File size: 3,625 Bytes
5403e87
e424d70
5403e87
 
 
 
 
 
c09b215
5403e87
 
 
 
 
 
 
 
 
548816e
 
5403e87
548816e
5403e87
 
 
 
 
 
 
 
 
 
 
 
c09b215
 
 
 
 
5403e87
 
 
 
 
 
 
 
 
 
548816e
 
 
 
 
 
 
 
 
 
 
 
5403e87
 
 
548816e
 
5403e87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548816e
 
 
 
e424d70
548816e
 
e424d70
548816e
 
 
 
 
e424d70
 
 
 
 
 
548816e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import re
import torch
import sentencepiece as spm
import torchaudio
from torchaudio.models.decoder import ctc_decoder
from torchaudio.transforms import Resample
from squeezeformer import MySqueezeformer
import torch.ao.quantization

# -------------------------
# Paths
# -------------------------
dirname = os.path.dirname(__file__)

sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(dirname, "../ressources/tokenizer/128_v7.model"))

# sp_lm = spm.SentencePieceProcessor()
# sp_lm.Load(os.path.join(dirname, "../ressources/tokenizer/5K.model"))

# lm = kenlm.Model("./ressources/kenLM_model/kab_5k_6-gram_v2.bin")

# -------------------------
# Device
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# Model
# -------------------------
model = MySqueezeformer().to(device)

acoustic_model_path = os.path.join(dirname, "../ressources/e2e_model/squeezeformer")
if device == "cpu":
    model = torch.ao.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )


model.load_state_dict(torch.load(acoustic_model_path, map_location=device))

model.eval()

# -------------------------
# Decoder
# -------------------------
tokens_file = os.path.join(dirname, "../ressources/tokenizer/128_v7.txt")

# decoder = ctc_decoder(
#     tokens=tokens_file,
#     lexicon=None,
#     beam_size=128,
#     beam_threshold=10,
#     beam_size_token=10,
#     nbest=50,
#     log_add=True,
#     blank_token="_",
#     sil_token="|",
#     unk_word="<unk>",
# )
decoder = ctc_decoder(
    tokens=tokens_file,
    lexicon=None,
    beam_size=1,
    nbest=1,
    log_add=True,
    blank_token="_",
    sil_token="|",
    unk_word="<unk>",
)


# -------------------------
# Inference
# -------------------------
@torch.no_grad()
def inference(audiofile: str) -> str:
    # ---- Load audio ----
    waveform, sr = torchaudio.load(audiofile)
    # Convert to 16 kHz if necessary
    target_sr = 16000
    if sr != target_sr:
        resampler = Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform)
        sr = target_sr
    # Convert to mono
    if waveform.size(0) > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    waveform = waveform.to(device)

    # Shape: [B, T]
    lengths = torch.tensor([waveform.size(1)], device=device)

    # ---- Model forward ----
    outputs, _ = model.forward(waveform, lengths)

    # Move to CPU for decoder
    outputs = outputs.cpu()

    # ---- Decode ----
    # results_array = decoder(outputs)[0]

    # transcriptions = []
    # scores = []
    result = decoder(outputs)[0][0]
    tokens = decoder.idxs_to_tokens(result.tokens)
    transcription = "".join(tokens)

    # for result in results_array:
    # tokens -> string
    # tokens = decoder.idxs_to_tokens(result.tokens)
    # transcription = "".join(tokens)

    transcription = transcription.replace("_", "")
    transcription = transcription.replace("|", "")
    transcription = transcription.replace("▁", " ")
    transcription = " ".join(transcription.split())
    transcription = re.sub(r"-{2,}", "-", transcription).strip()
    return transcription
    # transcriptions.append(transcription)

    # # ---- LM scoring ----
    # lm_input = " ".join(sp_lm.Encode(transcription, out_type=str))
    # lm_input = lm_input.replace("- ", "-").replace(" -", "-")

    # lm_score = lm.score(lm_input)

    # # weighted score
    # score = lm_score * 0.25 + result.score * 0.75
    # scores.append(score)

    # best_idx = int(np.argmax(scores))
    # return transcriptions[best_idx]