| |
|
|
| import os |
| import requests |
| import wikipedia |
| import gradio as gr |
| import torch |
|
|
| from functools import lru_cache |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import List |
|
|
| from transformers import ( |
| SeamlessM4TTokenizer, |
| SeamlessM4TProcessor, |
| SeamlessM4TForTextToText, |
| pipeline as hf_pipeline |
| ) |
|
|
| |
|
|
| MODEL = "facebook/hf-seamless-m4t-medium" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| tokenizer = SeamlessM4TTokenizer.from_pretrained(MODEL, use_fast=False) |
| processor = SeamlessM4TProcessor.from_pretrained(MODEL, tokenizer=tokenizer) |
|
|
| m4t_model = SeamlessM4TForTextToText.from_pretrained(MODEL).to(device) |
| if device == "cuda": |
| m4t_model = m4t_model.half() |
| m4t_model.eval() |
|
|
| def translate_m4t(text: str, src_iso3: str, tgt_iso3: str, auto_detect=False) -> str: |
| src = None if auto_detect else src_iso3 |
| inputs = processor(text=text, src_lang=src, return_tensors="pt").to(device) |
| tokens = m4t_model.generate(**inputs, tgt_lang=tgt_iso3) |
| return processor.decode(tokens[0].tolist(), skip_special_tokens=True) |
|
|
| def translate_m4t_batch( |
| texts: List[str], src_iso3: str, tgt_iso3: str, auto_detect=False |
| ) -> List[str]: |
| src = None if auto_detect else src_iso3 |
| inputs = processor( |
| text=texts, src_lang=src, return_tensors="pt", padding=True |
| ).to(device) |
| tokens = m4t_model.generate( |
| **inputs, |
| tgt_lang=tgt_iso3, |
| max_new_tokens=60, |
| num_beams=1 |
| ) |
| return processor.batch_decode(tokens, skip_special_tokens=True) |
|
|
|
|
| |
|
|
| ner = hf_pipeline( |
| "ner", |
| model="dslim/bert-base-NER-uncased", |
| aggregation_strategy="simple" |
| ) |
|
|
|
|
| |
|
|
| @lru_cache(maxsize=256) |
| def geocode_cache(place: str): |
| r = requests.get( |
| "https://nominatim.openstreetmap.org/search", |
| params={"q": place, "format": "json", "limit": 1}, |
| headers={"User-Agent": "iVoiceContext/1.0"} |
| ).json() |
| if not r: |
| return None |
| return {"lat": float(r[0]["lat"]), "lon": float(r[0]["lon"])} |
|
|
| @lru_cache(maxsize=256) |
| def fetch_osm_cache(lat: float, lon: float, osm_filter: str, limit: int = 5): |
| payload = f""" |
| [out:json][timeout:25]; |
| ( |
| node{osm_filter}(around:1000,{lat},{lon}); |
| way{osm_filter}(around:1000,{lat},{lon}); |
| ); |
| out center {limit}; |
| """ |
| resp = requests.post( |
| "https://overpass-api.de/api/interpreter", |
| data={"data": payload} |
| ) |
| elems = resp.json().get("elements", []) |
| return [ |
| {"name": e["tags"]["name"]} |
| for e in elems |
| if e.get("tags", {}).get("name") |
| ] |
|
|
| @lru_cache(maxsize=256) |
| def wiki_summary_cache(name: str) -> str: |
| try: |
| return wikipedia.summary(name, sentences=2) |
| except: |
| return "No summary available." |
|
|
|
|
| |
|
|
| def process_entity(ent) -> dict: |
| w = ent["word"] |
| lbl = ent["entity_group"] |
|
|
| if lbl == "LOC": |
| geo = geocode_cache(w) |
| if not geo: |
| return { |
| "text": w, |
| "label": lbl, |
| "type": "location", |
| "error": "could not geocode" |
| } |
|
|
| restaurants = fetch_osm_cache(geo["lat"], geo["lon"], '["amenity"="restaurant"]') |
| attractions = fetch_osm_cache(geo["lat"], geo["lon"], '["tourism"="attraction"]') |
|
|
| return { |
| "text": w, |
| "label": lbl, |
| "type": "location", |
| "geo": geo, |
| "restaurants": restaurants, |
| "attractions": attractions |
| } |
|
|
| |
| summary = wiki_summary_cache(w) |
| return {"text": w, "label": lbl, "type": "wiki", "summary": summary} |
|
|
|
|
| |
|
|
| def get_context( |
| text: str, |
| source_lang: str, |
| output_lang: str, |
| auto_detect: bool |
| ): |
| |
| if auto_detect or source_lang != "eng": |
| en = translate_m4t(text, source_lang, "eng", auto_detect=auto_detect) |
| else: |
| en = text |
|
|
| |
| ner_out = ner(en) |
| seen = set() |
| unique_ents = [] |
| for ent in ner_out: |
| w = ent["word"] |
| if w in seen: |
| continue |
| seen.add(w) |
| unique_ents.append(ent) |
|
|
| |
| entities = [] |
| with ThreadPoolExecutor(max_workers=8) as exe: |
| futures = [exe.submit(process_entity, ent) for ent in unique_ents] |
| for fut in futures: |
| entities.append(fut.result()) |
|
|
| |
| if source_lang != "eng": |
| to_translate = [] |
| translations_info = [] |
| |
| for i, e in enumerate(entities): |
| if e["type"] == "wiki": |
| translations_info.append(("summary", i)) |
| to_translate.append(e["summary"]) |
| elif e["type"] == "location": |
| for j, r in enumerate(e["restaurants"]): |
| translations_info.append(("restaurant", i, j)) |
| to_translate.append(r["name"]) |
| for j, a in enumerate(e["attractions"]): |
| translations_info.append(("attraction", i, j)) |
| to_translate.append(a["name"]) |
| |
| |
| translated = translate_m4t_batch(to_translate, "eng", source_lang) |
| |
| for txt, info in zip(translated, translations_info): |
| kind = info[0] |
| if kind == "summary": |
| _, ei = info |
| entities[ei]["summary"] = txt |
| elif kind == "restaurant": |
| _, ei, ri = info |
| entities[ei]["restaurants"][ri]["name"] = txt |
| elif kind == "attraction": |
| _, ei, ai = info |
| entities[ei]["attractions"][ai]["name"] = txt |
|
|
|
|
| return {"entities": entities} |
|
|
|
|
| |
|
|
| iface = gr.Interface( |
| fn=get_context, |
| inputs=[ |
| gr.Textbox(lines=3, placeholder="Enter textβ¦"), |
| gr.Textbox(label="Source Language (ISO 639-3)"), |
| gr.Textbox(label="Target Language (ISO 639-3)"), |
| gr.Checkbox(label="Auto-detect source language") |
| ], |
| outputs="json", |
| title="iVoice Context-Aware", |
| description="Returns only the detected entities and their related info." |
| ).queue() |
| |
|
|
| if __name__ == "__main__": |
| iface.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", 7860)), |
| share=True |
| ) |
|
|