| import json
|
| import logging
|
| import numpy as np
|
|
|
| import torchaudio
|
| from torch.utils.data import Dataset
|
|
|
|
|
| def _handle_wav(wav_path, target_rate=16000):
|
| """
|
| 处理单个音频文件
|
| 返回:
|
| waveform: numpy数组(一维)
|
| """
|
| waveform, sample_rate = torchaudio.load(wav_path)
|
| if sample_rate != target_rate:
|
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
|
| audio = waveform[0]
|
| return audio
|
|
|
|
|
| def _handle_dialogue_evaluation(obj, sample_rate=16000):
|
| """
|
| 处理对话评价任务,将用户和机器人音频拼接
|
| """
|
|
|
| user_audio = _handle_wav(obj["user_wav"], sample_rate)
|
| robot_audio = _handle_wav(obj["robot_wav"], sample_rate)
|
|
|
|
|
| silence = np.zeros(int(sample_rate * 0.5))
|
|
|
|
|
| combined_audio = np.concatenate([user_audio.numpy(), silence, robot_audio.numpy()])
|
|
|
|
|
| prompt_template = (
|
| "上面有一段对话,用户先说话,中间隔0.5s,机器人回答。请评价上述对话中机器人回答的合理性。"
|
| "考虑机器人回答的情感、语气、内容等方面。"
|
| "首先在<think></think>标签中详细分析,然后在<score></score>标签中给出1-10的评分。"
|
| )
|
|
|
|
|
| processed_obj = {
|
| "id": obj["id"],
|
| "prompt": [{"role": "user", "content": [
|
| {"type": "audio", "audio_url": "combined_audio"},
|
| {"type": "text", "text": prompt_template}
|
| ]}],
|
| "solution": f"<think>分析机器人回答的合理性</think><score>{obj['gt_score']}</score>",
|
| "audio": combined_audio,
|
| "gt_score": obj["gt_score"]
|
| }
|
|
|
| return processed_obj
|
|
|
|
|
| class AudioDataset(Dataset):
|
| def __init__(self, data_file, sample_rate=16000, is_perturb=False):
|
| super().__init__()
|
| self.data = []
|
|
|
|
|
| with open(data_file, 'r', encoding='utf8') as f:
|
| data_list = json.load(f)
|
|
|
|
|
| for item in data_list:
|
| processed_item = _handle_dialogue_evaluation(item, sample_rate)
|
| self.data.append(processed_item)
|
|
|
| self.sample_rate = sample_rate
|
| self.is_perturb = is_perturb
|
| logging.info(f"加载数据集: {data_file}, 样本数: {len(self.data)}, 采样率: {sample_rate}")
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| def __getitem__(self, index):
|
| return self.data[index] |