| import librosa
|
| from transformers import WhisperForAudioClassification
|
|
|
|
|
| model = WhisperForAudioClassification.from_pretrained("results/checkpoint-30")
|
|
|
|
|
| audio_path = "dataset/lisp/sample_01.wav"
|
| audio, original_sr = librosa.load(audio_path, sr=44100)
|
|
|
|
|
| target_sr = 16000
|
| if original_sr != target_sr:
|
| audio = librosa.resample(audio, orig_sr=original_sr, target_sr=target_sr)
|
|
|
|
|
| mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=target_sr, n_mels=80, hop_length=512)
|
| mel_spectrogram_db = librosa.power_to_db(mel_spectrogram)
|
|
|
| import torch
|
|
|
|
|
| max_len = 3000
|
| pad_width = (0, max_len - mel_spectrogram_db.shape[1])
|
| mel_spectrogram_db_padded = torch.nn.functional.pad(torch.from_numpy(mel_spectrogram_db).float().unsqueeze(1),
|
| pad_width, mode='constant', value=0)
|
|
|
|
|
|
|
| input_features = mel_spectrogram_db_padded
|
|
|
|
|
| input_features = input_features.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
| inputs = {'input_features': input_features}
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model(**inputs)
|
| logits = outputs.logits
|
| predicted_class_ids = torch.argmax(logits).item()
|
| predicted_label = model.config.id2label[predicted_class_ids]
|
|
|
| print("Predicted label:", predicted_label)
|
|
|