Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import torch | |
| import gradio as gr | |
| import functools | |
| from peft import get_peft_model, LoraConfig | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| def load_example_index(index_path): | |
| with open(index_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| def load_paper_text(text_path): | |
| with open(text_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| def load_model(): | |
| pt_path = hf_hub_download( | |
| repo_id="strarr/reddit-sumarize-model", | |
| filename="model.pt" | |
| ) | |
| checkpoint = torch.load(pt_path, map_location="cpu", weights_only=True) | |
| if 'lora_state_dict' not in checkpoint: | |
| raise ValueError("Checkpoint does not contain LoRA state dict") | |
| config_dict = checkpoint.get('config', {}) | |
| lora_settings = config_dict.get('advanced', {}).get('lora', {}) | |
| base_model_name = "Qwen/Qwen3-0.6B-Base" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| base_model = AutoModelForCausalLM.from_pretrained(base_model_name) | |
| lora_config = LoraConfig( | |
| r=lora_settings.get('r', 8), | |
| lora_alpha=lora_settings.get('alpha', 16), | |
| target_modules=lora_settings.get('target_modules', ["q_proj", "v_proj"]), | |
| lora_dropout=lora_settings.get('dropout', 0.1), | |
| bias=lora_settings.get('bias', "none"), | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(base_model, lora_config) | |
| model.load_state_dict(checkpoint["lora_state_dict"], strict=False) | |
| model.eval() | |
| return tokenizer, model | |
| def summarize(content, tokenizer, model, title="Untitled", topic="arXiv", max_input=1024, max_output=200): | |
| prompt = f"SUBREDDIT: r/{topic}\nTITLE: {title}\nPOST: {content}\nTL;DR:" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_output, | |
| pad_token_id=tokenizer.pad_token_id, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| summary = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
| return summary.strip() | |
| def parse_example_file(filepath): | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| text = f.read().strip() | |
| topic_match = re.search(r'^TOPIC:\s*(.*)$', text, flags=re.MULTILINE) | |
| title_match = re.search(r'^TITLE:\s*(.*)$', text, flags=re.MULTILINE) | |
| content_match = re.search(r'^CONTENT:\s*([\s\S]*)$', text, flags=re.MULTILINE) | |
| if topic_match and title_match and content_match: | |
| return { | |
| 'topic': topic_match.group(1).strip(), | |
| 'title': title_match.group(1).strip(), | |
| 'content': content_match.group(1).strip() | |
| } | |
| else: | |
| return { | |
| 'topic': "arXiv", | |
| 'title': os.path.basename(filepath), | |
| 'content': text | |
| } |