| import os |
| import json |
| import asyncio |
| import argparse |
| import httpx |
| from tqdm.asyncio import tqdm |
| from transformers import AutoProcessor |
|
|
| |
| DATA_PATH = "/home/mshahidul/readctrl/data/testing_data_gs/multiclinsum_gs_train_en.json" |
| OUT_PATH_TEMPLATE = ( |
| "/home/mshahidul/readctrl/data/translated_data/" |
| "multiclinsum_gs_train_{source_lang}2{target_lang}_gemma({start}_{end}).json" |
| ) |
|
|
| TRANSLATE_URL = "http://127.0.0.1:8080/v1/chat/completions" |
| CONCURRENCY_LIMIT = 8 |
|
|
| model_id = "google/translategemma-27b-it" |
| processor = AutoProcessor.from_pretrained(model_id) |
|
|
| semaphore = asyncio.Semaphore(CONCURRENCY_LIMIT) |
|
|
| async def call_llm(client, url, model, messages, temperature=0.1, max_tokens=None): |
| """Generic async caller for both Translation and Judge.""" |
| async with semaphore: |
| try: |
| payload = { |
| "model": model, |
| "messages": messages, |
| "temperature": temperature |
| } |
| if max_tokens is not None: |
| payload["max_tokens"] = max_tokens |
| response = await client.post(url, json=payload, timeout=60.0) |
| result = response.json() |
| return result['choices'][0]['message']['content'].strip() |
| except Exception as e: |
| return None |
|
|
| def build_gemma_prompt(text, source_lang="en", target_lang="bn"): |
| messages = [{ |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "source_lang_code": source_lang, |
| "target_lang_code": target_lang, |
| "text": text, |
| } |
| ], |
| }] |
| prompt = processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| messages=[{"role": "user", "content": prompt}] |
| return messages |
|
|
| async def process_record(client, record, source_lang, target_lang): |
| """Translates a single JSON record.""" |
| |
| |
| translated_fulltext_prompt = build_gemma_prompt( |
| record['fulltext'], source_lang=source_lang, target_lang=target_lang |
| ) |
| translated_summary_prompt = build_gemma_prompt( |
| record['summary'], source_lang=source_lang, target_lang=target_lang |
| ) |
| translated_fulltext = await call_llm( |
| client, TRANSLATE_URL, "translate_gemma", translated_fulltext_prompt, max_tokens=4092 |
| ) |
| translated_summary = await call_llm( |
| client, TRANSLATE_URL, "translate_gemma", translated_summary_prompt, max_tokens=1024 |
| ) |
|
|
| record['translated_fulltext'] = translated_fulltext |
| record['translated_summary'] = translated_summary |
| return record |
|
|
| def record_key(record): |
| record_id = record.get("id") |
| if record_id is not None: |
| return str(record_id) |
| return f"{record.get('fulltext', '')}||{record.get('summary', '')}" |
|
|
| def has_valid_translation(record): |
| translated_fulltext = record.get("translated_fulltext") |
| translated_summary = record.get("translated_summary") |
| return translated_fulltext is not None and translated_summary is not None |
|
|
| async def main(): |
| parser = argparse.ArgumentParser(description="Translate Multiclinsum dataset.") |
| parser.add_argument("--source-lang", default="en", help="Source language code") |
| parser.add_argument("--target-lang", default="bn", help="Target language code") |
| parser.add_argument( |
| "--start-idx", |
| type=int, |
| default=0, |
| help="Start index (inclusive) of the slice to translate", |
| ) |
| parser.add_argument( |
| "--end-idx", |
| type=int, |
| default=200, |
| help="End index (exclusive) of the slice to translate", |
| ) |
| args = parser.parse_args() |
|
|
| start_idx = args.start_idx |
| end_idx = args.end_idx |
|
|
| out_path = OUT_PATH_TEMPLATE.format( |
| source_lang=args.source_lang, |
| target_lang=args.target_lang, |
| start=start_idx, |
| end=end_idx, |
| ) |
|
|
| with open(DATA_PATH, 'r', encoding='utf-8') as f: |
| all_data = json.load(f) |
| data = all_data[start_idx:end_idx] |
|
|
| async with httpx.AsyncClient() as client: |
| existing_results = [] |
| if os.path.exists(out_path): |
| with open(out_path, 'r', encoding='utf-8') as f: |
| existing_results = json.load(f) |
|
|
| existing_by_key = {record_key(rec): rec for rec in existing_results} |
| output_results = [] |
|
|
| batch_size = 10 |
| max_regen = len(data) |
| regenerated = 0 |
| for i in tqdm(range(0, len(data), batch_size)): |
| batch = data[i:i + batch_size] |
| pending = [] |
| pending_keys = [] |
| new_generated = 0 |
|
|
| for rec in batch: |
| key = record_key(rec) |
| existing = existing_by_key.get(key) |
| if existing and has_valid_translation(existing): |
| output_results.append(existing) |
| else: |
| if regenerated < max_regen: |
| pending.append(process_record(client, rec, args.source_lang, args.target_lang)) |
| pending_keys.append(key) |
| regenerated += 1 |
| elif existing: |
| output_results.append(existing) |
|
|
| if pending: |
| processed = await asyncio.gather(*pending) |
| for key, rec in zip(pending_keys, processed): |
| if rec is not None: |
| existing_by_key[key] = rec |
| output_results.append(rec) |
| new_generated += 1 |
|
|
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| with open(out_path, 'w', encoding='utf-8') as f: |
| json.dump(output_results, f, ensure_ascii=False, indent=4) |
| print( |
| f"Batch {i // batch_size + 1}: new={new_generated}, total={len(output_results)}" |
| ) |
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |