| """This recipe to train CLAP. |
| It supports distillation using tinyCLAP (https://arxiv.org/abs/2311.14517). |
| |
| Authors |
| * Francesco Paissan 2024 |
| """ |
|
|
| import sys |
|
|
| import gradio as gr |
| import speechbrain as sb |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchaudio |
| import torchaudio.transforms as T |
| from hyperpyyaml import load_hyperpyyaml |
| from speechbrain.utils.distributed import run_on_main |
| from speechbrain.utils.metric_stats import MetricStats |
|
|
| torch.backends.cudnn.enabled = False |
|
|
| eps = 1e-10 |
|
|
|
|
| class CLAPBrain(sb.Brain): |
| def preprocess(self, wavs): |
| """Pre-process wavs.""" |
| x = self.hparams.spectrogram_extractor(wavs) |
| x = self.hparams.logmel_extractor(x) |
|
|
| return x |
|
|
| def prepare_txt_features(self, text): |
| """Prepares text features to input in CLAP text encoder.""" |
| txt_inp = self.hparams.txt_tokenizer( |
| text, |
| max_length=self.hparams.text_max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ).to(self.device) |
|
|
| return txt_inp |
|
|
| def compute_sim(self, audio_embed, caption_embed): |
| """Computes CLAP similarity metric.""" |
| similarity = audio_embed @ caption_embed.t() |
|
|
| return similarity |
|
|
| def compute_forward(self, batch, stage): |
| if len(batch) == 2: |
| wavs, caption = batch |
| else: |
| wavs, caption, _, _ = batch |
|
|
| wavs = wavs.to(self.device).squeeze(1) |
|
|
| x_sb = self.preprocess(wavs) |
|
|
| text_inp = self.prepare_txt_features(caption) |
|
|
| txt_shared, aud_shared = self.hparams.clap( |
| x_sb, |
| text_inp.input_ids.data, |
| text_inp.token_type_ids.data, |
| text_inp.attention_mask.data, |
| ) |
|
|
| if not hasattr(self.modules, "clap"): |
| aud_shared_student, _, _ = self.modules.clap_student(x_sb) |
| aud_shared_student = aud_shared_student / aud_shared_student.norm( |
| dim=1, keepdim=True |
| ) |
|
|
| return txt_shared, aud_shared, aud_shared_student |
|
|
|
|
| def audio_preprocess(x, sample_rate): |
| tmp, sr = torchaudio.load(x) |
| resample = T.Resample(sr, sample_rate) |
|
|
| tmp = resample(tmp) |
| tmp = tmp.sum(0, keepdims=True) |
|
|
| return tmp |
|
|
|
|
| @torch.no_grad() |
| def inference_wrapper(clap_brain): |
| def f(wav_path, prompt): |
| clap_brain.modules.eval() |
| tmp = audio_preprocess(wav_path, clap_brain.hparams.sample_rate) |
|
|
| ret = clap_brain.compute_forward([tmp, prompt], stage=sb.Stage.TEST) |
| sim = clap_brain.compute_sim(ret[2], ret[0]) |
|
|
| return f"tinyCLAP similarity is: {round(sim.item(), 2)}" |
|
|
| return f |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| |
| hparams_file = "hparams/inference.yaml" |
|
|
| |
| with open(hparams_file) as fin: |
| hparams = load_hyperpyyaml(fin, {}) |
|
|
| |
| if hparams["use_tensorboard"]: |
| from speechbrain.utils.train_logger import TensorboardLogger |
|
|
| hparams["tensorboard_train_logger"] = TensorboardLogger( |
| hparams["tensorboard_logs_folder"] |
| ) |
|
|
| hparams["clap"].to(hparams["device"]) |
| hparams["clap"].requires_grad_(False) |
| hparams["clap"].eval() |
|
|
| if hparams["zs_eval"]: |
| hparams["class_list"] = datasets["train"].dataset.classes |
|
|
| if hparams["audioenc_name_student"] is not None: |
| if hparams["projection_only"]: |
| print("Freezing Base AudioEncoder. Updating only the projection layers.") |
| hparams["student_model"].base.requires_grad_(False) |
|
|
| hparams["spectrogram_extractor"].to(hparams["device"]) |
| hparams["logmel_extractor"].to(hparams["device"]) |
|
|
| clap_brain = CLAPBrain( |
| modules=hparams["modules"], |
| hparams=hparams, |
| ) |
|
|
| if hparams["pretrained_CLAP"] is not None: |
| print("Loading CLAP model...") |
| run_on_main(hparams["load_CLAP"].collect_files) |
| hparams["load_CLAP"].load_collected() |
|
|
| inference_api = inference_wrapper(clap_brain) |
|
|
| examples_list = [ |
| ["./tunztunz_music.wav", "this is the sound of house music"], |
| ["./siren.wav", "this is the sound of sirens wailing"], |
| [ |
| "./whistling_and_chirping.wav", |
| "someone is whistling while birds are chirping", |
| ], |
| ] |
|
|
| demo = gr.Interface( |
| fn=inference_api, |
| inputs=[gr.Audio(type="filepath"), gr.Textbox()], |
| outputs=["text"], |
| examples=examples_list, |
| ) |
| demo.launch() |
|
|