import glob import time import torch from codeclm.models.codeclm_gen import CodecLM_gen from codeclm.models import builders import sys import os import torchaudio import numpy as np import json from vllm import LLM, SamplingParams import re import argparse import librosa auto_prompt_type = ['Pop', 'Latin', 'Rock', 'Electronic', 'Metal', 'Country', 'R&B/Soul', 'Ballad', 'Jazz', 'World', 'Hip-Hop', 'Funk', 'Soundtrack','Auto'] def check_language_by_text(text): chinese_pattern = re.compile(r'[\u4e00-\u9fff]') english_pattern = re.compile(r'[a-zA-Z]') chinese_count = len(re.findall(chinese_pattern, text)) english_count = len(re.findall(english_pattern, text)) chinese_ratio = chinese_count / len(text) english_ratio = english_count / len(text) if chinese_ratio >= 0.2: return "zh" elif english_ratio >= 0.5: return "en" else: return "en" def load_audio(f): a, fs= librosa.load(f, sr=48000) a = torch.tensor(a).unsqueeze(0) if (fs != 48000): a = torchaudio.functional.resample(a, fs, 48000) if a.shape[-1] >= 48000*10: a = a[..., :48000*10] return a[:, 0:48000*10] def parse_args(): parser = argparse.ArgumentParser(description='Song Generation Script') # 必需参数 parser.add_argument('--input_jsonl', type=str, required=True, help='Path to input JSONL file containing generation tasks') parser.add_argument('--save_dir', type=str, required=True, help='Directory to save generated audio files and results') parser.add_argument('--config_path', type=str, required=True, help='Path to the config file') return parser.parse_args() def main(): torch.set_num_threads(1) torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错 from omegaconf import OmegaConf OmegaConf.register_new_resolver("eval", lambda x: eval(x)) OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0]) OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) args = parse_args() input_jsonl = args.input_jsonl save_dir = args.save_dir cfg_path = args.config_path cfg = OmegaConf.load(cfg_path) cfg.mode = 'inference' max_duration = cfg.max_dur audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg) if audio_tokenizer is not None: for param in audio_tokenizer.parameters(): param.requires_grad = False print("Audio tokenizer successfully loaded!") audio_tokenizer = audio_tokenizer.eval().cuda() model_condition = CodecLM_gen(cfg=cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = max_duration) model_condition.condition_provider.conditioners.load_state_dict(torch.load(cfg.lm_checkpoint+"/conditioners_weights.pth")) print('Conditioner successfully loaded!') llm = LLM( model=cfg.lm_checkpoint, trust_remote_code=True, tensor_parallel_size=cfg.vllm.device_num, enforce_eager=False, dtype="bfloat16", gpu_memory_utilization=cfg.vllm.gpu_memory_utilization, tokenizer=None, skip_tokenizer_init=True, enable_prompt_embeds=True, enable_chunked_prefill=True, ) print("LLM 初始化成功") auto_prompt = torch.load('tools/new_prompt.pt') guidance_scale = cfg.vllm.guidance_scale temp = cfg.vllm.temp top_k = cfg.vllm.top_k sum_time = 0 sum_wav_len = 0 os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir + "/audios", exist_ok=True) os.makedirs(save_dir + "/jsonl", exist_ok=True) with open(input_jsonl, "r") as fp: lines = fp.readlines() new_items = [] for line in lines: item = json.loads(line) lyric = item["gt_lyric"] descriptions = item["descriptions"].lower() if "descriptions" in item else '.' descriptions = '[Musicality-very-high]' + ', ' + descriptions target_wav_name = f"{save_dir}/audios/{item['idx']}.flac" if os.path.exists(target_wav_name): continue if "prompt_audio_path" in item: assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found" assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together" with torch.no_grad(): pmt_wav = load_audio(item['prompt_audio_path']) item['raw_pmt_wav'] = pmt_wav if pmt_wav.dim() == 2: pmt_wav = pmt_wav[None] if pmt_wav.dim() != 3: raise ValueError("Melody wavs should have a shape [B, C, T].") pmt_wav = list(pmt_wav) if type(pmt_wav) == list: pmt_wav = torch.stack(pmt_wav, dim=0) with torch.no_grad(): pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda()) print(pmt_wav.shape) melody_is_wav = False elif "auto_prompt_audio_type" in item: assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found" lang = check_language_by_text(item['gt_lyric']) prompt_token = auto_prompt[item["auto_prompt_audio_type"]][lang][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]][lang]))] pmt_wav = prompt_token[:,[0],:] melody_is_wav = False else: pmt_wav = None melody_is_wav = True item["idx"] = f"{item['idx']}" item["wav_path"] = target_wav_name embeded_eosp1 = torch.load(cfg.lm_checkpoint+'/embeded_eosp1.pt') generate_inp = { 'descriptions': [lyric.replace(" ", " ")], 'type_info': [descriptions], 'melody_wavs': pmt_wav, 'melody_is_wav': melody_is_wav, 'embeded_eosp1': embeded_eosp1, } fused_input, audio_qt_embs = model_condition.generate_condition(**generate_inp, return_tokens=True) prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else [] allowed_token_ids = [x for x in range(cfg.lm.code_size+1) if x not in prompt_token] sampling_params = SamplingParams( max_tokens=cfg.audio_tokenizer_frame_rate*cfg.max_dur, temperature=temp, stop_token_ids=[cfg.lm.code_size], top_k=top_k, frequency_penalty=0.2, seed=int(time.time() * 1000000) % (2**32) if cfg.vllm.cfg else -1, allowed_token_ids=allowed_token_ids, guidance_scale=guidance_scale ) # 拆成现支持的batch 3 CFG形式 prompts = [{"prompt_embeds": embed} for embed in fused_input] promptss = [] for _ in range(2): promptss+=prompts uncondi = prompts[1] promptss = promptss[::2] + [uncondi] start_time = time.time() outputs = llm.generate(promptss, sampling_params=sampling_params) mid_time = time.time() token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids) token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0) with torch.no_grad(): # wav_nocfg = model_condition.generate_audio(token_ids) if 'raw_pmt_wav' in item: wav_cfg = model_condition.generate_audio(token_ids_CFG, item['raw_pmt_wav']) del item['raw_pmt_wav'] else: wav_cfg = model_condition.generate_audio(token_ids_CFG) end_time = time.time() torchaudio.save(target_wav_name, wav_cfg[0].cpu().float(), cfg.sample_rate) sum_time += end_time - start_time sum_wav_len += (token_ids_CFG.shape[-1] / 25) print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}, rtf {(end_time - start_time) / token_ids_CFG.shape[-1] * 25:.2f}") new_items.append(item) print(f"Total time: {sum_time:.4f} seconds, total wav length: {sum_wav_len:.4f} seconds, rtf {sum_time/sum_wav_len:.2f}") src_jsonl_name = os.path.split(input_jsonl)[-1] with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw: for item in new_items: fw.writelines(json.dumps(item, ensure_ascii=False)+"\n") if __name__ == "__main__": main()