""" Paper Researcher — terminal-style Gradio demo with parallel streaming. """ import os import re import json import ast from paper_researcher import PaperResearcher from paper_researcher.tasks import SYSTEM_PROMPT import paper_researcher.tasks as t import gradio as gr MODEL_ID = os.environ.get("MODEL_ID", "JaydeepR/SmolLM-135M-SFT-exp01") MAX_NEW_TOKENS = 256 TEMPERATURE = 0.4 researcher = PaperResearcher(MODEL_ID) THEME = gr.themes.Base( font=[gr.themes.GoogleFont("JetBrains Mono"), "Fira Code", "monospace"], ).set( body_background_fill="#0d1117", block_background_fill="#0d1117", panel_background_fill="#0d1117", input_background_fill="#161b22", input_border_color="#30363d", input_border_color_focus="#58a6ff", block_border_color="#30363d", block_border_width="1px", block_label_text_color="#58a6ff", block_label_text_size="11px", body_text_color="#c9d1d9", body_text_size="12px", button_primary_background_fill="#238636", button_primary_background_fill_hover="#2ea043", button_primary_text_color="#ffffff", button_secondary_background_fill="#21262d", button_secondary_border_color="#30363d", button_secondary_text_color="#c9d1d9", ) CSS = """ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&display=swap'); footer { display: none !important; } .output-box textarea { color: #ffffff !important; font-size: 14px !important; line-height: 1.6 !important; } ul[role="listbox"] { background: #161b22 !important; } ul[role="listbox"] li { color: #e6edf3 !important; background: #161b22 !important; } ul[role="listbox"] li:hover { background: #21262d !important; } .messages-box textarea { color: #e3b341 !important; font-size: 11px !important; } .temp-header p { color: #f78166 !important; font-size: 11px !important; font-weight: 600 !important; letter-spacing: 2px !important; margin: 4px 0 2px 0 !important; } .gen-btn button { width: 100% !important; letter-spacing: 1px !important; } .stats-bar p { color: #8b949e !important; font-size: 14px !important; letter-spacing: 1px !important; margin: 2px 0 4px 0 !important; } .block { padding: 6px !important; } .gap, .gap-2 { gap: 6px !important; } .left-col { border-right: 1px solid #21262d !important; } #right-panel .output-box textarea { height: calc(100vh - 200px) !important; overflow-y: auto !important; resize: none !important; } """ MODES = { "bullets": {"desc": "Extract key points as bullets", "hint": "Paste passage...", "call": lambda r, x, **k: "\n".join(f"- {b}" for b in r.extract_bullets(x, **k))}, "qa_pairs": {"desc": "Generate Q&A pairs", "hint": "Paste passage...", "call": lambda r, x, **k: "\n\n".join(f"Q: {p.question}\nA: {p.answer}" for p in r.generate_qa_pairs(x, **k))}, "question": {"desc": "Generate a question from passage", "hint": "Paste passage...", "call": lambda r, x, **k: r.generate_question(x, **k)}, "fact": {"desc": "Extract a single fact", "hint": "Paste passage...", "call": lambda r, x, **k: r.extract_fact(x, **k)}, "answer": {"desc": "Answer question given passage", "hint": "passage...\n\n[blank line]\n\nquestion", "call": lambda r, x, **k: r.answer(*reversed(_split2(x)), **k)}, "rephrase": {"desc": "Rephrase and elaborate", "hint": "Paste passage...", "call": lambda r, x, **k: r.rephrase(x, **k)}, "continuation": {"desc": "Continue passage from beginning", "hint": "Paste start of passage...", "call": lambda r, x, **k: r.continue_from(x, **k)}, "triplets": {"desc": "Extract knowledge graph triplets", "hint": "Paste passage...", "call": lambda r, x, **k: "\n".join(f"({tri.subject}, {tri.relation}, {tri.object})" for tri in r.extract_triplets(x, **k))}, "comparison": {"desc": "Compare two passages", "hint": "passage 1...\n\n[blank line]\n\npassage 2", "call": lambda r, x, **k: r.compare(*_split2(x), **k)}, "retrieval": {"desc": "Find which passage answers a question", "hint": "Passage 1: ...\n\nPassage 2: ...\n\nQuestion: ...", "call": lambda r, x, **k: _fmt_retrieval(r.find_relevant(*_parse_retrieval(x), **k))}, } MODE_KEYS = list(MODES.keys()) MODE_CHOICES = [f"{k} — {MODES[k]['desc']}" for k in MODE_KEYS] INSTRUCTION_MAP = { "bullets": t.BULLETS_INSTRUCTION, "qa_pairs": t.QA_PAIRS_INSTRUCTION, "question": t.QUESTION_FROM_PASSAGE_INSTRUCTION, "fact": t.FACT_FROM_PASSAGE_INSTRUCTION, "answer": t.QA_ANSWER_INSTRUCTION, "rephrase": t.REPHRASE_INSTRUCTION, "continuation": t.CONTINUATION_INSTRUCTION, "triplets": t.TRIPLETS_INSTRUCTION, "comparison": t.COMPARISON_INSTRUCTION, "retrieval": t.RETRIEVAL_INSTRUCTION, } def _split2(text): parts = text.split("\n\n", 1) return parts[0].strip(), (parts[1].strip() if len(parts) > 1 else "") def _fmt_retrieval(r): idx = f"Passage {r.index + 1}" if r.index is not None else "None" return f"Passage: {idx}\n\n{r.reasoning}" def _parse_retrieval(text): passages = re.findall(r"Passage \d+:\s*(.*?)(?=Passage \d+:|Question:|$)", text, re.DOTALL) passages = [p.strip() for p in passages if p.strip()] q = re.search(r"Question:\s*(.+)", text, re.DOTALL) return (q.group(1).strip() if q else ""), passages def _maybe_prettify(text: str) -> str: stripped = text.strip() if stripped.startswith("[") and stripped.endswith("]"): try: val = ast.literal_eval(stripped) if isinstance(val, list): return "\n\n".join(str(item) for item in val) except (ValueError, SyntaxError): pass try: return json.dumps(json.loads(stripped), indent=2) except (json.JSONDecodeError, ValueError): pass return text def key_from_choice(c): return c.split(" — ")[0].strip() def on_mode_change(choice): key = key_from_choice(choice) return gr.update(placeholder=MODES[key]["hint"]) def build_messages_preview(choice, text): if not text.strip(): return "" key = key_from_choice(choice) inst = INSTRUCTION_MAP.get(key, "") messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"{inst}\n\n{text}"}] return json.dumps(messages, indent=2) def build_prompt(mode_key, user_text): inst = INSTRUCTION_MAP[mode_key] tok = researcher._backend.tokenizer msgs = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"{inst}\n\n{user_text}"}] return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) def generate(mode_choice, user_text): if not user_text.strip(): return "// no input", "" key = key_from_choice(mode_choice) try: result = MODES[key]["call"](researcher, user_text, temperature=TEMPERATURE, max_new_tokens=MAX_NEW_TOKENS) return _maybe_prettify(str(result)), "" except Exception as e: return f"// error: {e}", "" with gr.Blocks(title="paper_researcher", theme=THEME, css=CSS) as demo: gr.Markdown( f"// paper_researcher" f"  {MODEL_ID}" f"  ⚠ CPU — responses may take 30-60s" ) with gr.Row(equal_height=False): with gr.Column(scale=1, elem_classes=["left-col"], elem_id="left-panel"): mode_dd = gr.Dropdown(choices=MODE_CHOICES, value=MODE_CHOICES[0], label="mode") user_box = gr.Textbox(label="input", lines=10, placeholder=MODES[MODE_KEYS[0]]["hint"]) messages_box = gr.Textbox(label="messages preview", lines=8, interactive=False, elem_classes=["messages-box"]) gen_btn = gr.Button("▶ GENERATE", variant="primary", elem_classes=["gen-btn"]) with gr.Column(scale=1, elem_id="right-panel"): stats_md = gr.Markdown("", elem_classes=["stats-bar"]) output_box = gr.Textbox(show_label=False, lines=20, interactive=False, elem_classes=["output-box"]) mode_dd.change(fn=on_mode_change, inputs=mode_dd, outputs=user_box) mode_dd.change(fn=build_messages_preview, inputs=[mode_dd, user_box], outputs=messages_box) user_box.change(fn=build_messages_preview, inputs=[mode_dd, user_box], outputs=messages_box) gen_btn.click(fn=generate, inputs=[mode_dd, user_box], outputs=[output_box, stats_md]) demo.launch()