Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Protocol | |
| from openai import OpenAI | |
| try: | |
| from dotenv import load_dotenv | |
| except ImportError: # pragma: no cover - optional dependency for local convenience. | |
| def load_dotenv() -> bool: | |
| return False | |
| try: | |
| from openenv.core.client_types import StepResult | |
| except ImportError: | |
| from enterprise_finance_env._compat import StepResult | |
| from enterprise_finance_env.client import EnterpriseFinanceClient | |
| from enterprise_finance_env.models import ( | |
| ActionLike, | |
| ApplyForexAdjustment, | |
| EnterpriseFinanceActionPayload, | |
| EnterpriseFinanceObservation, | |
| EnterpriseFinanceState, | |
| LinkTransactions, | |
| PostEliminationEntry, | |
| QuerySubledger, | |
| ) | |
| load_dotenv() | |
| DEFAULT_ENV_BASE_URL = "https://prasham1710-enterprise-finance-env.hf.space" | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL") | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL", DEFAULT_ENV_BASE_URL) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) | |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "300")) | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "8")) | |
| SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation. | |
| Use exactly one tool call on each turn. Prefer safe, incremental actions: | |
| - QuerySubledger to narrow the ledger slice. | |
| - LinkTransactions only when the debit-credit pair is clearly valid. | |
| - ApplyForexAdjustment only when the observation exposes an approved rate and date. | |
| - PostEliminationEntry only when you are confident the residual and account are correct. | |
| Never invent transaction ids, FX rates, dates, or accounts. | |
| """ | |
| JSON_FALLBACK_SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation. | |
| Reply with exactly one JSON object and nothing else. | |
| The JSON object must match one of these shapes: | |
| {"type":"query_subledger","entity":"PARENT_US","account_code":"IC_AR","date_range":["2026-01-01","2026-01-31"]} | |
| {"type":"link_transactions","debit_txn_id":"TXN1","credit_txn_id":"TXN2","rationale":"Explain the match."} | |
| {"type":"apply_forex_adjustment","txn_id":"TXN1","exchange_rate":1.3025,"date":"2026-02-05"} | |
| {"type":"post_elimination_entry","entity_id":"GROUP","amount":12.34,"account":"IC_FX_ELIM_CLEARING"} | |
| Choose exactly one action for this turn. Do not emit multiple actions. Do not use markdown fences. | |
| """ | |
| class EpisodeClient(Protocol): | |
| async def reset(self, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]: | |
| ... | |
| async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]: | |
| ... | |
| async def state(self) -> EnterpriseFinanceState: | |
| ... | |
| class EpisodeSummary: | |
| final_result: StepResult[EnterpriseFinanceObservation] | |
| final_state: EnterpriseFinanceState | |
| def _normalize_reference(reference: str) -> str: | |
| return "-".join(reference.split("-")[:-1]) if "-" in reference else reference | |
| def _date_bounds(rows: list[Any]) -> tuple[str, str]: | |
| dates = sorted(row.txn_date for row in rows) | |
| return dates[0], dates[-1] | |
| def _pair_by_reference(observation: EnterpriseFinanceObservation) -> list[tuple[str, str]]: | |
| grouped: dict[str, dict[str, str]] = {} | |
| for row in observation.structured_ledgers: | |
| bucket = grouped.setdefault(_normalize_reference(row.reference), {}) | |
| bucket[row.side] = row.txn_id | |
| pairs: list[tuple[str, str]] = [] | |
| for bucket in grouped.values(): | |
| debit = bucket.get("debit") | |
| credit = bucket.get("credit") | |
| if debit and credit: | |
| pairs.append((debit, credit)) | |
| return pairs | |
| def _medium_instructions(observation: EnterpriseFinanceObservation) -> list[dict[str, Any]]: | |
| instructions: list[dict[str, Any]] = [] | |
| for blob in observation.unstructured_invoices: | |
| data = json.loads(blob) if isinstance(blob, str) else blob | |
| instructions.append( | |
| { | |
| "debit_txn_id": data["debit_txn_id"], | |
| "credit_txn_id": data["credit_txn_id"], | |
| "fx_txn_id": data["fx_txn_id"], | |
| "fx_date": data["fx_date"], | |
| "fx_rate": float(data["fx_rate"]), | |
| "approved_fx_drift_usd": float(data["approved_fx_drift_usd"]), | |
| } | |
| ) | |
| return instructions | |
| def _terminal_posting(observation: EnterpriseFinanceObservation, difficulty: str) -> tuple[float, str]: | |
| if difficulty == "easy": | |
| return 0.0, "IC_ELIM_CLEARING" | |
| if difficulty == "medium": | |
| total = sum( | |
| float(blob["approved_fx_drift_usd"]) | |
| if not isinstance(blob, str) | |
| else float(json.loads(blob)["approved_fx_drift_usd"]) | |
| for blob in observation.unstructured_invoices | |
| ) | |
| return round(total, 2), "IC_FX_ELIM_CLEARING" | |
| total = sum( | |
| float(blob["approved_write_off_usd"]) | |
| if not isinstance(blob, str) | |
| else float(json.loads(blob)["approved_write_off_usd"]) | |
| for blob in observation.dispute_inbox | |
| ) | |
| return round(total, 2), "IC_WRITE_OFF_CLEARING" | |
| def _observation_digest(observation: EnterpriseFinanceObservation) -> dict[str, Any]: | |
| rows = observation.structured_ledgers | |
| sample_rows = [ | |
| { | |
| "txn_id": row.txn_id, | |
| "entity_id": row.entity_id, | |
| "counterparty_entity_id": row.counterparty_entity_id, | |
| "account_code": row.account_code, | |
| "side": row.side, | |
| "amount": row.amount, | |
| "currency": row.currency, | |
| "txn_date": row.txn_date, | |
| "reference": row.reference, | |
| "has_fx_adjustment": row.has_fx_adjustment, | |
| "disputed": row.disputed, | |
| } | |
| for row in rows[:12] | |
| ] | |
| invoice_sample = [ | |
| blob if isinstance(blob, dict) else json.loads(blob) | |
| for blob in observation.unstructured_invoices[:4] | |
| ] | |
| dispute_sample = [ | |
| blob if isinstance(blob, dict) else json.loads(blob) | |
| for blob in observation.dispute_inbox[:4] | |
| ] | |
| return { | |
| "feedback_message": observation.feedback_message, | |
| "unmatched_row_count": len(rows), | |
| "sample_rows": sample_rows, | |
| "fx_rates_sample": dict(list(observation.fx_rates.items())[:10]), | |
| "unstructured_invoices_sample": invoice_sample, | |
| "dispute_inbox_sample": dispute_sample, | |
| } | |
| def _build_user_prompt( | |
| step: int, | |
| observation: EnterpriseFinanceObservation, | |
| state: EnterpriseFinanceState, | |
| history: list[str], | |
| ) -> str: | |
| history_block = "\n".join(history[-4:]) if history else "None" | |
| prompt_payload = { | |
| "step": step, | |
| "difficulty": state.difficulty, | |
| "step_count": state.step_count, | |
| "open_abs_balance": state.open_abs_balance, | |
| "unmatched_valid_count": state.unmatched_valid_count, | |
| "total_reward": state.total_reward, | |
| "observation": _observation_digest(observation), | |
| "recent_history": history_block, | |
| } | |
| return json.dumps(prompt_payload, indent=2) | |
| def _extract_json_block(content: str) -> str: | |
| stripped = content.strip() | |
| if stripped.startswith("```"): | |
| stripped = re.sub(r"^```(?:json)?", "", stripped).strip() | |
| stripped = re.sub(r"```$", "", stripped).strip() | |
| return stripped | |
| def _build_tools() -> list[dict[str, Any]]: | |
| return [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "query_subledger", | |
| "description": "Query a ledger slice for one entity and account code over a date range.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "entity": {"type": "string"}, | |
| "account_code": {"type": "string"}, | |
| "date_range": { | |
| "type": "array", | |
| "items": {"type": "string"}, | |
| "minItems": 2, | |
| "maxItems": 2, | |
| }, | |
| }, | |
| "required": ["entity", "account_code", "date_range"], | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "link_transactions", | |
| "description": "Link a debit and credit transaction when they are valid elimination counterparts.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "debit_txn_id": {"type": "string"}, | |
| "credit_txn_id": {"type": "string"}, | |
| "rationale": {"type": "string"}, | |
| }, | |
| "required": ["debit_txn_id", "credit_txn_id", "rationale"], | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "apply_forex_adjustment", | |
| "description": "Apply an approved FX adjustment using the exact published rate and date.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "txn_id": {"type": "string"}, | |
| "exchange_rate": {"type": "number"}, | |
| "date": {"type": "string"}, | |
| }, | |
| "required": ["txn_id", "exchange_rate", "date"], | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "post_elimination_entry", | |
| "description": "Submit the final GROUP elimination entry once the residual is fully understood.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "entity_id": {"type": "string"}, | |
| "amount": {"type": "number"}, | |
| "account": {"type": "string"}, | |
| }, | |
| "required": ["entity_id", "amount", "account"], | |
| "additionalProperties": False, | |
| }, | |
| }, | |
| }, | |
| ] | |
| def _tool_call_to_action(name: str, arguments: dict[str, Any]) -> ActionLike: | |
| if name == "query_subledger": | |
| return QuerySubledger( | |
| entity=arguments["entity"], | |
| account_code=arguments["account_code"], | |
| date_range=tuple(arguments["date_range"]), | |
| ) | |
| if name == "link_transactions": | |
| return LinkTransactions( | |
| debit_txn_id=arguments["debit_txn_id"], | |
| credit_txn_id=arguments["credit_txn_id"], | |
| rationale=arguments["rationale"], | |
| ) | |
| if name == "apply_forex_adjustment": | |
| return ApplyForexAdjustment( | |
| txn_id=arguments["txn_id"], | |
| exchange_rate=arguments["exchange_rate"], | |
| date=arguments["date"], | |
| ) | |
| if name == "post_elimination_entry": | |
| return PostEliminationEntry( | |
| entity_id=arguments["entity_id"], | |
| amount=arguments["amount"], | |
| account=arguments["account"], | |
| ) | |
| raise ValueError(f"Unsupported tool call: {name}") | |
| def _json_dict_to_action(payload: dict[str, Any]) -> ActionLike: | |
| return EnterpriseFinanceActionPayload.model_validate(payload).root | |
| def _fallback_action(observation: EnterpriseFinanceObservation) -> ActionLike: | |
| if observation.structured_ledgers: | |
| start_date, end_date = _date_bounds(observation.structured_ledgers) | |
| first_row = observation.structured_ledgers[0] | |
| return QuerySubledger( | |
| entity=first_row.entity_id, | |
| account_code=first_row.account_code, | |
| date_range=(start_date, end_date), | |
| ) | |
| return PostEliminationEntry( | |
| entity_id="GROUP", | |
| amount=0.0, | |
| account="IC_ELIM_CLEARING", | |
| ) | |
| def _format_action(action: ActionLike) -> str: | |
| payload = ( | |
| action.model_dump(mode="json") | |
| if hasattr(action, "model_dump") | |
| else EnterpriseFinanceActionPayload.model_validate(action).model_dump(mode="json") | |
| ) | |
| return json.dumps(payload, separators=(",", ":")) | |
| def _provider_prefers_json_fallback(api_base_url: str) -> bool: | |
| return "groq.com" in api_base_url.lower() | |
| def _fallback_json_completion( | |
| *, | |
| llm_client: OpenAI, | |
| model: str, | |
| user_prompt: str, | |
| temperature: float, | |
| max_tokens: int, | |
| ) -> ActionLike: | |
| completion = llm_client.chat.completions.create( | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| messages=[ | |
| {"role": "system", "content": JSON_FALLBACK_SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| content = completion.choices[0].message.content or "" | |
| payload = json.loads(_extract_json_block(content)) | |
| return _json_dict_to_action(payload) | |
| def _print_step_trace( | |
| step_index: int, | |
| action: ActionLike, | |
| result: StepResult[EnterpriseFinanceObservation], | |
| state: EnterpriseFinanceState, | |
| ) -> None: | |
| print(f"Step {step_index}: { _format_action(action) }") | |
| print( | |
| " reward=" | |
| f"{(result.reward or 0.0):+.2f}" | |
| f" done={result.done}" | |
| f" open_balance={state.open_abs_balance:.2f}" | |
| f" unmatched={state.unmatched_valid_count}" | |
| f" total_reward={state.total_reward:.2f}" | |
| f" score={state.final_score}" | |
| ) | |
| print(f" feedback={result.observation.feedback_message}") | |
| async def run_heuristic_episode(client: EpisodeClient, difficulty: str) -> EpisodeSummary: | |
| reset_result = await client.reset(difficulty=difficulty) | |
| observation = reset_result.observation | |
| if observation.structured_ledgers: | |
| start_date, end_date = _date_bounds(observation.structured_ledgers) | |
| first_row = observation.structured_ledgers[0] | |
| await client.step( | |
| QuerySubledger( | |
| entity=first_row.entity_id, | |
| account_code=first_row.account_code, | |
| date_range=(start_date, end_date), | |
| ) | |
| ) | |
| if difficulty == "medium": | |
| for instruction in _medium_instructions(observation): | |
| await client.step( | |
| ApplyForexAdjustment( | |
| txn_id=instruction["fx_txn_id"], | |
| exchange_rate=instruction["fx_rate"], | |
| date=instruction["fx_date"], | |
| ) | |
| ) | |
| result = await client.step( | |
| LinkTransactions( | |
| debit_txn_id=instruction["debit_txn_id"], | |
| credit_txn_id=instruction["credit_txn_id"], | |
| rationale="Invoice blob, legal entities, and approved FX rate align.", | |
| ) | |
| ) | |
| if result.done: | |
| break | |
| else: | |
| for debit_txn_id, credit_txn_id in _pair_by_reference(observation): | |
| result = await client.step( | |
| LinkTransactions( | |
| debit_txn_id=debit_txn_id, | |
| credit_txn_id=credit_txn_id, | |
| rationale="Counterparty references and debit-credit orientation agree.", | |
| ) | |
| ) | |
| if result.done: | |
| break | |
| amount, account = _terminal_posting(observation, difficulty) | |
| final_result = await client.step( | |
| PostEliminationEntry(entity_id="GROUP", amount=amount, account=account) | |
| ) | |
| final_state = await client.state() | |
| return EpisodeSummary(final_result=final_result, final_state=final_state) | |
| async def run_openai_episode( | |
| client: EpisodeClient, | |
| *, | |
| llm_client: OpenAI, | |
| api_base_url: str, | |
| difficulty: str, | |
| model: str, | |
| max_steps: int, | |
| temperature: float, | |
| max_tokens: int, | |
| ) -> EpisodeSummary: | |
| history: list[str] = [] | |
| tools = _build_tools() | |
| result = await client.reset(difficulty=difficulty) | |
| current_state = await client.state() | |
| for step_index in range(1, max_steps + 1): | |
| user_prompt = _build_user_prompt( | |
| step_index, | |
| result.observation, | |
| current_state, | |
| history, | |
| ) | |
| action: ActionLike | |
| try: | |
| if _provider_prefers_json_fallback(api_base_url): | |
| action = _fallback_json_completion( | |
| llm_client=llm_client, | |
| model=model, | |
| user_prompt=user_prompt, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| else: | |
| completion = llm_client.chat.completions.create( | |
| model=model, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| tool_choice="required", | |
| parallel_tool_calls=False, | |
| tools=tools, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| message = completion.choices[0].message | |
| tool_call = message.tool_calls[0] if getattr(message, "tool_calls", None) else None | |
| if tool_call is None: | |
| action = _fallback_action(result.observation) | |
| else: | |
| arguments = json.loads(tool_call.function.arguments or "{}") | |
| action = _tool_call_to_action(tool_call.function.name, arguments) | |
| except Exception as exc: # noqa: BLE001 | |
| if "tool_use_failed" not in str(exc) and "Failed to call a function" not in str(exc): | |
| raise | |
| action = _fallback_json_completion( | |
| llm_client=llm_client, | |
| model=model, | |
| user_prompt=user_prompt, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| result = await client.step(action) | |
| current_state = await client.state() | |
| history.append( | |
| f"step={step_index} reward={(result.reward or 0.0):+.2f} " | |
| f"unmatched={current_state.unmatched_valid_count} " | |
| f"open_balance={current_state.open_abs_balance:.2f}" | |
| ) | |
| _print_step_trace(step_index, action, result, current_state) | |
| if result.done: | |
| break | |
| return EpisodeSummary(final_result=result, final_state=current_state) | |
| def _print_summary(summary: EpisodeSummary) -> None: | |
| print("Final Result") | |
| print(f" reward={summary.final_result.reward}") | |
| print(f" done={summary.final_result.done}") | |
| print(f" feedback={summary.final_result.observation.feedback_message}") | |
| state = summary.final_state | |
| compact_state = { | |
| "difficulty": state.difficulty, | |
| "scenario_id": state.scenario_id, | |
| "step_count": state.step_count, | |
| "open_abs_balance": state.open_abs_balance, | |
| "linked_pairs": len(state.linked_pairs), | |
| "fx_adjustments": len(state.fx_adjustments), | |
| "invalid_actions": state.invalid_actions, | |
| "unmatched_valid_count": state.unmatched_valid_count, | |
| "expected_terminal_amount": state.expected_terminal_amount, | |
| "expected_terminal_account": state.expected_terminal_account, | |
| "force_balance_attempted": state.force_balance_attempted, | |
| "final_score": state.final_score, | |
| "total_reward": state.total_reward, | |
| } | |
| print(json.dumps(compact_state, indent=2)) | |
| async def _main_async(args: argparse.Namespace) -> None: | |
| try: | |
| async with EnterpriseFinanceClient(base_url=args.env_base_url) as client: | |
| if args.policy == "heuristic": | |
| summary = await run_heuristic_episode(client, difficulty=args.difficulty) | |
| else: | |
| if not API_KEY: | |
| raise RuntimeError( | |
| "HF_TOKEN (or API_KEY) must be set for --policy openai." | |
| ) | |
| if not args.model_name: | |
| raise RuntimeError("MODEL_NAME (or --model-name) must be set for --policy openai.") | |
| llm_client = OpenAI(base_url=args.api_base_url, api_key=API_KEY) | |
| summary = await run_openai_episode( | |
| client, | |
| llm_client=llm_client, | |
| api_base_url=args.api_base_url, | |
| difficulty=args.difficulty, | |
| model=args.model_name, | |
| max_steps=args.max_steps, | |
| temperature=args.temperature, | |
| max_tokens=args.max_tokens, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| raise SystemExit( | |
| f"Failed to run inference against environment API at {args.env_base_url}: {exc}" | |
| ) from exc | |
| _print_summary(summary) | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run enterprise finance OpenEnv baselines.") | |
| parser.add_argument("--env-base-url", default=ENV_BASE_URL) | |
| parser.add_argument("--api-base-url", default=API_BASE_URL) | |
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") | |
| parser.add_argument("--policy", choices=["openai", "heuristic"], default="openai") | |
| parser.add_argument("--model-name", default=MODEL_NAME) | |
| parser.add_argument("--max-steps", type=int, default=MAX_STEPS) | |
| parser.add_argument("--temperature", type=float, default=TEMPERATURE) | |
| parser.add_argument("--max-tokens", type=int, default=MAX_TOKENS) | |
| return parser | |
| def main() -> None: | |
| parser = build_parser() | |
| args = parser.parse_args() | |
| asyncio.run(_main_async(args)) | |
| if __name__ == "__main__": | |
| main() | |