from __future__ import annotations from typing import Iterable from pathlib import Path BASE_MODEL_NAME = "distilbert/distilbert-base-cased" DEFAULT_MODEL_DIR = Path(__file__).resolve().parent.parent MAX_LENGTH = 256 TOP_P_THRESHOLD = 0.95 EXPECTED_ARXIV_CATEGORIES = [ "Computer Science", "Physics", "Mathematics", "Statistics", "Quantitative Biology", "Quantitative Finance", "Economics", "Electrical Engineering and Systems Science", ] EXAMPLES = { "Graph Neural Networks": { "title": "Message Passing Neural Networks for Molecular Property Prediction", "abstract": ( "We introduce a graph-based neural architecture for supervised learning on " "molecular graphs. The model propagates messages between atoms, aggregates " "node states into a graph embedding, and predicts physical and chemical " "properties with competitive accuracy." ), }, "Physics": { "title": "Topological phase transitions in two-dimensional quantum materials", "abstract": ( "We study a lattice model with strong spin-orbit coupling and show how " "interactions modify the phase diagram. Using numerical simulations we " "characterize edge states, quantify transport signatures, and discuss " "observable consequences for low-temperature experiments." ), }, "Bioinformatics": { "title": "Transformer models for protein function annotation from sequence", "abstract": ( "We pretrain a transformer encoder on amino acid sequences and finetune it " "for protein function prediction. The approach improves annotation quality " "for underrepresented families and reveals biologically meaningful sequence " "patterns." ), }, } def format_input_text(title: str, abstract: str) -> str: title = title.strip() abstract = abstract.strip() parts: list[str] = [] if title: parts.append(f"Title: {title}\nTitle summary: {title}") if abstract: parts.append(f"Abstract: {abstract}") return "\n\n".join(parts) def take_top_p(records: Iterable[dict[str, float]], threshold: float) -> list[dict[str, float]]: selected: list[dict[str, float]] = [] cumulative = 0.0 for record in records: selected.append(record) cumulative += record["score"] if cumulative >= threshold: break return selected