Prasham1710's picture
first commit
4ccc966
from __future__ import annotations
from datetime import date
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
try:
from openenv.core.env_server.types import Action, Observation, State
except ImportError:
from enterprise_finance_env._compat import Action, Observation, State
class StructuredLedgerRow(BaseModel):
model_config = ConfigDict(extra="forbid")
txn_id: str = Field(..., min_length=3)
entity_id: str = Field(..., min_length=2)
counterparty_entity_id: str = Field(..., min_length=2)
account_code: str = Field(..., min_length=2)
side: Literal["debit", "credit"]
amount: float = Field(..., gt=0)
currency: str = Field(..., min_length=3, max_length=3)
txn_date: str
reference: str = Field(..., min_length=3)
matched: bool = False
eligible_for_linking: bool = True
has_fx_adjustment: bool = False
disputed: bool = False
@field_validator("txn_date")
@classmethod
def validate_txn_date(cls, value: str) -> str:
date.fromisoformat(value)
return value
@field_validator("currency")
@classmethod
def normalize_currency(cls, value: str) -> str:
return value.upper()
class EnterpriseFinanceObservation(Observation):
structured_ledgers: list[StructuredLedgerRow] = Field(default_factory=list)
unstructured_invoices: list[dict[str, Any] | str] = Field(default_factory=list)
fx_rates: dict[str, float] = Field(default_factory=dict)
dispute_inbox: list[dict[str, Any] | str] = Field(default_factory=list)
feedback_message: str = Field(default="")
class EnterpriseFinanceState(State):
difficulty: Literal["easy", "medium", "hard"] = "easy"
scenario_id: str = Field(default="")
initial_abs_balance: float = Field(default=0.0, ge=0)
open_abs_balance: float = Field(default=0.0, ge=0)
linked_pairs: list[tuple[str, str]] = Field(default_factory=list)
fx_adjustments: dict[str, float] = Field(default_factory=dict)
invalid_actions: int = Field(default=0, ge=0)
deleted_txn_ids: list[str] = Field(default_factory=list)
expected_terminal_amount: float = Field(default=0.0)
expected_terminal_account: str = Field(default="")
audit_flags: list[str] = Field(default_factory=list)
valid_event_count: int = Field(default=0, ge=0)
unmatched_valid_count: int = Field(default=0, ge=0)
total_reward: float = Field(default=0.0)
final_score: float | None = Field(default=None, ge=0, le=1)
force_balance_attempted: bool = False
class QuerySubledger(Action):
type: Literal["query_subledger"] = "query_subledger"
entity: str = Field(..., min_length=2)
account_code: str = Field(..., min_length=2)
date_range: tuple[str, str]
@field_validator("date_range")
@classmethod
def validate_date_range(cls, value: tuple[str, str]) -> tuple[str, str]:
start = date.fromisoformat(value[0])
end = date.fromisoformat(value[1])
if start > end:
raise ValueError("date_range start must be on or before end")
return value
class LinkTransactions(Action):
type: Literal["link_transactions"] = "link_transactions"
debit_txn_id: str = Field(..., min_length=3)
credit_txn_id: str = Field(..., min_length=3)
rationale: str = Field(..., min_length=5, max_length=500)
class ApplyForexAdjustment(Action):
type: Literal["apply_forex_adjustment"] = "apply_forex_adjustment"
txn_id: str = Field(..., min_length=3)
exchange_rate: float = Field(..., gt=0)
date: str
@field_validator("date")
@classmethod
def validate_date(cls, value: str) -> str:
date.fromisoformat(value)
return value
class PostEliminationEntry(Action):
type: Literal["post_elimination_entry"] = "post_elimination_entry"
entity_id: str = Field(..., min_length=2)
amount: float
account: str = Field(..., min_length=2)
EnterpriseFinanceAction = Annotated[
QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry,
Field(discriminator="type"),
]
class EnterpriseFinanceActionPayload(RootModel[EnterpriseFinanceAction]):
root: EnterpriseFinanceAction
ActionLike = (
QuerySubledger
| LinkTransactions
| ApplyForexAdjustment
| PostEliminationEntry
| EnterpriseFinanceActionPayload
)
class EnterpriseFinanceStepPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
observation: EnterpriseFinanceObservation
reward: float | None = None
done: bool = False
def unwrap_action(action: ActionLike) -> QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry:
if isinstance(action, EnterpriseFinanceActionPayload):
return action.root
return action