Spaces:
Sleeping
Sleeping
File size: 4,825 Bytes
4ccc966 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | from __future__ import annotations
from datetime import date
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
try:
from openenv.core.env_server.types import Action, Observation, State
except ImportError:
from enterprise_finance_env._compat import Action, Observation, State
class StructuredLedgerRow(BaseModel):
model_config = ConfigDict(extra="forbid")
txn_id: str = Field(..., min_length=3)
entity_id: str = Field(..., min_length=2)
counterparty_entity_id: str = Field(..., min_length=2)
account_code: str = Field(..., min_length=2)
side: Literal["debit", "credit"]
amount: float = Field(..., gt=0)
currency: str = Field(..., min_length=3, max_length=3)
txn_date: str
reference: str = Field(..., min_length=3)
matched: bool = False
eligible_for_linking: bool = True
has_fx_adjustment: bool = False
disputed: bool = False
@field_validator("txn_date")
@classmethod
def validate_txn_date(cls, value: str) -> str:
date.fromisoformat(value)
return value
@field_validator("currency")
@classmethod
def normalize_currency(cls, value: str) -> str:
return value.upper()
class EnterpriseFinanceObservation(Observation):
structured_ledgers: list[StructuredLedgerRow] = Field(default_factory=list)
unstructured_invoices: list[dict[str, Any] | str] = Field(default_factory=list)
fx_rates: dict[str, float] = Field(default_factory=dict)
dispute_inbox: list[dict[str, Any] | str] = Field(default_factory=list)
feedback_message: str = Field(default="")
class EnterpriseFinanceState(State):
difficulty: Literal["easy", "medium", "hard"] = "easy"
scenario_id: str = Field(default="")
initial_abs_balance: float = Field(default=0.0, ge=0)
open_abs_balance: float = Field(default=0.0, ge=0)
linked_pairs: list[tuple[str, str]] = Field(default_factory=list)
fx_adjustments: dict[str, float] = Field(default_factory=dict)
invalid_actions: int = Field(default=0, ge=0)
deleted_txn_ids: list[str] = Field(default_factory=list)
expected_terminal_amount: float = Field(default=0.0)
expected_terminal_account: str = Field(default="")
audit_flags: list[str] = Field(default_factory=list)
valid_event_count: int = Field(default=0, ge=0)
unmatched_valid_count: int = Field(default=0, ge=0)
total_reward: float = Field(default=0.0)
final_score: float | None = Field(default=None, ge=0, le=1)
force_balance_attempted: bool = False
class QuerySubledger(Action):
type: Literal["query_subledger"] = "query_subledger"
entity: str = Field(..., min_length=2)
account_code: str = Field(..., min_length=2)
date_range: tuple[str, str]
@field_validator("date_range")
@classmethod
def validate_date_range(cls, value: tuple[str, str]) -> tuple[str, str]:
start = date.fromisoformat(value[0])
end = date.fromisoformat(value[1])
if start > end:
raise ValueError("date_range start must be on or before end")
return value
class LinkTransactions(Action):
type: Literal["link_transactions"] = "link_transactions"
debit_txn_id: str = Field(..., min_length=3)
credit_txn_id: str = Field(..., min_length=3)
rationale: str = Field(..., min_length=5, max_length=500)
class ApplyForexAdjustment(Action):
type: Literal["apply_forex_adjustment"] = "apply_forex_adjustment"
txn_id: str = Field(..., min_length=3)
exchange_rate: float = Field(..., gt=0)
date: str
@field_validator("date")
@classmethod
def validate_date(cls, value: str) -> str:
date.fromisoformat(value)
return value
class PostEliminationEntry(Action):
type: Literal["post_elimination_entry"] = "post_elimination_entry"
entity_id: str = Field(..., min_length=2)
amount: float
account: str = Field(..., min_length=2)
EnterpriseFinanceAction = Annotated[
QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry,
Field(discriminator="type"),
]
class EnterpriseFinanceActionPayload(RootModel[EnterpriseFinanceAction]):
root: EnterpriseFinanceAction
ActionLike = (
QuerySubledger
| LinkTransactions
| ApplyForexAdjustment
| PostEliminationEntry
| EnterpriseFinanceActionPayload
)
class EnterpriseFinanceStepPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
observation: EnterpriseFinanceObservation
reward: float | None = None
done: bool = False
def unwrap_action(action: ActionLike) -> QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry:
if isinstance(action, EnterpriseFinanceActionPayload):
return action.root
return action
|