Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from datetime import date | |
| from typing import Annotated, Any, Literal | |
| from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator | |
| try: | |
| from openenv.core.env_server.types import Action, Observation, State | |
| except ImportError: | |
| from enterprise_finance_env._compat import Action, Observation, State | |
| class StructuredLedgerRow(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| txn_id: str = Field(..., min_length=3) | |
| entity_id: str = Field(..., min_length=2) | |
| counterparty_entity_id: str = Field(..., min_length=2) | |
| account_code: str = Field(..., min_length=2) | |
| side: Literal["debit", "credit"] | |
| amount: float = Field(..., gt=0) | |
| currency: str = Field(..., min_length=3, max_length=3) | |
| txn_date: str | |
| reference: str = Field(..., min_length=3) | |
| matched: bool = False | |
| eligible_for_linking: bool = True | |
| has_fx_adjustment: bool = False | |
| disputed: bool = False | |
| def validate_txn_date(cls, value: str) -> str: | |
| date.fromisoformat(value) | |
| return value | |
| def normalize_currency(cls, value: str) -> str: | |
| return value.upper() | |
| class EnterpriseFinanceObservation(Observation): | |
| structured_ledgers: list[StructuredLedgerRow] = Field(default_factory=list) | |
| unstructured_invoices: list[dict[str, Any] | str] = Field(default_factory=list) | |
| fx_rates: dict[str, float] = Field(default_factory=dict) | |
| dispute_inbox: list[dict[str, Any] | str] = Field(default_factory=list) | |
| feedback_message: str = Field(default="") | |
| class EnterpriseFinanceState(State): | |
| difficulty: Literal["easy", "medium", "hard"] = "easy" | |
| scenario_id: str = Field(default="") | |
| initial_abs_balance: float = Field(default=0.0, ge=0) | |
| open_abs_balance: float = Field(default=0.0, ge=0) | |
| linked_pairs: list[tuple[str, str]] = Field(default_factory=list) | |
| fx_adjustments: dict[str, float] = Field(default_factory=dict) | |
| invalid_actions: int = Field(default=0, ge=0) | |
| deleted_txn_ids: list[str] = Field(default_factory=list) | |
| expected_terminal_amount: float = Field(default=0.0) | |
| expected_terminal_account: str = Field(default="") | |
| audit_flags: list[str] = Field(default_factory=list) | |
| valid_event_count: int = Field(default=0, ge=0) | |
| unmatched_valid_count: int = Field(default=0, ge=0) | |
| total_reward: float = Field(default=0.0) | |
| final_score: float | None = Field(default=None, ge=0, le=1) | |
| force_balance_attempted: bool = False | |
| class QuerySubledger(Action): | |
| type: Literal["query_subledger"] = "query_subledger" | |
| entity: str = Field(..., min_length=2) | |
| account_code: str = Field(..., min_length=2) | |
| date_range: tuple[str, str] | |
| def validate_date_range(cls, value: tuple[str, str]) -> tuple[str, str]: | |
| start = date.fromisoformat(value[0]) | |
| end = date.fromisoformat(value[1]) | |
| if start > end: | |
| raise ValueError("date_range start must be on or before end") | |
| return value | |
| class LinkTransactions(Action): | |
| type: Literal["link_transactions"] = "link_transactions" | |
| debit_txn_id: str = Field(..., min_length=3) | |
| credit_txn_id: str = Field(..., min_length=3) | |
| rationale: str = Field(..., min_length=5, max_length=500) | |
| class ApplyForexAdjustment(Action): | |
| type: Literal["apply_forex_adjustment"] = "apply_forex_adjustment" | |
| txn_id: str = Field(..., min_length=3) | |
| exchange_rate: float = Field(..., gt=0) | |
| date: str | |
| def validate_date(cls, value: str) -> str: | |
| date.fromisoformat(value) | |
| return value | |
| class PostEliminationEntry(Action): | |
| type: Literal["post_elimination_entry"] = "post_elimination_entry" | |
| entity_id: str = Field(..., min_length=2) | |
| amount: float | |
| account: str = Field(..., min_length=2) | |
| EnterpriseFinanceAction = Annotated[ | |
| QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry, | |
| Field(discriminator="type"), | |
| ] | |
| class EnterpriseFinanceActionPayload(RootModel[EnterpriseFinanceAction]): | |
| root: EnterpriseFinanceAction | |
| ActionLike = ( | |
| QuerySubledger | |
| | LinkTransactions | |
| | ApplyForexAdjustment | |
| | PostEliminationEntry | |
| | EnterpriseFinanceActionPayload | |
| ) | |
| class EnterpriseFinanceStepPayload(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| observation: EnterpriseFinanceObservation | |
| reward: float | None = None | |
| done: bool = False | |
| def unwrap_action(action: ActionLike) -> QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry: | |
| if isinstance(action, EnterpriseFinanceActionPayload): | |
| return action.root | |
| return action | |