| |
|
|
| import gradio as gr |
| import torch |
| import torchaudio |
|
|
| from ced_model.feature_extraction_ced import CedFeatureExtractor |
| from ced_model.modeling_ced import CedForAudioClassification |
|
|
| model_path = "mispeech/ced-base" |
| feature_extractor = CedFeatureExtractor.from_pretrained(model_path) |
| model = CedForAudioClassification.from_pretrained(model_path) |
|
|
|
|
| def process(audio_path: str) -> str: |
| if audio_path is None: |
| return "No audio file uploaded." |
|
|
| global model |
| global label_maps |
| audio, sr = torchaudio.load(audio_path) |
| if sr != 16000: |
| return "Models are trained on 16khz, please sample your input to 16khz mono." |
|
|
| inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| logits = model(**inputs).logits |
|
|
| predicted_class_ids = torch.argmax(logits, dim=-1).item() |
| predicted_label = model.config.id2label[predicted_class_ids] |
|
|
| return predicted_label |
|
|
|
|
| iface_audio_file = gr.Interface( |
| fn=process, |
| inputs=gr.Audio(sources="upload", type="filepath", streaming=False), |
| outputs="text", |
| ) |
| gr.close_all() |
| iface_audio_file.launch() |
|
|