muhammadbinmurtza
Restructure: clauseguard as package subfolder, app_file: clauseguard/app.py
913a064 | """Agent 2: Classifier — assigns clause types and detects contract type.""" | |
| import json | |
| import logging | |
| from typing import Any | |
| from clauseguard.config.prompts import CLASSIFIER_SYSTEM_PROMPT | |
| from clauseguard.models.clause import Clause, ClauseList, ClauseType | |
| from clauseguard.services.model_service import call_model, clean_json_response | |
| logger = logging.getLogger(__name__) | |
| MAX_RETRIES = 1 | |
| async def run_classifier(clause_list: ClauseList) -> ClauseList: | |
| """Classify each clause and detect the overall contract type. | |
| Args: | |
| clause_list: The ClauseList from the Extractor agent. | |
| Returns: | |
| An updated ClauseList with clause_type and contract_type filled in. | |
| """ | |
| input_json = clause_list.model_dump_json(indent=2) | |
| content = await call_model( | |
| system_prompt=CLASSIFIER_SYSTEM_PROMPT, | |
| user_prompt=f"Classify these clauses:\n{input_json}", | |
| agent_name="Classifier", | |
| max_retries=MAX_RETRIES, | |
| ) | |
| if content is None: | |
| logger.warning("Classifier produced no valid output, returning original clauses") | |
| return clause_list | |
| return _parse_response(content, clause_list) | |
| def _parse_response(content: str, original: ClauseList) -> ClauseList: | |
| """Parse the classifier JSON response and merge with original data.""" | |
| cleaned = clean_json_response(content) | |
| data = json.loads(cleaned) | |
| clauses_data = data.get("clauses", data if isinstance(data, list) else []) | |
| contract_type = data.get("contract_type", "Other") | |
| classified_clauses: list[Clause] = [] | |
| for c in clauses_data: | |
| clause_type_raw = c.get("clause_type", "OTHER") | |
| try: | |
| clause_type = ClauseType(clause_type_raw) | |
| except ValueError: | |
| clause_type = ClauseType.OTHER | |
| classified_clauses.append( | |
| Clause( | |
| id=c.get("id", 0), | |
| raw_text=c.get("raw_text", ""), | |
| plain_english=c.get("plain_english"), | |
| clause_type=clause_type, | |
| section_heading=c.get("section_heading"), | |
| position=c.get("position", 0), | |
| confidence_score=c.get("confidence_score"), | |
| ) | |
| ) | |
| return ClauseList( | |
| clauses=classified_clauses, | |
| contract_type=contract_type, | |
| total_clauses=len(classified_clauses), | |
| ) | |