Spaces:
Sleeping
Sleeping
File size: 3,625 Bytes
5403e87 e424d70 5403e87 c09b215 5403e87 548816e 5403e87 548816e 5403e87 c09b215 5403e87 548816e 5403e87 548816e 5403e87 548816e e424d70 548816e e424d70 548816e e424d70 548816e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import os
import re
import torch
import sentencepiece as spm
import torchaudio
from torchaudio.models.decoder import ctc_decoder
from torchaudio.transforms import Resample
from squeezeformer import MySqueezeformer
import torch.ao.quantization
# -------------------------
# Paths
# -------------------------
dirname = os.path.dirname(__file__)
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(dirname, "../ressources/tokenizer/128_v7.model"))
# sp_lm = spm.SentencePieceProcessor()
# sp_lm.Load(os.path.join(dirname, "../ressources/tokenizer/5K.model"))
# lm = kenlm.Model("./ressources/kenLM_model/kab_5k_6-gram_v2.bin")
# -------------------------
# Device
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------
# Model
# -------------------------
model = MySqueezeformer().to(device)
acoustic_model_path = os.path.join(dirname, "../ressources/e2e_model/squeezeformer")
if device == "cpu":
model = torch.ao.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
model.load_state_dict(torch.load(acoustic_model_path, map_location=device))
model.eval()
# -------------------------
# Decoder
# -------------------------
tokens_file = os.path.join(dirname, "../ressources/tokenizer/128_v7.txt")
# decoder = ctc_decoder(
# tokens=tokens_file,
# lexicon=None,
# beam_size=128,
# beam_threshold=10,
# beam_size_token=10,
# nbest=50,
# log_add=True,
# blank_token="_",
# sil_token="|",
# unk_word="<unk>",
# )
decoder = ctc_decoder(
tokens=tokens_file,
lexicon=None,
beam_size=1,
nbest=1,
log_add=True,
blank_token="_",
sil_token="|",
unk_word="<unk>",
)
# -------------------------
# Inference
# -------------------------
@torch.no_grad()
def inference(audiofile: str) -> str:
# ---- Load audio ----
waveform, sr = torchaudio.load(audiofile)
# Convert to 16 kHz if necessary
target_sr = 16000
if sr != target_sr:
resampler = Resample(orig_freq=sr, new_freq=target_sr)
waveform = resampler(waveform)
sr = target_sr
# Convert to mono
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.to(device)
# Shape: [B, T]
lengths = torch.tensor([waveform.size(1)], device=device)
# ---- Model forward ----
outputs, _ = model.forward(waveform, lengths)
# Move to CPU for decoder
outputs = outputs.cpu()
# ---- Decode ----
# results_array = decoder(outputs)[0]
# transcriptions = []
# scores = []
result = decoder(outputs)[0][0]
tokens = decoder.idxs_to_tokens(result.tokens)
transcription = "".join(tokens)
# for result in results_array:
# tokens -> string
# tokens = decoder.idxs_to_tokens(result.tokens)
# transcription = "".join(tokens)
transcription = transcription.replace("_", "")
transcription = transcription.replace("|", "")
transcription = transcription.replace("▁", " ")
transcription = " ".join(transcription.split())
transcription = re.sub(r"-{2,}", "-", transcription).strip()
return transcription
# transcriptions.append(transcription)
# # ---- LM scoring ----
# lm_input = " ".join(sp_lm.Encode(transcription, out_type=str))
# lm_input = lm_input.replace("- ", "-").replace(" -", "-")
# lm_score = lm.score(lm_input)
# # weighted score
# score = lm_score * 0.25 + result.score * 0.75
# scores.append(score)
# best_idx = int(np.argmax(scores))
# return transcriptions[best_idx]
|