from __future__ import annotations import json import traceback from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any, Callable, Dict from common.model_client import ModelConfig, MultiProviderLLMClient from common.paper_package import PaperPackage from .final_prompts import ( SYSTEM_TWO_PASS_FORMATTER, SYSTEM_TWO_PASS_JUDGE, SYSTEM_TWO_PASS_REASONING, formatter_prompt, judge_prompt, reasoning_prompt, ) from .schemas import JudgeResult, UIPayload @dataclass class TwoPassPipelineResult: run_dir: Path result: Dict[str, Any] class FormatterStageError(RuntimeError): def __init__(self, message: str, run_dir: Path): super().__init__(message) self.run_dir = run_dir class TwoPassAnnotationPipeline: def __init__( self, *, provider: str, model: str, formatter_model: str | None, judge_model: str | None, output_root: Path, run_label: str | None = None, annotator_id: str = "llm", temperature: float = 0.2, max_tokens: int = 16000, candidate_count: int = 1, formatter_max_attempts: int = 3, include_reference_examples: bool = True, prompt_profile: str = "full", progress_callback: Callable[[str], None] | None = None, ): self.output_root = output_root self.annotator_id = annotator_id self.progress_callback = progress_callback self.run_label = run_label self.candidate_count = max(1, candidate_count) self.formatter_max_attempts = max(1, formatter_max_attempts) self.include_reference_examples = include_reference_examples self.prompt_profile = prompt_profile self.use_judge = self.candidate_count > 1 stage_models = {} if formatter_model: stage_models["two_pass_formatter"] = formatter_model if judge_model and self.use_judge: stage_models["two_pass_judge"] = judge_model self.client = MultiProviderLLMClient( default_config=ModelConfig( provider=provider, model=model, temperature=temperature, max_tokens=max_tokens, ), stage_models=stage_models, ) def run(self, paper: PaperPackage) -> TwoPassPipelineResult: run_dir = self._make_run_dir(paper) payload = { **paper.to_prompt_payload(), "paper_dir": paper.paper_dir, "full_processed_text": self._load_full_processed_text(paper), } formatter_config = self.client.config_for_stage("two_pass_formatter") self._log( f"[run] paper={paper.paper_dir.name} provider={self.client.default_config.provider} model={self.client.default_config.model_name}" ) self._log(f"[run] formatter_model={formatter_config.model_name}") self._log(f"[run] include_reference_examples={self.include_reference_examples}") self._log(f"[run] prompt_profile={self.prompt_profile}") if self.use_judge: judge_config = self.client.config_for_stage("two_pass_judge") self._log(f"[run] judge_model={judge_config.model_name}") else: self._log("[run] judge_model=disabled (candidate_count=1)") self._log(f"[run] output={run_dir}") reasoning_user_prompt = reasoning_prompt(payload, include_reference_examples=self.include_reference_examples, prompt_profile=self.prompt_profile) self._write_text(run_dir / "pass_1_reasoning.prompt.txt", reasoning_user_prompt) self._log( f"[pass 1] free-form reasoning ({self.candidate_count} candidate{'s' if self.candidate_count != 1 else ''})" ) candidate_texts: list[str] = [] candidate_paths: list[str] = [] for index in range(self.candidate_count): reasoning_text = self.client.generate_text( stage_name="two_pass_reasoning", system_prompt=SYSTEM_TWO_PASS_REASONING, user_prompt=reasoning_user_prompt, ) candidate_id = f"candidate_{index + 1}" candidate_path = run_dir / f"pass_1_reasoning.output.{candidate_id}.md" self._write_text(candidate_path, reasoning_text) candidate_texts.append(reasoning_text) candidate_paths.append(str(candidate_path)) selected_candidate_index = 0 selected_candidate_id = "candidate_1" selected_reasoning_text = candidate_texts[0] judge_output_path: Path | None = None if self.use_judge: judge_user_prompt = judge_prompt(payload, candidate_texts) self._write_text(run_dir / "pass_1_reasoning.judge.prompt.txt", judge_user_prompt) self._log("[pass 1] candidate judging") judge_result = self.client.generate_structured( stage_name="two_pass_judge", system_prompt=SYSTEM_TWO_PASS_JUDGE, user_prompt=judge_user_prompt, response_model=JudgeResult, ) judge_output_path = run_dir / "pass_1_reasoning.judge.output.json" self._write_json(judge_output_path, judge_result.model_dump()) selected_candidate_index = judge_result.selected_candidate_index selected_candidate_id = judge_result.selected_candidate_id selected_reasoning_text = candidate_texts[selected_candidate_index] selected_reasoning_path = run_dir / "pass_1_reasoning.selected.md" self._write_text(selected_reasoning_path, selected_reasoning_text) formatter_user_prompt = formatter_prompt(payload, selected_reasoning_text, self.annotator_id) self._write_text(run_dir / "pass_2_formatter.prompt.txt", formatter_user_prompt) final_payload: UIPayload | None = None formatter_attempts: list[dict[str, Any]] = [] for attempt in range(1, self.formatter_max_attempts + 1): self._log( f"[pass 2] strict ui json formatting (attempt {attempt}/{self.formatter_max_attempts})" ) try: final_payload = self.client.generate_structured( stage_name="two_pass_formatter", system_prompt=SYSTEM_TWO_PASS_FORMATTER, user_prompt=formatter_user_prompt, response_model=UIPayload, ) formatter_attempts.append({"attempt": attempt, "status": "success"}) break except Exception as exc: error_text = "".join(traceback.format_exception(exc)).strip() error_path = run_dir / f"pass_2_formatter.attempt_{attempt}.error.txt" self._write_text(error_path, error_text) formatter_attempts.append( { "attempt": attempt, "status": "failed", "error": str(exc), "error_path": str(error_path), } ) if attempt < self.formatter_max_attempts: self._log("[pass 2] formatter failed; retrying formatter only") if final_payload is None: self._write_json(run_dir / "formatter_attempts.json", {"attempts": formatter_attempts}) raise FormatterStageError( f"Formatter failed after {self.formatter_max_attempts} attempts; pass 1 outputs kept in {run_dir}", run_dir, ) self._write_json(run_dir / "formatter_attempts.json", {"attempts": formatter_attempts}) self._write_json(run_dir / "pass_2_ui_payload.json", final_payload.model_dump()) result = { "paper_id": paper.paper_dir.name, "paper_dir": str(paper.paper_dir), "generated_at": datetime.now(timezone.utc).isoformat(), "reasoner_model": self.client.default_config.model_name, "formatter_model": formatter_config.model_name, "judge_model": judge_config.model_name if self.use_judge else None, "candidate_count": self.candidate_count, "include_reference_examples": self.include_reference_examples, "prompt_profile": self.prompt_profile, "reasoning_candidate_paths": [str(path) for path in candidate_paths], "selected_reasoning_candidate": selected_candidate_id, "selected_candidate_index": selected_candidate_index, "selected_reasoning_path": str(selected_reasoning_path), "judge_output_path": str(judge_output_path) if judge_output_path is not None else None, "formatter_attempts": formatter_attempts, "ui_payload_path": str(run_dir / "pass_2_ui_payload.json"), "ui_payload": final_payload.model_dump(), } self._write_json(run_dir / "run_output.json", result) self._log("[run] complete") return TwoPassPipelineResult(run_dir=run_dir, result=result) def _make_run_dir(self, paper: PaperPackage) -> Path: stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") run_name = stamp if self.run_label: run_name = f"{self._slugify(self.run_label)}__{stamp}" run_dir = self.output_root / paper.paper_dir.name / run_name run_dir.mkdir(parents=True, exist_ok=True) return run_dir def _write_json(self, path: Path, payload: Dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2, ensure_ascii=True) + "\n") def _write_text(self, path: Path, text: str) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(text) def _log(self, message: str) -> None: if self.progress_callback: self.progress_callback(message) @staticmethod def _slugify(value: str) -> str: slug = "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "-" for ch in value.strip()) slug = "-".join(part for part in slug.split("-") if part) return slug[:160] or "run" def _load_full_processed_text(self, paper: PaperPackage) -> str: processed_path = paper.paper_dir / "processed_main.tex" if processed_path.exists(): try: return processed_path.read_text() except Exception: pass sections_dir = paper.paper_dir / "sections" parts: list[str] = [] if sections_dir.exists(): for path in sorted(sections_dir.iterdir()): if not path.is_file(): continue try: text = path.read_text().strip() except Exception: continue if text: parts.append(f"[{path.name}]\n{text}") return "\n\n".join(parts)