File size: 1,248 Bytes
e078b1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from __future__ import annotations
import argparse
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from src.models.registry import summarize_text

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--models", nargs="+", required=True)
    parser.add_argument("--text-column", default="Description")
    parser.add_argument("--max-length", type=int, default=96)
    args = parser.parse_args()

    df = pd.read_csv(args.input)
    records = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Generating summaries"):
        source_text = str(row[args.text_column])
        for model_name in args.models:
            summary = summarize_text(source_text, model_name=model_name, max_new_tokens=args.max_length)
            payload = row.to_dict()
            payload["model_name"] = model_name
            payload["generated_summary"] = summary
            records.append(payload)
    out_df = pd.DataFrame(records)
    Path(args.output).parent.mkdir(parents=True, exist_ok=True)
    out_df.to_csv(args.output, index=False)
    print(f"Wrote {len(out_df)} rows to {args.output}")