import json import re import numpy as np import torch import librosa from transformers import AutoProcessor, AutoModelForCTC AUDIO_PATH = "sample_trim.wav" ALIGN_PATH = "output/text_alignment_global.json" OUT_PATH = "output/word_timestamps.json" MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-arabic" ARABIC_DIACRITICS = re.compile(r"[\u064B-\u0652\u0670\u0653\u0654\u0655]") TATWEEL = "\u0640" def normalize_ar(s: str) -> str: s = s.replace(TATWEEL, "") s = re.sub(ARABIC_DIACRITICS, "", s) s = s.replace("أ", "ا").replace("إ", "ا").replace("آ", "ا") s = s.replace("ى", "ي") s = s.replace("ة", "ه") s = re.sub(r"\s+", " ", s).strip() return s def main(): # Load alignment align = json.load(open(ALIGN_PATH, encoding="utf-8")) alignment = [a for a in align["alignment"] if a.get("canon")] # Load audio audio, sr = librosa.load(AUDIO_PATH, sr=16000, mono=True) total_sec = len(audio) / sr # Load CTC model processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForCTC.from_pretrained(MODEL_ID) model.eval() inputs = processor(audio, sampling_rate=sr, return_tensors="pt", padding=True) with torch.no_grad(): logits = model(**inputs).logits[0] # (T, V) pred_ids = torch.argmax(logits, dim=-1).cpu().numpy().tolist() # Convert token IDs -> tokens vocab = processor.tokenizer.get_vocab() # invert vocab: id -> token inv_vocab = {i: t for t, i in vocab.items()} blank_id = processor.tokenizer.pad_token_id if blank_id is None: # fallback: common wav2vec2 blank is vocab[""] blank_id = vocab.get("", None) tokens = [inv_vocab[i] for i in pred_ids] # Collapse repeats, remove blanks, keep time indices collapsed = [] prev = None for t_idx, tok_id in enumerate(pred_ids): if tok_id == prev: continue prev = tok_id if blank_id is not None and tok_id == blank_id: continue tok = inv_vocab.get(tok_id, "") if tok.strip() == "": continue collapsed.append((t_idx, tok)) # Map CTC time index -> seconds # time steps correspond to model frames spanning full audio T = logits.shape[0] def idx_to_time(i): return (i / T) * total_sec # Prepare normalized ASR tokens from alignment file (we use them to locate spans) asr_tokens = [] for a in alignment: if a["asr_token"] is None: asr_tokens.append(None) else: asr_tokens.append(normalize_ar(a["asr_token"])) # We will approximate word timestamps by scanning collapsed tokens and # finding the earliest and latest CTC indices where the letters of the ASR token appear in order. # # This is a heuristic but works reasonably for MVP. def find_span_for_word(word_norm, start_search_idx): if not word_norm: return None, start_search_idx # remove spaces target = word_norm.replace(" ", "") if target == "": return None, start_search_idx i = start_search_idx start_idx = None last_idx = None for ch in target: found = False while i < len(collapsed): t_idx, tok = collapsed[i] # tokens may be characters or pieces; match if character appears if ch in tok: if start_idx is None: start_idx = t_idx last_idx = t_idx i += 1 found = True break i += 1 if not found: return None, start_search_idx return (start_idx, last_idx), i out_rows = [] search_ptr = 0 for a in alignment: cw = a["canon"] tok = a["asr_token"] tok_norm = normalize_ar(tok) if tok else None span, search_ptr2 = find_span_for_word(tok_norm, search_ptr) if tok_norm else (None, search_ptr) if span is None: start_t = None end_t = None else: s_idx, e_idx = span start_t = round(float(idx_to_time(s_idx)), 3) end_t = round(float(idx_to_time(e_idx)), 3) # advance pointer to keep order search_ptr = search_ptr2 out_rows.append({ "ayah": cw["ayah"], "word": cw["word"], "asr_token": tok, "score": a["score"], "match": a["match"], "timestamp": None if start_t is None else {"start": start_t, "end": end_t} }) out = { "audio_path": AUDIO_PATH, "model": MODEL_ID, "note": "CTC-based approximate word timestamps; upgrade later with forced alignment for higher accuracy.", "stats": { "words": len(out_rows), "timestamped": sum(1 for r in out_rows if r["timestamp"] is not None) }, "words": out_rows } json.dump(out, open(OUT_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2) print("OK ✅ wrote", OUT_PATH) print("Timestamped:", out["stats"]["timestamped"], "/", out["stats"]["words"]) print("Sample:", out_rows[0]) if __name__ == "__main__": main()