claims-env / models.py
akhiilll's picture
Deploy ClaimSense adjudication gym
1cfeb15 verified
"""ClaimSense — typed payloads exchanged with the adjudication gym.
Three Pydantic shells sit on top of OpenEnv's base contracts:
* ``AdjudicatorAction`` — what the agent submits each turn.
* ``AdjudicatorObservation`` — what comes back to the agent.
* ``AdjudicatorState`` — bookkeeping the server retains, including
hidden ground truth used for reward shaping.
The ``Claims*`` aliases at the bottom keep the OpenEnv ``create_fastapi_app``
wiring stable and let any older import paths continue to resolve, but new
code should reference the descriptive names.
"""
from __future__ import annotations
from typing import Any
from openenv.core import Action, Observation, State
from pydantic import Field
# --- Action vocabulary -----------------------------------------------------
# Centralised so the env, the client helpers, and tests can share the list.
INFORMATION_ACTIONS: tuple[str, ...] = (
"query_policy",
"query_claim_history",
"check_fraud",
"request_documents",
"verify_coverage",
"verify_purchase",
"calculate_payout",
)
TERMINAL_ACTIONS: tuple[str, ...] = ("approve", "deny", "escalate")
ALL_ACTIONS: tuple[str, ...] = INFORMATION_ACTIONS + TERMINAL_ACTIONS
# --- Action ---------------------------------------------------------------
class AdjudicatorAction(Action):
"""A single move from the adjudicator agent.
The interesting field is ``action_type``; ``parameters`` carries
per-action arguments such as ``payout``, ``reason``, ``damage_type``.
"""
action_type: str = Field(description="Verb the agent wants to perform")
claim_id: str = Field(default="", description="Claim under review (optional)")
parameters: dict[str, Any] = Field(
default_factory=dict,
description="Free-form keyword payload for the chosen verb",
)
# --- Observation ----------------------------------------------------------
class AdjudicatorObservation(Observation):
"""Information returned to the agent after every action.
Partial observability is enforced through ``revealed_info``: the agent
only sees what it has explicitly queried. Terminal flags ride on the
same payload so downstream RL frameworks can grab them in one fetch.
"""
# Header — always populated.
claim_id: str = Field(default="")
claim_type: str = Field(default="")
claim_amount_requested: float = Field(default=0.0)
claimant_name: str = Field(default="")
incident_date: str = Field(default="")
description: str = Field(default="")
# Channel back from the env after the latest action.
system_response: str = Field(default="")
action_success: bool = Field(default=True)
# Knowledge the agent has unlocked so far (grows over the episode).
revealed_info: dict[str, Any] = Field(default_factory=dict)
# Hint to constrained policies: which verbs are still legal.
available_actions: list[str] = Field(default_factory=list)
# Telemetry (purely informational).
time_elapsed_minutes: int = Field(default=0)
queries_made: int = Field(default=0)
# Episode termination.
is_terminal: bool = Field(default=False)
terminal_reason: str = Field(default="")
# OpenEnv expects the reward to live on the observation envelope.
reward: float = Field(default=0.0)
# --- State ----------------------------------------------------------------
class AdjudicatorState(State):
"""Server-side episode bookkeeping + hidden ground truth.
The ground-truth columns (``true_verdict``, ``correct_payout``,
``is_fraud`` …) drive reward shaping; the agent never sees them
directly.
"""
# Public summary
claim_id: str = Field(default="")
claim_type: str = Field(default="")
claim_amount_requested: float = Field(default=0.0)
# Hidden truth used for reward computation
true_verdict: str = Field(default="")
correct_payout: float = Field(default=0.0)
is_fraud: bool = Field(default=False)
fraud_type: str | None = Field(default=None)
# Policy artefacts revealed only when queried
policy_coverage_limit: float = Field(default=0.0)
policy_deductible: float = Field(default=0.0)
policy_status: str = Field(default="")
coverage_exclusions: list[str] = Field(default_factory=list)
# Case shape
complexity: str = Field(default="standard")
requires_documents: list[str] = Field(default_factory=list)
requires_escalation: bool = Field(default=False)
# Episode meters
actions_taken: int = Field(default=0)
queries_made: int = Field(default=0)
time_elapsed_minutes: int = Field(default=0)
# Per-channel "have we asked yet" flags
policy_queried: bool = Field(default=False)
history_queried: bool = Field(default=False)
fraud_checked: bool = Field(default=False)
documents_requested: bool = Field(default=False)
coverage_verified: bool = Field(default=False)
payout_calculated: bool = Field(default=False)
# Decision the agent ultimately landed on
agent_decision: str = Field(default="")
agent_payout: float = Field(default=0.0)
decision_reason: str = Field(default="")
# Reward decomposition (kept for analysis dashboards)
correctness_reward: float = Field(default=0.0)
efficiency_reward: float = Field(default=0.0)
fraud_detection_reward: float = Field(default=0.0)
total_reward: float = Field(default=0.0)
# --- Compatibility aliases -----------------------------------------------
# OpenEnv's serialiser, plus a small number of older snippets, look up the
# original class names. Keeping aliases avoids silent runtime breakage.
ClaimsAction = AdjudicatorAction
ClaimsObservation = AdjudicatorObservation
ClaimsState = AdjudicatorState
__all__ = [
"AdjudicatorAction",
"AdjudicatorObservation",
"AdjudicatorState",
"ClaimsAction",
"ClaimsObservation",
"ClaimsState",
"INFORMATION_ACTIONS",
"TERMINAL_ACTIONS",
"ALL_ACTIONS",
]