Spaces:
Sleeping
Sleeping
| import re | |
| import torch | |
| from torch.nn.functional import ctc_loss, log_softmax | |
| from torch.optim import RAdam | |
| from pytorch_lightning import LightningModule, Trainer | |
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
| from torchmetrics.functional import word_error_rate, char_error_rate | |
| import sentencepiece as spm | |
| from nemo.collections.asr.modules import ( | |
| AudioToMelSpectrogramPreprocessor, | |
| SpectrogramAugmentation, | |
| SqueezeformerEncoder, | |
| ConvASRDecoder, | |
| ) | |
| from nemo.core import typecheck | |
| from torchaudio.models.decoder import ctc_decoder | |
| typecheck.set_typecheck_enabled(False) | |
| # ------------------------- | |
| # Tokenizer | |
| # ------------------------- | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load("ressources/tokenizer/128_v7.model") | |
| # ------------------------- | |
| # CTC Decoder | |
| # ------------------------- | |
| tokens_file = "ressources/tokenizer/128_v7.txt" | |
| decoder = ctc_decoder( | |
| lexicon=None, | |
| tokens=tokens_file, | |
| beam_size=1, | |
| beam_threshold=1, | |
| beam_size_token=1, | |
| nbest=1, | |
| log_add=True, | |
| blank_token="_", | |
| sil_token="|", | |
| unk_word="<unk>", | |
| ) | |
| # ------------------------- | |
| # Hyperparameters | |
| # ------------------------- | |
| LR = 2e-4 | |
| NONE_COUNT = 0 | |
| # ------------------------- | |
| # LightningModule | |
| # ------------------------- | |
| class MySqueezeformer(LightningModule): | |
| def __init__(self, LR=LR): | |
| super().__init__() | |
| self.LR = LR | |
| self.processor = AudioToMelSpectrogramPreprocessor( | |
| sample_rate=16000, | |
| features=80, | |
| n_fft=512, | |
| window_size=0.025, | |
| window_stride=0.01, | |
| log=True, | |
| frame_splicing=True, | |
| ) | |
| self.augmentation = SpectrogramAugmentation(2, 5, 27, 0.05) | |
| self.encoder = SqueezeformerEncoder( | |
| feat_in=80, | |
| feat_out=-1, | |
| n_layers=16, | |
| d_model=144, | |
| adaptive_scale=True, | |
| time_reduce_idx=7, | |
| dropout_emb=0, | |
| dropout_att=0.1, | |
| subsampling_factor=4, | |
| ) | |
| self.decoder = ConvASRDecoder(feat_in=144, num_classes=128) | |
| # ------------------------- | |
| # Forward | |
| # ------------------------- | |
| def forward(self, x, lengths): | |
| spec, lengths = self.processor(x, lengths) | |
| if self.training: | |
| spec = self.augmentation(spec, lengths) | |
| encoded = self.encoder(spec, lengths) | |
| decoded = self.decoder(encoded[0]) | |
| logits_lengths = torch.tensor([len(d) for d in decoded], device=x.device) | |
| return decoded, logits_lengths | |
| # ------------------------- | |
| # Training Step | |
| # ------------------------- | |
| def training_step(self, batch, batch_idx): | |
| spectrograms, transcriptions, specs_lengths, transcriptions_lengths = batch | |
| outputs, logits_lengths = self(spectrograms, specs_lengths) | |
| outputs = torch.stack(outputs).transpose(0, 1) | |
| outputs = log_softmax(outputs, dim=2) | |
| loss = ctc_loss( | |
| outputs, | |
| transcriptions, | |
| logits_lengths, | |
| transcriptions_lengths, | |
| blank=1, | |
| zero_infinity=True, | |
| ) | |
| global NONE_COUNT | |
| if torch.isnan(loss) or torch.isinf(loss): | |
| NONE_COUNT += 1 | |
| self.log("N_c", float(NONE_COUNT), prog_bar=True, sync_dist=True) | |
| return None | |
| self.log("loss", loss, sync_dist=True, on_epoch=True, on_step=False) | |
| return loss | |
| # ------------------------- | |
| # Validation Step | |
| # ------------------------- | |
| def validation_step(self, batch, batch_idx): | |
| spectrograms, transcriptions, specs_lengths, transcriptions_lengths = batch | |
| outputs, logits_lengths = self(spectrograms, specs_lengths) | |
| all_transcriptions = [] | |
| all_targets = [] | |
| # Decode targets | |
| for i, tgt in enumerate(transcriptions): | |
| tgt_sentence = sp.Decode(tgt[: transcriptions_lengths[i]].tolist()) | |
| all_targets.append(tgt_sentence) | |
| # Decode predictions | |
| for i, out in enumerate(outputs): | |
| result = decoder(out.cpu().unsqueeze(0))[0][0] | |
| tokens = decoder.idxs_to_tokens(result.tokens) | |
| transcription = "".join(tokens).replace("_", "").replace("|", "") | |
| transcription = " ".join(transcription.split("▁")) | |
| transcription = re.sub(r"-{2,}", "-", transcription) | |
| transcription = transcription.strip() | |
| all_transcriptions.append(transcription) | |
| wer = word_error_rate(all_transcriptions, all_targets) | |
| cer = char_error_rate(all_transcriptions, all_targets) | |
| # Compute CTC loss for logging | |
| stacked_outputs = torch.stack(outputs).transpose(0, 1) | |
| stacked_outputs = log_softmax(stacked_outputs, dim=2) | |
| val_loss = ctc_loss( | |
| stacked_outputs, | |
| transcriptions, | |
| logits_lengths, | |
| transcriptions_lengths, | |
| blank=1, | |
| zero_infinity=True, | |
| ) | |
| self.log("val_loss", val_loss, sync_dist=True, on_epoch=True) | |
| self.log("wer", wer, prog_bar=True, sync_dist=True, on_epoch=True) | |
| self.log("cer", cer, sync_dist=True, on_epoch=True) | |
| # ------------------------- | |
| # Optimizer | |
| # ------------------------- | |
| def configure_optimizers(self): | |
| optimizer = RAdam( | |
| self.parameters(), | |
| lr=self.LR, | |
| betas=[0.9, 0.98], | |
| weight_decay=1e-6, | |
| eps=1e-9, | |
| ) | |
| return optimizer | |
| # ------------------------- | |
| # Training | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| callbacks = [ | |
| LearningRateMonitor(logging_interval="epoch"), | |
| ModelCheckpoint( | |
| dirpath="./checkpoints_vZ2/val_loss", | |
| verbose=False, | |
| save_on_train_epoch_end=True, | |
| save_top_k=1, | |
| save_last=True, | |
| monitor="val_loss", | |
| ), | |
| ModelCheckpoint( | |
| dirpath="./checkpoints_vZ2/wer", | |
| verbose=False, | |
| save_on_train_epoch_end=True, | |
| save_top_k=1, | |
| save_last=False, | |
| monitor="wer", | |
| ), | |
| ModelCheckpoint( | |
| dirpath="./checkpoints_vZ2/cer", | |
| verbose=False, | |
| save_on_train_epoch_end=True, | |
| save_top_k=1, | |
| save_last=False, | |
| monitor="cer", | |
| ), | |
| ] | |
| model = MySqueezeformer() | |
| trainer = Trainer( | |
| accelerator="auto", | |
| precision="bf16", | |
| callbacks=callbacks, | |
| default_root_dir="./checkpoints_vZ2/logs", | |
| reload_dataloaders_every_n_epochs=1, | |
| max_epochs=300, | |
| ) | |
| # trainer.fit( | |
| # model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloader | |
| # ) | |