debatefloor / app /models.py
AniketAsla's picture
sync: mirror git d05fcb5 to Space
b4ac377 verified
from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from openenv.core.env_server.types import (
Action as OpenEnvAction,
Observation as OpenEnvObservation,
State as OpenEnvState,
)
from pydantic import BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
# Inline base classes — removes the openenv package dependency so this module
# These local aliases subclass OpenEnv core types while preserving the
# permissive Pydantic behavior expected by the FastAPI schema.
# ---------------------------------------------------------------------------
class Action(OpenEnvAction):
"""Base class for all environment actions."""
model_config = ConfigDict(extra="allow", validate_assignment=True)
class Observation(OpenEnvObservation):
"""Base class for all environment observations."""
model_config = ConfigDict(extra="allow", validate_assignment=True, arbitrary_types_allowed=True)
done: bool = Field(default=False, description="Whether the episode has terminated")
reward: Union[bool, int, float, None] = Field(
default=None, description="Reward signal from the last action"
)
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional metadata for the observation"
)
class State(OpenEnvState):
"""Base class for environment state."""
model_config = ConfigDict(extra="allow", validate_assignment=True, arbitrary_types_allowed=True)
episode_id: Optional[str] = Field(default=None, description="Unique identifier for the current episode")
step_count: int = Field(default=0, ge=0, description="Number of steps taken in the current episode")
# ---------------------------------------------------------------------------
# Domain models
# ---------------------------------------------------------------------------
class ClaimStatus(str, Enum):
OPEN = "open"
INVESTIGATING = "investigating"
DECIDED = "decided"
CLOSED = "closed"
class InsuranceClaimReward(BaseModel):
fraud_detection_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Fraction of expected fraud signals found")
decision_accuracy: float = Field(default=0.0, ge=0.0, le=1.0, description="1.0 if final decision matches allowed decisions, else 0.0")
payout_accuracy: float = Field(default=0.0, ge=0.0, le=1.0, description="Score for payout estimate within the expected band")
efficiency_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Step efficiency: higher when fewer steps used")
consistency_score: float = Field(default=0.0, ge=0.0, le=1.0, description="For coordinated_fraud: quality of linked-claim targeting")
evidence_quality_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Fraction of flagged signals backed by keyword-grounded evidence")
calibration_score: Optional[float] = Field(default=None, description="3×2 matrix calibration score in [-1.0, 1.0]. Only populated on terminal actions.")
exploit_penalty: float = Field(default=0.0, ge=0.0, description="Penalty for looping or duplicate actions")
penalty: float = Field(default=0.0, description="Total accumulated penalty subtracted from weighted score")
total: float = Field(default=0.0, ge=0.0, le=1.0, description="Final clamped reward in [0.0, 1.0]")
class InsuranceClaimAction(Action):
action_type: Literal[
"validate_document",
"request_information",
"lookup_policy_history",
"compare_documents",
"flag_fraud_signal",
"estimate_payout",
"query_historical_data", # DebateFloor: alias for lookup_policy_history
"query_linked_claim", # coordinated_ring: reveals linked claim detail
"verify_identity", # identity_fraud: cross-checks registry
"verify_provider_registration", # phantom_provider: checks IRDAI registry
"convene_debate_panel", # Multi-agent: prosecutor vs defender arguments
"approve_claim",
"deny_claim",
"request_investigation",
"escalate_to_human", # DebateFloor terminal: for coordinated_ring / hard tasks
] = Field(..., description="The type of action to perform on the claim")
parameters: Dict[str, Any] = Field(
default_factory=dict,
description="Action-specific parameters. See /schema for required fields per action_type.",
)
reasoning: str = Field(
default="",
max_length=4000,
description="Agent's reasoning for this action. Used for evidence quality scoring.",
)
confidence: Optional[Literal["HIGH", "MED", "LOW"]] = Field(
default=None,
description="Agent's declared confidence level. Required for terminal actions (approve_claim, deny_claim, escalate_to_human). Graded via 3×2 calibration matrix.",
)
def model_post_init(self, __context: Any) -> None:
terminal_actions = {"approve_claim", "deny_claim", "escalate_to_human"}
if self.action_type in terminal_actions and self.confidence is None:
raise ValueError(
f"confidence is required for terminal action '{self.action_type}'. Must be HIGH, MED, or LOW."
)
class InsuranceClaimObservation(Observation):
claim_id: str = Field(..., description="Unique identifier for this claim")
task_id: str = Field(..., description="Task identifier: clean_claim | contradictory_claim | coordinated_fraud")
claimant: Dict[str, Any] = Field(..., description="Claimant personal and policy details")
incident: Dict[str, Any] = Field(..., description="Incident date, location, type, and description")
documents: List[Dict[str, Any]] = Field(..., description="Claim documents available for validation")
linked_claims: List[Dict[str, Any]] = Field(
default_factory=list,
description="For coordinated_fraud: stub entries with claim_id and claimant only. Use query_linked_claim to retrieve full details.",
)
action_history: List[Dict[str, Any]] = Field(default_factory=list, description="Actions taken so far this episode")
available_actions: List[str] = Field(default_factory=list, description="Valid action_type values for this task")
step_number: int = Field(default=0, description="Current step number (0-indexed from reset)")
max_steps: int = Field(default=0, description="Maximum steps allowed before episode closes")
investigation_budget: int = Field(default=0, description="Total budget units for this episode")
budget_remaining: int = Field(default=0, description="Budget units remaining. Going negative adds a 0.02 penalty per unit over budget.")
flags_raised: List[str] = Field(default_factory=list, description="Fraud signal flag IDs raised so far")
discovered_signals: List[str] = Field(
default_factory=list,
description="Fraud signals actually discovered through allowed investigative actions.",
)
status: ClaimStatus = Field(default=ClaimStatus.OPEN, description="Current claim processing status")
message: str = Field(default="", description="Human-readable message describing result of last action")
confidence_required: bool = Field(default=True, description="Whether next action requires a confidence declaration")
reward_breakdown: InsuranceClaimReward = Field(default_factory=InsuranceClaimReward, description="Detailed reward components for current step")
rubric_reward: float = Field(default=0.0, description="Reward returned by the composed OpenEnv rubric")
rubric_components: Dict[str, float] = Field(
default_factory=dict,
description="Named leaf rubric scores for logging and analysis",
)
debate_transcript: Optional[Dict[str, Any]] = Field(
default=None,
description="Multi-agent debate panel output. Populated after convene_debate_panel action. Contains prosecutor_argument, defender_argument, and panel_verdict.",
)
class InsuranceClaimState(State):
task_id: str = ""
claim_id: str = ""
step_number: int = 0
max_steps: int = 0
status: ClaimStatus = ClaimStatus.OPEN
flags_raised: List[str] = Field(default_factory=list)
discovered_signals: List[str] = Field(default_factory=list)
found_signals: List[str] = Field(default_factory=list)
penalty_total: float = 0.0
done: bool = False
last_action_error: Optional[str] = None
payout_estimate_inr: Optional[float] = None
final_decision: Optional[str] = None
final_score: float = 0.0