import os import re import json import torch import gradio as gr import functools from peft import get_peft_model, LoraConfig from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, AutoModelForCausalLM def load_example_index(index_path): with open(index_path, 'r', encoding='utf-8') as f: return json.load(f) def load_paper_text(text_path): with open(text_path, 'r', encoding='utf-8') as f: return f.read() @functools.lru_cache(maxsize=1) def load_model(): pt_path = hf_hub_download( repo_id="strarr/reddit-sumarize-model", filename="model.pt" ) checkpoint = torch.load(pt_path, map_location="cpu", weights_only=True) if 'lora_state_dict' not in checkpoint: raise ValueError("Checkpoint does not contain LoRA state dict") config_dict = checkpoint.get('config', {}) lora_settings = config_dict.get('advanced', {}).get('lora', {}) base_model_name = "Qwen/Qwen3-0.6B-Base" tokenizer = AutoTokenizer.from_pretrained(base_model_name) base_model = AutoModelForCausalLM.from_pretrained(base_model_name) lora_config = LoraConfig( r=lora_settings.get('r', 8), lora_alpha=lora_settings.get('alpha', 16), target_modules=lora_settings.get('target_modules', ["q_proj", "v_proj"]), lora_dropout=lora_settings.get('dropout', 0.1), bias=lora_settings.get('bias', "none"), task_type="CAUSAL_LM" ) model = get_peft_model(base_model, lora_config) model.load_state_dict(checkpoint["lora_state_dict"], strict=False) model.eval() return tokenizer, model def summarize(content, tokenizer, model, title="Untitled", topic="arXiv", max_input=1024, max_output=200): prompt = f"SUBREDDIT: r/{topic}\nTITLE: {title}\nPOST: {content}\nTL;DR:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_output, pad_token_id=tokenizer.pad_token_id, do_sample=True, temperature=0.7, top_p=0.9 ) summary = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) return summary.strip() def parse_example_file(filepath): with open(filepath, "r", encoding="utf-8") as f: text = f.read().strip() topic_match = re.search(r'^TOPIC:\s*(.*)$', text, flags=re.MULTILINE) title_match = re.search(r'^TITLE:\s*(.*)$', text, flags=re.MULTILINE) content_match = re.search(r'^CONTENT:\s*([\s\S]*)$', text, flags=re.MULTILINE) if topic_match and title_match and content_match: return { 'topic': topic_match.group(1).strip(), 'title': title_match.group(1).strip(), 'content': content_match.group(1).strip() } else: return { 'topic': "arXiv", 'title': os.path.basename(filepath), 'content': text }