| from typing import Any, Callable, Dict |
| import random |
| import lightning.pytorch as pl |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
| class AudioSep(pl.LightningModule): |
| def __init__( |
| self, |
| ss_model: nn.Module, |
| waveform_mixer, |
| query_encoder, |
| loss_function, |
| optimizer_type: str, |
| learning_rate: float, |
| lr_lambda_func, |
| use_text_ratio=1.0, |
| ): |
| r"""Pytorch Lightning wrapper of PyTorch model, including forward, |
| optimization of model, etc. |
| |
| Args: |
| ss_model: nn.Module |
| anchor_segment_detector: nn.Module |
| loss_function: function or object |
| learning_rate: float |
| lr_lambda: function |
| """ |
|
|
| super().__init__() |
| self.ss_model = ss_model |
| self.waveform_mixer = waveform_mixer |
| self.query_encoder = query_encoder |
| self.query_encoder_type = self.query_encoder.encoder_type |
| self.use_text_ratio = use_text_ratio |
| self.loss_function = loss_function |
| self.optimizer_type = optimizer_type |
| self.learning_rate = learning_rate |
| self.lr_lambda_func = lr_lambda_func |
|
|
|
|
| def forward(self, x): |
| pass |
|
|
| def training_step(self, batch_data_dict, batch_idx): |
| r"""Forward a mini-batch data to model, calculate loss function, and |
| train for one step. A mini-batch data is evenly distributed to multiple |
| devices (if there are) for parallel training. |
| |
| Args: |
| batch_data_dict: e.g. |
| 'audio_text': { |
| 'text': ['a sound of dog', ...] |
| 'waveform': (batch_size, 1, samples) |
| } |
| batch_idx: int |
| |
| Returns: |
| loss: float, loss function of this mini-batch |
| """ |
| |
| random.seed(batch_idx) |
|
|
| batch_audio_text_dict = batch_data_dict['audio_text'] |
|
|
| batch_text = batch_audio_text_dict['text'] |
| batch_audio = batch_audio_text_dict['waveform'] |
| device = batch_audio.device |
| |
| mixtures, segments = self.waveform_mixer( |
| waveforms=batch_audio |
| ) |
|
|
| |
| if self.query_encoder_type == 'CLAP': |
| conditions = self.query_encoder.get_query_embed( |
| modality='hybird', |
| text=batch_text, |
| audio=segments.squeeze(1), |
| use_text_ratio=self.use_text_ratio, |
| ) |
|
|
| input_dict = { |
| 'mixture': mixtures[:, None, :].squeeze(1), |
| 'condition': conditions, |
| } |
|
|
| target_dict = { |
| 'segment': segments.squeeze(1), |
| } |
|
|
| self.ss_model.train() |
| sep_segment = self.ss_model(input_dict)['waveform'] |
| sep_segment = sep_segment.squeeze() |
| |
|
|
| output_dict = { |
| 'segment': sep_segment, |
| } |
|
|
| |
| loss = self.loss_function(output_dict, target_dict) |
|
|
| self.log_dict({"train_loss": loss}) |
| |
| return loss |
|
|
| def test_step(self, batch, batch_idx): |
| pass |
| |
| def configure_optimizers(self): |
| r"""Configure optimizer. |
| """ |
|
|
| if self.optimizer_type == "AdamW": |
| optimizer = optim.AdamW( |
| params=self.ss_model.parameters(), |
| lr=self.learning_rate, |
| betas=(0.9, 0.999), |
| eps=1e-08, |
| weight_decay=0.0, |
| amsgrad=True, |
| ) |
| else: |
| raise NotImplementedError |
|
|
| scheduler = LambdaLR(optimizer, self.lr_lambda_func) |
|
|
| output_dict = { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| 'scheduler': scheduler, |
| 'interval': 'step', |
| 'frequency': 1, |
| } |
| } |
|
|
| return output_dict |
| |
|
|
| def get_model_class(model_type): |
| if model_type == 'ResUNet30': |
| from models.resunet import ResUNet30 |
| return ResUNet30 |
|
|
| else: |
| raise NotImplementedError |
|
|