from __future__ import annotations from datetime import date from typing import Annotated, Any, Literal from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator try: from openenv.core.env_server.types import Action, Observation, State except ImportError: from enterprise_finance_env._compat import Action, Observation, State class StructuredLedgerRow(BaseModel): model_config = ConfigDict(extra="forbid") txn_id: str = Field(..., min_length=3) entity_id: str = Field(..., min_length=2) counterparty_entity_id: str = Field(..., min_length=2) account_code: str = Field(..., min_length=2) side: Literal["debit", "credit"] amount: float = Field(..., gt=0) currency: str = Field(..., min_length=3, max_length=3) txn_date: str reference: str = Field(..., min_length=3) matched: bool = False eligible_for_linking: bool = True has_fx_adjustment: bool = False disputed: bool = False @field_validator("txn_date") @classmethod def validate_txn_date(cls, value: str) -> str: date.fromisoformat(value) return value @field_validator("currency") @classmethod def normalize_currency(cls, value: str) -> str: return value.upper() class EnterpriseFinanceObservation(Observation): structured_ledgers: list[StructuredLedgerRow] = Field(default_factory=list) unstructured_invoices: list[dict[str, Any] | str] = Field(default_factory=list) fx_rates: dict[str, float] = Field(default_factory=dict) dispute_inbox: list[dict[str, Any] | str] = Field(default_factory=list) feedback_message: str = Field(default="") class EnterpriseFinanceState(State): difficulty: Literal["easy", "medium", "hard"] = "easy" scenario_id: str = Field(default="") initial_abs_balance: float = Field(default=0.0, ge=0) open_abs_balance: float = Field(default=0.0, ge=0) linked_pairs: list[tuple[str, str]] = Field(default_factory=list) fx_adjustments: dict[str, float] = Field(default_factory=dict) invalid_actions: int = Field(default=0, ge=0) deleted_txn_ids: list[str] = Field(default_factory=list) expected_terminal_amount: float = Field(default=0.0) expected_terminal_account: str = Field(default="") audit_flags: list[str] = Field(default_factory=list) valid_event_count: int = Field(default=0, ge=0) unmatched_valid_count: int = Field(default=0, ge=0) total_reward: float = Field(default=0.0) final_score: float | None = Field(default=None, ge=0, le=1) force_balance_attempted: bool = False class QuerySubledger(Action): type: Literal["query_subledger"] = "query_subledger" entity: str = Field(..., min_length=2) account_code: str = Field(..., min_length=2) date_range: tuple[str, str] @field_validator("date_range") @classmethod def validate_date_range(cls, value: tuple[str, str]) -> tuple[str, str]: start = date.fromisoformat(value[0]) end = date.fromisoformat(value[1]) if start > end: raise ValueError("date_range start must be on or before end") return value class LinkTransactions(Action): type: Literal["link_transactions"] = "link_transactions" debit_txn_id: str = Field(..., min_length=3) credit_txn_id: str = Field(..., min_length=3) rationale: str = Field(..., min_length=5, max_length=500) class ApplyForexAdjustment(Action): type: Literal["apply_forex_adjustment"] = "apply_forex_adjustment" txn_id: str = Field(..., min_length=3) exchange_rate: float = Field(..., gt=0) date: str @field_validator("date") @classmethod def validate_date(cls, value: str) -> str: date.fromisoformat(value) return value class PostEliminationEntry(Action): type: Literal["post_elimination_entry"] = "post_elimination_entry" entity_id: str = Field(..., min_length=2) amount: float account: str = Field(..., min_length=2) EnterpriseFinanceAction = Annotated[ QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry, Field(discriminator="type"), ] class EnterpriseFinanceActionPayload(RootModel[EnterpriseFinanceAction]): root: EnterpriseFinanceAction ActionLike = ( QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry | EnterpriseFinanceActionPayload ) class EnterpriseFinanceStepPayload(BaseModel): model_config = ConfigDict(extra="forbid") observation: EnterpriseFinanceObservation reward: float | None = None done: bool = False def unwrap_action(action: ActionLike) -> QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry: if isinstance(action, EnterpriseFinanceActionPayload): return action.root return action