File size: 2,355 Bytes
3552405 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """Agent 2: Classifier — assigns clause types and detects contract type."""
import json
import logging
from typing import Any
from clauseguard.config.prompts import CLASSIFIER_SYSTEM_PROMPT
from clauseguard.models.clause import Clause, ClauseList, ClauseType
from clauseguard.services.model_service import call_model, clean_json_response
logger = logging.getLogger(__name__)
MAX_RETRIES = 1
async def run_classifier(clause_list: ClauseList) -> ClauseList:
"""Classify each clause and detect the overall contract type.
Args:
clause_list: The ClauseList from the Extractor agent.
Returns:
An updated ClauseList with clause_type and contract_type filled in.
"""
input_json = clause_list.model_dump_json(indent=2)
content = await call_model(
system_prompt=CLASSIFIER_SYSTEM_PROMPT,
user_prompt=f"Classify these clauses:\n{input_json}",
agent_name="Classifier",
max_retries=MAX_RETRIES,
)
if content is None:
logger.warning("Classifier produced no valid output, returning original clauses")
return clause_list
return _parse_response(content, clause_list)
def _parse_response(content: str, original: ClauseList) -> ClauseList:
"""Parse the classifier JSON response and merge with original data."""
cleaned = clean_json_response(content)
data = json.loads(cleaned)
clauses_data = data.get("clauses", data if isinstance(data, list) else [])
contract_type = data.get("contract_type", "Other")
classified_clauses: list[Clause] = []
for c in clauses_data:
clause_type_raw = c.get("clause_type", "OTHER")
try:
clause_type = ClauseType(clause_type_raw)
except ValueError:
clause_type = ClauseType.OTHER
classified_clauses.append(
Clause(
id=c.get("id", 0),
raw_text=c.get("raw_text", ""),
plain_english=c.get("plain_english"),
clause_type=clause_type,
section_heading=c.get("section_heading"),
position=c.get("position", 0),
confidence_score=c.get("confidence_score"),
)
)
return ClauseList(
clauses=classified_clauses,
contract_type=contract_type,
total_clauses=len(classified_clauses),
)
|