| import os |
| from pathlib import Path |
| from typing import Any, BinaryIO, Mapping, Optional, Union |
| import torch |
|
|
| from config import default_config |
| from featex import load_audio, Preprocessor |
| from model import Classifier |
|
|
| class Pipeline: |
| def __init__(self, checkpoint: Optional[str | Path] = None, config: Optional[Mapping[str, Any]] = None, device: Optional[torch.device] = None): |
| if checkpoint is None: |
| file_dir = Path(__file__).parent.resolve() |
| checkpoint = file_dir / "dam3.1.ckpt" |
| if config is None: |
| config = default_config |
| if device is None: |
| if torch.cuda.is_available(): |
| device = torch.device("cuda:0") |
| else: |
| device = torch.device("cpu") |
| self.device = device |
| self.model = Classifier(**config) |
| self.preprocessor = Preprocessor(**self.model.preprocessor_config) |
| state_dict = torch.load(checkpoint, map_location=device) |
| self.model.load_state_dict(state_dict) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def run_on_features(self, features: torch.Tensor, quantize: bool = True): |
| scores = self.model(features, torch.tensor([features.shape[0]], device=self.device))[0] |
| if quantize: |
| return {k: int(v.item()) for k, v in self.model.quantize_scores(scores).items()} |
| else: |
| return scores |
|
|
| def run_on_audio(self, audio: torch.Tensor, quantize: bool = True): |
| features = self.preprocessor.preprocess_with_audio_normalization(audio) |
| return self.run_on_features(features.to(self.device), quantize=quantize) |
|
|
| def run_on_file(self, source: Union[BinaryIO, str, os.PathLike], quantize=True): |
| audio = load_audio(source) |
| return self.run_on_audio(audio, quantize=quantize) |
|
|