| import torch |
| import torch.nn as nn |
| import librosa |
| import numpy as np |
| from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration |
| import os |
| |
| QWEN_AUDIO_PREFIX = '''Given a user prompt and an audio clip, generate an "Enhanced prompt" that provides detailed descriptions suitable for audio generation. Evaluate the audio and user prompt: |
| - If the prompt is simple, focus on adding specifics about tones, instruments, rhythms, tempos, and audio characteristics to create vivid and concrete audio descriptions. |
| - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n |
| Here are examples of how to transform or refine prompts: |
| - User Prompt: Piano music -> Enhanced: A gentle, melancholic piano piece with delicate arpeggios in a minor key, featuring subtle reverb that creates a sense of space and intimacy. |
| - User Prompt: City sounds -> Enhanced: A bustling urban soundscape with distant traffic noise, occasional car horns, footsteps on concrete sidewalks, and the murmur of crowd conversations, with subtle pigeons cooing in the background.\n |
| Please generate only the enhanced description for the audio and prompt below and avoid including any additional commentary or evaluations: |
| User Prompt:''' |
|
|
| class Qwen2AudioEmbedder(nn.Module): |
| def __init__(self, model_path, embed_dim=256, max_length=320, dtype=torch.float, device="cuda"): |
| super().__init__() |
| self.max_length = max_length |
| self.device = device |
| self.embed_dim = embed_dim |
|
|
| self.model = Qwen2AudioForConditionalGeneration.from_pretrained( |
| model_path, |
| torch_dtype=dtype, |
| device_map={"": int(os.environ.get("LOCAL_RANK", 0))} |
| ) |
| |
| self.model.requires_grad_(False) |
| self.model.eval() |
| self.processor = AutoProcessor.from_pretrained(model_path) |
| |
| |
| |
| self.proj = nn.Linear(4096, embed_dim, device=device, dtype=dtype) |
| self.prefix = QWEN_AUDIO_PREFIX |
|
|
| def forward(self, text, audio_data): |
| """ |
| Args: |
| text: 文本描述列表 |
| audio_data: 音频数据列表,每个元素是numpy数组 |
| Returns: |
| 字典包含 "output": 嵌入张量, "mask": 掩码张量 |
| """ |
| output, mask = self.encode(text, audio_data) |
| output = self.projection(output) |
| return {"output": output, "mask": mask} |
|
|
| def encode(self, text, audio_data): |
| """编码文本和音频到嵌入空间""" |
| """编码文本和音频到嵌入空间""" |
| batch_size = len(text) |
| |
| |
| processed_audios = [] |
| for audio in audio_data: |
| if isinstance(audio, torch.Tensor): |
| audio = audio.cpu().numpy() |
| |
| audio=librosa.resample(audio, orig_sr=24000, target_sr=16000) |
| processed_audios.append(audio) |
|
|
| |
| conversations = [] |
| for txt in text: |
| conversation = [ |
| {"role": "user", "content": [ |
| |
| {"type": "audio", "audio": None}, |
| {"type": "text", "text": txt} |
| ]} |
| ] |
| |
| formatted_text = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) |
| conversations.append(formatted_text) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| inputs = self.processor( |
| text=conversations, |
| audio=processed_audios, |
| return_tensors="pt", |
| sampling_rate=16000, |
| padding=True, |
| truncation=True |
| ) |
|
|
| |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| |
| outputs = self.model( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| input_features=inputs["input_features"], |
| feature_attention_mask=inputs["feature_attention_mask"], |
| output_hidden_states=True, |
| ) |
|
|
| |
| hidden_states_full = outputs.hidden_states[-1] |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| current_len = hidden_states_full.shape[1] |
| if current_len > self.max_length: |
| embs = hidden_states_full[:, :self.max_length, :] |
| else: |
| pad_width = self.max_length - current_len |
| |
| padding = torch.zeros( |
| hidden_states_full.shape[0], |
| pad_width, |
| hidden_states_full.shape[2], |
| device=self.device, |
| dtype=hidden_states_full.dtype |
| ) |
| embs = torch.cat([hidden_states_full, padding], dim=1) |
|
|
| |
| attention_mask = inputs["attention_mask"] |
| if current_len > self.max_length: |
| masks = attention_mask[:, :self.max_length].bool() |
| else: |
| pad_width = self.max_length - current_len |
| |
| mask_padding = torch.zeros( |
| attention_mask.shape[0], |
| pad_width, |
| device=self.device, |
| dtype=torch.bool |
| ) |
| masks = torch.cat([attention_mask.bool(), mask_padding], dim=1) |
| |
| return embs, masks |
|
|
| def projection(self, x): |
| """将嵌入映射到指定维度""" |
| return self.proj(x) |
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Test Qwen Audio Encoder") |
| parser.add_argument("--model_path", type=str, default="/mnt/petrelfs/taoye/workspace/model/qwen25audio", |
| help="Path to Qwen Audio model") |
| parser.add_argument("--embed_dim", type=int, default=4096, |
| help="Target embedding dimension after projection") |
| args = parser.parse_args() |
| |
| print(f"Loading model from {args.model_path}...") |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| embedder = Qwen2AudioEmbedder( |
| model_path=args.model_path, |
| embed_dim=args.embed_dim, |
| max_length=640, |
| dtype=torch.float, |
| device=device |
| ) |
| |
| |
| captions = [ |
| "Describe this audio", |
| "What musical instruments are being played in this recording?" |
| ] |
| |
| |
| audio_path = "/mnt/petrelfs/taoye/workspace/editing/data/add/add_fore_audio_caps_begin_1/audio/edit/syn_5.wav" |
| audio_data = [] |
| for _ in range(len(captions)): |
| waveform, sr = librosa.load(audio_path,sr=24000) |
| |
| audio_data.append(waveform) |
| |
| |
| with torch.no_grad(): |
| output = embedder(captions, audio_data) |
| |
| |
| print("模型输出的字典:") |
| print(f"包含keys: {list(output.keys())}") |
| |
| print("\n输出张量的形状:") |
| print(output['output'].shape) |
| |
| print("\n掩码张量的形状:") |
| print(output['mask'].shape) |
| |
| |
| assert output['output'].shape[-1] == args.embed_dim, f"输出维度 {output['output'].shape[-1]} 不等于预期维度 {args.embed_dim}" |
| print(f"\n成功验证:输出维度 = {args.embed_dim}") |
| |
| |
| print(f"样本嵌入值:\n{output['output'][0, :5, :5]}") |
| print(f"非零掩码位置数量: {output['mask'][0,:]}") |
| |
| print(f"第一个样本的非零掩码位置数量: {output['mask'][0].sum().item()}") |
|
|
|
|
|
|