| from typing import Any |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class ContentEncoder(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| text_encoder: nn.Module = None, |
| llm_encoder: nn.Module = None, |
| video_encoder: nn.Module = None, |
| midi_encoder: nn.Module = None, |
| phoneme_encoder: nn.Module = None, |
| pitch_encoder: nn.Module = None, |
| audio_encoder: nn.Module = None |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.text_encoder = text_encoder |
| self.midi_encoder = midi_encoder |
| self.phoneme_encoder = phoneme_encoder |
| self.pitch_encoder = pitch_encoder |
| self.audio_encoder = audio_encoder |
| self.video_encoder = video_encoder |
|
|
| def encode_content( |
| self, batch_content: list[Any], batch_task: list[str], |
| device: str | torch.device |
| ): |
| batch_content_output = [] |
| batch_content_mask = [] |
| batch_la_content_output = [] |
| batch_la_content_output_mask = [] |
| zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) |
| |
| for i,(content, task) in enumerate(zip(batch_content, batch_task)): |
| if task == "audio_editing": |
| raw_waveform = torch.as_tensor(content["audio"]).float() |
| waveform_with_batch_dim = raw_waveform.unsqueeze(0).to(device) |
| waveform_lengths = torch.as_tensor([raw_waveform.shape[0]]) |
| |
| |
| content_output_dict = self.text_encoder( |
| [content["caption"]], waveform_with_batch_dim |
| ) |
| audio_dict = { |
| "waveform": waveform_with_batch_dim, |
| "waveform_lengths": waveform_lengths |
| } |
| audio_output_dict = self.audio_encoder(**audio_dict) |
| la_content_output_dict = { |
| "output": audio_output_dict["output"], |
| "mask": audio_output_dict["mask"] |
| } |
|
|
| batch_content_output.append(content_output_dict["output"][0]) |
| batch_content_mask.append(content_output_dict["mask"][0]) |
| batch_la_content_output.append(la_content_output_dict["output"][0]) |
| batch_la_content_output_mask.append( |
| la_content_output_dict.get("mask", zero_la_content)[0] |
| ) |
|
|
| batch_content_output = nn.utils.rnn.pad_sequence( |
| batch_content_output, batch_first=True, padding_value=0 |
| ) |
| batch_content_mask = nn.utils.rnn.pad_sequence( |
| batch_content_mask, batch_first=True, padding_value=False |
| ) |
| batch_la_content_output = nn.utils.rnn.pad_sequence( |
| batch_la_content_output, batch_first=True, padding_value=0 |
| ) |
|
|
| batch_la_content_output_mask = nn.utils.rnn.pad_sequence( |
| batch_la_content_output_mask, batch_first=True, padding_value=False |
| ) |
| return { |
| "content": batch_content_output , |
| "content_mask": batch_content_mask, |
| "length_aligned_content": batch_la_content_output, |
| "time_aligned_content_mask": batch_la_content_output_mask |
| } |
|
|
|
|
|
|
| class BatchedContentEncoder(ContentEncoder): |
| def encode_content( |
| self, batch_content: list[dict], batch_task: list[str], |
| device: str | torch.device |
| ): |
| assert all(task == "audio_editing" for task in batch_task), \ |
| "BatchedContentEncoder now are only support audio_editing" |
|
|
| zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) |
|
|
| captions = [] |
| waveforms = [] |
| waveform_lengths = [] |
| for content in batch_content: |
| raw_waveform = torch.as_tensor(content["audio"]).float().to(device) |
| captions.append(content["caption"]) |
| waveforms.append(raw_waveform) |
| waveform_lengths.append(raw_waveform.shape[0]) |
|
|
| content_output_dict = self.text_encoder( |
| captions, waveforms |
| ) |
|
|
| batch_la_content_output = [] |
| batch_la_content_output_mask = [] |
| for i in range(len(batch_content)): |
| audio_dict = { |
| "waveform": waveforms[i].unsqueeze(0), |
| "waveform_lengths": torch.as_tensor([waveform_lengths[i]], device=device) |
| } |
| audio_output_dict = self.audio_encoder(**audio_dict) |
| batch_la_content_output.append(audio_output_dict["output"][0]) |
| batch_la_content_output_mask.append(audio_output_dict["mask"][0]) |
|
|
| |
| batch_la_content_output = nn.utils.rnn.pad_sequence( |
| batch_la_content_output, batch_first=True, padding_value=0 |
| ) |
| batch_la_content_output_mask = nn.utils.rnn.pad_sequence( |
| batch_la_content_output_mask, batch_first=True, padding_value=False |
| ) |
|
|
| return { |
| "content": content_output_dict["output"], |
| "content_mask": content_output_dict["mask"], |
| "length_aligned_content": batch_la_content_output, |
| "time_aligned_content_mask": batch_la_content_output_mask |
| } |
|
|