File size: 4,825 Bytes
4ccc966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations

from datetime import date
from typing import Annotated, Any, Literal

from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator

try:
    from openenv.core.env_server.types import Action, Observation, State
except ImportError:
    from enterprise_finance_env._compat import Action, Observation, State


class StructuredLedgerRow(BaseModel):
    model_config = ConfigDict(extra="forbid")

    txn_id: str = Field(..., min_length=3)
    entity_id: str = Field(..., min_length=2)
    counterparty_entity_id: str = Field(..., min_length=2)
    account_code: str = Field(..., min_length=2)
    side: Literal["debit", "credit"]
    amount: float = Field(..., gt=0)
    currency: str = Field(..., min_length=3, max_length=3)
    txn_date: str
    reference: str = Field(..., min_length=3)
    matched: bool = False
    eligible_for_linking: bool = True
    has_fx_adjustment: bool = False
    disputed: bool = False

    @field_validator("txn_date")
    @classmethod
    def validate_txn_date(cls, value: str) -> str:
        date.fromisoformat(value)
        return value

    @field_validator("currency")
    @classmethod
    def normalize_currency(cls, value: str) -> str:
        return value.upper()


class EnterpriseFinanceObservation(Observation):
    structured_ledgers: list[StructuredLedgerRow] = Field(default_factory=list)
    unstructured_invoices: list[dict[str, Any] | str] = Field(default_factory=list)
    fx_rates: dict[str, float] = Field(default_factory=dict)
    dispute_inbox: list[dict[str, Any] | str] = Field(default_factory=list)
    feedback_message: str = Field(default="")


class EnterpriseFinanceState(State):
    difficulty: Literal["easy", "medium", "hard"] = "easy"
    scenario_id: str = Field(default="")
    initial_abs_balance: float = Field(default=0.0, ge=0)
    open_abs_balance: float = Field(default=0.0, ge=0)
    linked_pairs: list[tuple[str, str]] = Field(default_factory=list)
    fx_adjustments: dict[str, float] = Field(default_factory=dict)
    invalid_actions: int = Field(default=0, ge=0)
    deleted_txn_ids: list[str] = Field(default_factory=list)
    expected_terminal_amount: float = Field(default=0.0)
    expected_terminal_account: str = Field(default="")
    audit_flags: list[str] = Field(default_factory=list)
    valid_event_count: int = Field(default=0, ge=0)
    unmatched_valid_count: int = Field(default=0, ge=0)
    total_reward: float = Field(default=0.0)
    final_score: float | None = Field(default=None, ge=0, le=1)
    force_balance_attempted: bool = False


class QuerySubledger(Action):
    type: Literal["query_subledger"] = "query_subledger"
    entity: str = Field(..., min_length=2)
    account_code: str = Field(..., min_length=2)
    date_range: tuple[str, str]

    @field_validator("date_range")
    @classmethod
    def validate_date_range(cls, value: tuple[str, str]) -> tuple[str, str]:
        start = date.fromisoformat(value[0])
        end = date.fromisoformat(value[1])
        if start > end:
            raise ValueError("date_range start must be on or before end")
        return value


class LinkTransactions(Action):
    type: Literal["link_transactions"] = "link_transactions"
    debit_txn_id: str = Field(..., min_length=3)
    credit_txn_id: str = Field(..., min_length=3)
    rationale: str = Field(..., min_length=5, max_length=500)


class ApplyForexAdjustment(Action):
    type: Literal["apply_forex_adjustment"] = "apply_forex_adjustment"
    txn_id: str = Field(..., min_length=3)
    exchange_rate: float = Field(..., gt=0)
    date: str

    @field_validator("date")
    @classmethod
    def validate_date(cls, value: str) -> str:
        date.fromisoformat(value)
        return value


class PostEliminationEntry(Action):
    type: Literal["post_elimination_entry"] = "post_elimination_entry"
    entity_id: str = Field(..., min_length=2)
    amount: float
    account: str = Field(..., min_length=2)


EnterpriseFinanceAction = Annotated[
    QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry,
    Field(discriminator="type"),
]


class EnterpriseFinanceActionPayload(RootModel[EnterpriseFinanceAction]):
    root: EnterpriseFinanceAction


ActionLike = (
    QuerySubledger
    | LinkTransactions
    | ApplyForexAdjustment
    | PostEliminationEntry
    | EnterpriseFinanceActionPayload
)


class EnterpriseFinanceStepPayload(BaseModel):
    model_config = ConfigDict(extra="forbid")

    observation: EnterpriseFinanceObservation
    reward: float | None = None
    done: bool = False


def unwrap_action(action: ActionLike) -> QuerySubledger | LinkTransactions | ApplyForexAdjustment | PostEliminationEntry:
    if isinstance(action, EnterpriseFinanceActionPayload):
        return action.root
    return action