SHAD_Homework / src /paper_classifier.py
AndreyForty's picture
Update src/paper_classifier.py
bab7a71 verified
from __future__ import annotations
from typing import Iterable
from pathlib import Path
BASE_MODEL_NAME = "distilbert/distilbert-base-cased"
DEFAULT_MODEL_DIR = Path(__file__).resolve().parent.parent
MAX_LENGTH = 256
TOP_P_THRESHOLD = 0.95
EXPECTED_ARXIV_CATEGORIES = [
"Computer Science",
"Physics",
"Mathematics",
"Statistics",
"Quantitative Biology",
"Quantitative Finance",
"Economics",
"Electrical Engineering and Systems Science",
]
EXAMPLES = {
"Graph Neural Networks": {
"title": "Message Passing Neural Networks for Molecular Property Prediction",
"abstract": (
"We introduce a graph-based neural architecture for supervised learning on "
"molecular graphs. The model propagates messages between atoms, aggregates "
"node states into a graph embedding, and predicts physical and chemical "
"properties with competitive accuracy."
),
},
"Physics": {
"title": "Topological phase transitions in two-dimensional quantum materials",
"abstract": (
"We study a lattice model with strong spin-orbit coupling and show how "
"interactions modify the phase diagram. Using numerical simulations we "
"characterize edge states, quantify transport signatures, and discuss "
"observable consequences for low-temperature experiments."
),
},
"Bioinformatics": {
"title": "Transformer models for protein function annotation from sequence",
"abstract": (
"We pretrain a transformer encoder on amino acid sequences and finetune it "
"for protein function prediction. The approach improves annotation quality "
"for underrepresented families and reveals biologically meaningful sequence "
"patterns."
),
},
}
def format_input_text(title: str, abstract: str) -> str:
title = title.strip()
abstract = abstract.strip()
parts: list[str] = []
if title:
parts.append(f"Title: {title}\nTitle summary: {title}")
if abstract:
parts.append(f"Abstract: {abstract}")
return "\n\n".join(parts)
def take_top_p(records: Iterable[dict[str, float]], threshold: float) -> list[dict[str, float]]:
selected: list[dict[str, float]] = []
cumulative = 0.0
for record in records:
selected.append(record)
cumulative += record["score"]
if cumulative >= threshold:
break
return selected