| import torch |
| import librosa |
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from transformers import AutoTokenizer, ClapTextModelWithProjection |
| from src.models.transformer import Dasheng_Encoder |
| from src.models.sed_decoder import Decoder, TSED_Wrapper |
| from src.utils import load_yaml_with_includes |
|
|
|
|
| class FlexSED: |
| def __init__( |
| self, |
| config_path='src/configs/model.yml', |
| ckpt_path='ckpts/flexsed_as.pt', |
| ckpt_url='https://huggingface.co/Higobeatz/FlexSED/resolve/main/ckpts/flexsed_as.pt', |
| device='cuda' |
| ): |
| """ |
| Initialize FlexSED with model, CLAP, and tokenizer loaded once. |
| If the checkpoint is not available locally, it will be downloaded automatically. |
| """ |
| self.device = device |
| params = load_yaml_with_includes(config_path) |
|
|
| |
| if not os.path.exists(ckpt_path): |
| print(f"[FlexSED] Downloading checkpoint from {ckpt_url} ...") |
| state_dict = torch.hub.load_state_dict_from_url(ckpt_url, map_location="cpu") |
| else: |
| state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
| |
| encoder = Dasheng_Encoder(**params['encoder']).to(self.device) |
| decoder = Decoder(**params['decoder']).to(self.device) |
| self.model = TSED_Wrapper(encoder, decoder, params['ft_blocks'], params['frozen_encoder']) |
| self.model.load_state_dict(state_dict['model']) |
| self.model.eval() |
|
|
| |
| self.clap = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused") |
| self.clap.eval() |
| self.tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") |
|
|
| def run_inference(self, audio_path, events, norm_audio=True): |
| """ |
| Run inference on audio for given events. |
| """ |
| audio, sr = librosa.load(audio_path, sr=16000) |
| audio = torch.tensor([audio]).to(self.device) |
|
|
| if norm_audio: |
| eps = 1e-9 |
| max_val = torch.max(torch.abs(audio)) |
| audio = audio / (max_val + eps) |
|
|
| clap_embeds = [] |
| with torch.no_grad(): |
| for event in events: |
| text = f"The sound of {event.replace('_', ' ').capitalize()}" |
| inputs = self.tokenizer([text], padding=True, return_tensors="pt") |
| outputs = self.clap(**inputs) |
| text_embeds = outputs.text_embeds.unsqueeze(1) |
| clap_embeds.append(text_embeds) |
|
|
| query = torch.cat(clap_embeds, dim=1).to(self.device) |
| mel = self.model.forward_to_spec(audio) |
| preds = self.model(mel, query) |
| preds = torch.sigmoid(preds).cpu() |
|
|
| return preds |
|
|
| |
| @staticmethod |
| def plot_and_save_multi(preds, events, sr=25, out_dir="./plots", fname="all_events"): |
| os.makedirs(out_dir, exist_ok=True) |
| preds_np = preds.squeeze(1).numpy() |
| T = preds_np.shape[1] |
|
|
| plt.figure(figsize=(12, len(events) * 0.6 + 2)) |
| plt.imshow( |
| preds_np, |
| aspect="auto", |
| cmap="Blues", |
| extent=[0, T/sr, 0, len(events)], |
| vmin=0, vmax=1, origin="lower" |
|
|
| ) |
| plt.colorbar(label="Probability") |
| plt.yticks(np.arange(len(events)) + 0.5, events) |
| plt.xlabel("Time (s)") |
| plt.ylabel("Events") |
| plt.title("Event Predictions") |
|
|
| save_path = os.path.join(out_dir, f"{fname}.png") |
| plt.savefig(save_path, dpi=200, bbox_inches="tight") |
| plt.close() |
| return save_path |
|
|
| def to_multi_plot(self, preds, events, out_dir="./plots", fname="all_events"): |
| return self.plot_and_save_multi(preds, events, out_dir=out_dir, fname=fname) |
|
|
| |
| @staticmethod |
| def make_multi_event_video(preds, events, sr=25, out_dir="./videos", |
| audio_path=None, fps=25, highlight=True, fname="all_events"): |
| from moviepy.editor import ImageSequenceClip, AudioFileClip |
| from tqdm import tqdm |
|
|
| os.makedirs(out_dir, exist_ok=True) |
| preds_np = preds.squeeze(1).numpy() |
| T = preds_np.shape[1] |
| duration = T / sr |
|
|
| frames = [] |
| n_frames = int(duration * fps) |
|
|
| for i in tqdm(range(n_frames)): |
| t = int(i * T / n_frames) |
| plt.figure(figsize=(12, len(events) * 0.6 + 2)) |
|
|
| if highlight: |
| mask = np.zeros_like(preds_np) |
| mask[:, :t+1] = preds_np[:, :t+1] |
| plt.imshow( |
| mask, |
| aspect="auto", |
| cmap="Blues", |
| extent=[0, T/sr, 0, len(events)], |
| vmin=0, vmax=1, origin="lower" |
| ) |
| else: |
| plt.imshow( |
| preds_np[:, :t+1], |
| aspect="auto", |
| cmap="Blues", |
| extent=[0, (t+1)/sr, 0, len(events)], |
| vmin=0, vmax=1, origin="lower" |
| ) |
|
|
| plt.colorbar(label="Probability") |
| plt.yticks(np.arange(len(events)) + 0.5, events) |
| plt.xlabel("Time (s)") |
| plt.ylabel("Events") |
| plt.title("Event Predictions") |
|
|
| frame_path = f"/tmp/frame_{i:04d}.png" |
| plt.savefig(frame_path, dpi=150, bbox_inches="tight") |
| plt.close() |
| frames.append(frame_path) |
|
|
| clip = ImageSequenceClip(frames, fps=fps) |
| if audio_path is not None: |
| audio = AudioFileClip(audio_path).subclip(0, duration) |
| clip = clip.set_audio(audio) |
|
|
| save_path = os.path.join(out_dir, f"{fname}.mp4") |
| clip.write_videofile( |
| save_path, |
| fps=fps, |
| codec="mpeg4", |
| audio_codec="aac" |
| ) |
|
|
| for f in frames: |
| os.remove(f) |
|
|
| return save_path |
|
|
| def to_multi_video(self, preds, events, audio_path, out_dir="./videos", fname="all_events"): |
| return self.make_multi_event_video( |
| preds, events, audio_path=audio_path, out_dir=out_dir, fname=fname |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| flexsed = FlexSED(device='cuda') |
|
|
| events = ["Door", "Laughter", "Dog"] |
| preds = flexsed.run_inference("example2.wav", events) |
|
|
| |
| flexsed.to_multi_plot(preds, events, fname="example2") |
| |
|
|