| |
| """ |
| Inference script for the Academic Paper Classifier. |
| |
| Loads a fine-tuned DistilBERT model and predicts the arxiv category for a |
| given paper abstract. Returns the predicted category along with per-class |
| confidence scores. |
| |
| Usage examples: |
| # Use a local model directory |
| python inference.py --model_path ./model --abstract "We propose a novel ..." |
| |
| # Use a HuggingFace Hub model |
| python inference.py --model_path gr8monk3ys/paper-classifier-model \ |
| --abstract "We propose a novel ..." |
| |
| # Interactive mode (reads from stdin) |
| python inference.py --model_path ./model |
| |
| Author: Lorenzo Scaturchio (gr8monk3ys) |
| License: MIT |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
| class PaperClassifier: |
| """Thin wrapper around a fine-tuned sequence-classification model. |
| |
| Parameters |
| ---------- |
| model_path : str |
| Path to a local model directory **or** a HuggingFace Hub model id. |
| device : str | None |
| Target device (``"cpu"``, ``"cuda"``, ``"mps"``). If *None* the best |
| available device is selected automatically. |
| """ |
|
|
| def __init__(self, model_path: str, device: str | None = None) -> None: |
| if device is None: |
| if torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
| self.device = torch.device(device) |
|
|
| logger.info("Loading tokenizer from: %s", model_path) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
| logger.info("Loading model from: %s", model_path) |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
| self.model.to(self.device) |
|
|
| |
| self.id2label: dict[int, str] = self.model.config.id2label |
| logger.info("Labels: %s", list(self.id2label.values())) |
|
|
| @torch.no_grad() |
| def predict(self, abstract: str, top_k: int | None = None) -> dict: |
| """Classify a single paper abstract. |
| |
| Parameters |
| ---------- |
| abstract : str |
| The paper abstract to classify. |
| top_k : int | None |
| If given, only the *top_k* categories (by confidence) are returned |
| in ``scores``. Pass *None* to return all categories. |
| |
| Returns |
| ------- |
| dict |
| ``{"label": str, "confidence": float, "scores": {label: prob}}`` |
| """ |
| self.model.eval() |
|
|
| inputs = self.tokenizer( |
| abstract, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=512, |
| ).to(self.device) |
|
|
| logits = self.model(**inputs).logits |
| probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy() |
|
|
| sorted_indices = probs.argsort()[::-1] |
| if top_k is not None: |
| sorted_indices = sorted_indices[:top_k] |
|
|
| scores = { |
| self.id2label[int(idx)]: float(probs[idx]) for idx in sorted_indices |
| } |
|
|
| best_idx = int(probs.argmax()) |
| return { |
| "label": self.id2label[best_idx], |
| "confidence": float(probs[best_idx]), |
| "scores": scores, |
| } |
|
|
|
|
| |
| |
| |
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Classify an academic paper abstract into an arxiv category." |
| ) |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="./model", |
| help="Path to the fine-tuned model directory or HF Hub id (default: %(default)s).", |
| ) |
| parser.add_argument( |
| "--abstract", |
| type=str, |
| default=None, |
| help="Paper abstract text. If omitted, the script enters interactive mode.", |
| ) |
| parser.add_argument( |
| "--top_k", |
| type=int, |
| default=None, |
| help="Only show the top-k predictions (default: show all).", |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default=None, |
| choices=["cpu", "cuda", "mps"], |
| help="Device to run inference on (default: auto-detect).", |
| ) |
| parser.add_argument( |
| "--json", |
| action="store_true", |
| default=False, |
| dest="output_json", |
| help="Output raw JSON instead of human-readable text.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _print_result(result: dict, output_json: bool) -> None: |
| """Pretty-print or JSON-dump a prediction result.""" |
| if output_json: |
| print(json.dumps(result, indent=2)) |
| return |
|
|
| print(f"\n Predicted category : {result['label']}") |
| print(f" Confidence : {result['confidence']:.4f}") |
| print(" ---------------------------------") |
| for label, score in result["scores"].items(): |
| bar = "#" * int(score * 40) |
| print(f" {label:<10s} {score:6.4f} {bar}") |
| print() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| classifier = PaperClassifier(model_path=args.model_path, device=args.device) |
|
|
| if args.abstract is not None: |
| result = classifier.predict(args.abstract, top_k=args.top_k) |
| _print_result(result, args.output_json) |
| return |
|
|
| |
| print("Academic Paper Classifier - Interactive Mode") |
| print("Enter a paper abstract (or 'quit' to exit).") |
| print("For multi-line input, end with an empty line.\n") |
|
|
| while True: |
| try: |
| lines: list[str] = [] |
| prompt = "abstract> " if sys.stdin.isatty() else "" |
| while True: |
| line = input(prompt) |
| if line.strip().lower() == "quit": |
| logger.info("Exiting.") |
| return |
| if line == "" and lines: |
| break |
| lines.append(line) |
| prompt = "... " if sys.stdin.isatty() else "" |
|
|
| abstract = " ".join(lines).strip() |
| if not abstract: |
| continue |
|
|
| result = classifier.predict(abstract, top_k=args.top_k) |
| _print_result(result, args.output_json) |
|
|
| except (EOFError, KeyboardInterrupt): |
| print() |
| logger.info("Exiting.") |
| return |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|