Spaces:
Running
Running
File size: 8,781 Bytes
b4ac377 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | from __future__ import annotations
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from openenv.core.env_server.types import (
Action as OpenEnvAction,
Observation as OpenEnvObservation,
State as OpenEnvState,
)
from pydantic import BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
# Inline base classes — removes the openenv package dependency so this module
# These local aliases subclass OpenEnv core types while preserving the
# permissive Pydantic behavior expected by the FastAPI schema.
# ---------------------------------------------------------------------------
class Action(OpenEnvAction):
"""Base class for all environment actions."""
model_config = ConfigDict(extra="allow", validate_assignment=True)
class Observation(OpenEnvObservation):
"""Base class for all environment observations."""
model_config = ConfigDict(extra="allow", validate_assignment=True, arbitrary_types_allowed=True)
done: bool = Field(default=False, description="Whether the episode has terminated")
reward: Union[bool, int, float, None] = Field(
default=None, description="Reward signal from the last action"
)
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional metadata for the observation"
)
class State(OpenEnvState):
"""Base class for environment state."""
model_config = ConfigDict(extra="allow", validate_assignment=True, arbitrary_types_allowed=True)
episode_id: Optional[str] = Field(default=None, description="Unique identifier for the current episode")
step_count: int = Field(default=0, ge=0, description="Number of steps taken in the current episode")
# ---------------------------------------------------------------------------
# Domain models
# ---------------------------------------------------------------------------
class ClaimStatus(str, Enum):
OPEN = "open"
INVESTIGATING = "investigating"
DECIDED = "decided"
CLOSED = "closed"
class InsuranceClaimReward(BaseModel):
fraud_detection_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Fraction of expected fraud signals found")
decision_accuracy: float = Field(default=0.0, ge=0.0, le=1.0, description="1.0 if final decision matches allowed decisions, else 0.0")
payout_accuracy: float = Field(default=0.0, ge=0.0, le=1.0, description="Score for payout estimate within the expected band")
efficiency_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Step efficiency: higher when fewer steps used")
consistency_score: float = Field(default=0.0, ge=0.0, le=1.0, description="For coordinated_fraud: quality of linked-claim targeting")
evidence_quality_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Fraction of flagged signals backed by keyword-grounded evidence")
calibration_score: Optional[float] = Field(default=None, description="3×2 matrix calibration score in [-1.0, 1.0]. Only populated on terminal actions.")
exploit_penalty: float = Field(default=0.0, ge=0.0, description="Penalty for looping or duplicate actions")
penalty: float = Field(default=0.0, description="Total accumulated penalty subtracted from weighted score")
total: float = Field(default=0.0, ge=0.0, le=1.0, description="Final clamped reward in [0.0, 1.0]")
class InsuranceClaimAction(Action):
action_type: Literal[
"validate_document",
"request_information",
"lookup_policy_history",
"compare_documents",
"flag_fraud_signal",
"estimate_payout",
"query_historical_data", # DebateFloor: alias for lookup_policy_history
"query_linked_claim", # coordinated_ring: reveals linked claim detail
"verify_identity", # identity_fraud: cross-checks registry
"verify_provider_registration", # phantom_provider: checks IRDAI registry
"convene_debate_panel", # Multi-agent: prosecutor vs defender arguments
"approve_claim",
"deny_claim",
"request_investigation",
"escalate_to_human", # DebateFloor terminal: for coordinated_ring / hard tasks
] = Field(..., description="The type of action to perform on the claim")
parameters: Dict[str, Any] = Field(
default_factory=dict,
description="Action-specific parameters. See /schema for required fields per action_type.",
)
reasoning: str = Field(
default="",
max_length=4000,
description="Agent's reasoning for this action. Used for evidence quality scoring.",
)
confidence: Optional[Literal["HIGH", "MED", "LOW"]] = Field(
default=None,
description="Agent's declared confidence level. Required for terminal actions (approve_claim, deny_claim, escalate_to_human). Graded via 3×2 calibration matrix.",
)
def model_post_init(self, __context: Any) -> None:
terminal_actions = {"approve_claim", "deny_claim", "escalate_to_human"}
if self.action_type in terminal_actions and self.confidence is None:
raise ValueError(
f"confidence is required for terminal action '{self.action_type}'. Must be HIGH, MED, or LOW."
)
class InsuranceClaimObservation(Observation):
claim_id: str = Field(..., description="Unique identifier for this claim")
task_id: str = Field(..., description="Task identifier: clean_claim | contradictory_claim | coordinated_fraud")
claimant: Dict[str, Any] = Field(..., description="Claimant personal and policy details")
incident: Dict[str, Any] = Field(..., description="Incident date, location, type, and description")
documents: List[Dict[str, Any]] = Field(..., description="Claim documents available for validation")
linked_claims: List[Dict[str, Any]] = Field(
default_factory=list,
description="For coordinated_fraud: stub entries with claim_id and claimant only. Use query_linked_claim to retrieve full details.",
)
action_history: List[Dict[str, Any]] = Field(default_factory=list, description="Actions taken so far this episode")
available_actions: List[str] = Field(default_factory=list, description="Valid action_type values for this task")
step_number: int = Field(default=0, description="Current step number (0-indexed from reset)")
max_steps: int = Field(default=0, description="Maximum steps allowed before episode closes")
investigation_budget: int = Field(default=0, description="Total budget units for this episode")
budget_remaining: int = Field(default=0, description="Budget units remaining. Going negative adds a 0.02 penalty per unit over budget.")
flags_raised: List[str] = Field(default_factory=list, description="Fraud signal flag IDs raised so far")
discovered_signals: List[str] = Field(
default_factory=list,
description="Fraud signals actually discovered through allowed investigative actions.",
)
status: ClaimStatus = Field(default=ClaimStatus.OPEN, description="Current claim processing status")
message: str = Field(default="", description="Human-readable message describing result of last action")
confidence_required: bool = Field(default=True, description="Whether next action requires a confidence declaration")
reward_breakdown: InsuranceClaimReward = Field(default_factory=InsuranceClaimReward, description="Detailed reward components for current step")
rubric_reward: float = Field(default=0.0, description="Reward returned by the composed OpenEnv rubric")
rubric_components: Dict[str, float] = Field(
default_factory=dict,
description="Named leaf rubric scores for logging and analysis",
)
debate_transcript: Optional[Dict[str, Any]] = Field(
default=None,
description="Multi-agent debate panel output. Populated after convene_debate_panel action. Contains prosecutor_argument, defender_argument, and panel_verdict.",
)
class InsuranceClaimState(State):
task_id: str = ""
claim_id: str = ""
step_number: int = 0
max_steps: int = 0
status: ClaimStatus = ClaimStatus.OPEN
flags_raised: List[str] = Field(default_factory=list)
discovered_signals: List[str] = Field(default_factory=list)
found_signals: List[str] = Field(default_factory=list)
penalty_total: float = 0.0
done: bool = False
last_action_error: Optional[str] = None
payout_estimate_inr: Optional[float] = None
final_decision: Optional[str] = None
final_score: float = 0.0
|