from __future__ import annotations import argparse import asyncio import json import os import re from dataclasses import dataclass from typing import Any, Protocol from openai import OpenAI try: from dotenv import load_dotenv except ImportError: # pragma: no cover - optional dependency for local convenience. def load_dotenv() -> bool: return False try: from openenv.core.client_types import StepResult except ImportError: from enterprise_finance_env._compat import StepResult from enterprise_finance_env.client import EnterpriseFinanceClient from enterprise_finance_env.models import ( ActionLike, ApplyForexAdjustment, EnterpriseFinanceActionPayload, EnterpriseFinanceObservation, EnterpriseFinanceState, LinkTransactions, PostEliminationEntry, QuerySubledger, ) load_dotenv() DEFAULT_ENV_BASE_URL = "https://prasham1710-enterprise-finance-env.hf.space" API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL") ENV_BASE_URL = os.getenv("ENV_BASE_URL", DEFAULT_ENV_BASE_URL) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "300")) MAX_STEPS = int(os.getenv("MAX_STEPS", "8")) SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation. Use exactly one tool call on each turn. Prefer safe, incremental actions: - QuerySubledger to narrow the ledger slice. - LinkTransactions only when the debit-credit pair is clearly valid. - ApplyForexAdjustment only when the observation exposes an approved rate and date. - PostEliminationEntry only when you are confident the residual and account are correct. Never invent transaction ids, FX rates, dates, or accounts. """ JSON_FALLBACK_SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation. Reply with exactly one JSON object and nothing else. The JSON object must match one of these shapes: {"type":"query_subledger","entity":"PARENT_US","account_code":"IC_AR","date_range":["2026-01-01","2026-01-31"]} {"type":"link_transactions","debit_txn_id":"TXN1","credit_txn_id":"TXN2","rationale":"Explain the match."} {"type":"apply_forex_adjustment","txn_id":"TXN1","exchange_rate":1.3025,"date":"2026-02-05"} {"type":"post_elimination_entry","entity_id":"GROUP","amount":12.34,"account":"IC_FX_ELIM_CLEARING"} Choose exactly one action for this turn. Do not emit multiple actions. Do not use markdown fences. """ class EpisodeClient(Protocol): async def reset(self, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]: ... async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]: ... async def state(self) -> EnterpriseFinanceState: ... @dataclass class EpisodeSummary: final_result: StepResult[EnterpriseFinanceObservation] final_state: EnterpriseFinanceState def _normalize_reference(reference: str) -> str: return "-".join(reference.split("-")[:-1]) if "-" in reference else reference def _date_bounds(rows: list[Any]) -> tuple[str, str]: dates = sorted(row.txn_date for row in rows) return dates[0], dates[-1] def _pair_by_reference(observation: EnterpriseFinanceObservation) -> list[tuple[str, str]]: grouped: dict[str, dict[str, str]] = {} for row in observation.structured_ledgers: bucket = grouped.setdefault(_normalize_reference(row.reference), {}) bucket[row.side] = row.txn_id pairs: list[tuple[str, str]] = [] for bucket in grouped.values(): debit = bucket.get("debit") credit = bucket.get("credit") if debit and credit: pairs.append((debit, credit)) return pairs def _medium_instructions(observation: EnterpriseFinanceObservation) -> list[dict[str, Any]]: instructions: list[dict[str, Any]] = [] for blob in observation.unstructured_invoices: data = json.loads(blob) if isinstance(blob, str) else blob instructions.append( { "debit_txn_id": data["debit_txn_id"], "credit_txn_id": data["credit_txn_id"], "fx_txn_id": data["fx_txn_id"], "fx_date": data["fx_date"], "fx_rate": float(data["fx_rate"]), "approved_fx_drift_usd": float(data["approved_fx_drift_usd"]), } ) return instructions def _terminal_posting(observation: EnterpriseFinanceObservation, difficulty: str) -> tuple[float, str]: if difficulty == "easy": return 0.0, "IC_ELIM_CLEARING" if difficulty == "medium": total = sum( float(blob["approved_fx_drift_usd"]) if not isinstance(blob, str) else float(json.loads(blob)["approved_fx_drift_usd"]) for blob in observation.unstructured_invoices ) return round(total, 2), "IC_FX_ELIM_CLEARING" total = sum( float(blob["approved_write_off_usd"]) if not isinstance(blob, str) else float(json.loads(blob)["approved_write_off_usd"]) for blob in observation.dispute_inbox ) return round(total, 2), "IC_WRITE_OFF_CLEARING" def _observation_digest(observation: EnterpriseFinanceObservation) -> dict[str, Any]: rows = observation.structured_ledgers sample_rows = [ { "txn_id": row.txn_id, "entity_id": row.entity_id, "counterparty_entity_id": row.counterparty_entity_id, "account_code": row.account_code, "side": row.side, "amount": row.amount, "currency": row.currency, "txn_date": row.txn_date, "reference": row.reference, "has_fx_adjustment": row.has_fx_adjustment, "disputed": row.disputed, } for row in rows[:12] ] invoice_sample = [ blob if isinstance(blob, dict) else json.loads(blob) for blob in observation.unstructured_invoices[:4] ] dispute_sample = [ blob if isinstance(blob, dict) else json.loads(blob) for blob in observation.dispute_inbox[:4] ] return { "feedback_message": observation.feedback_message, "unmatched_row_count": len(rows), "sample_rows": sample_rows, "fx_rates_sample": dict(list(observation.fx_rates.items())[:10]), "unstructured_invoices_sample": invoice_sample, "dispute_inbox_sample": dispute_sample, } def _build_user_prompt( step: int, observation: EnterpriseFinanceObservation, state: EnterpriseFinanceState, history: list[str], ) -> str: history_block = "\n".join(history[-4:]) if history else "None" prompt_payload = { "step": step, "difficulty": state.difficulty, "step_count": state.step_count, "open_abs_balance": state.open_abs_balance, "unmatched_valid_count": state.unmatched_valid_count, "total_reward": state.total_reward, "observation": _observation_digest(observation), "recent_history": history_block, } return json.dumps(prompt_payload, indent=2) def _extract_json_block(content: str) -> str: stripped = content.strip() if stripped.startswith("```"): stripped = re.sub(r"^```(?:json)?", "", stripped).strip() stripped = re.sub(r"```$", "", stripped).strip() return stripped def _build_tools() -> list[dict[str, Any]]: return [ { "type": "function", "function": { "name": "query_subledger", "description": "Query a ledger slice for one entity and account code over a date range.", "parameters": { "type": "object", "properties": { "entity": {"type": "string"}, "account_code": {"type": "string"}, "date_range": { "type": "array", "items": {"type": "string"}, "minItems": 2, "maxItems": 2, }, }, "required": ["entity", "account_code", "date_range"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "link_transactions", "description": "Link a debit and credit transaction when they are valid elimination counterparts.", "parameters": { "type": "object", "properties": { "debit_txn_id": {"type": "string"}, "credit_txn_id": {"type": "string"}, "rationale": {"type": "string"}, }, "required": ["debit_txn_id", "credit_txn_id", "rationale"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "apply_forex_adjustment", "description": "Apply an approved FX adjustment using the exact published rate and date.", "parameters": { "type": "object", "properties": { "txn_id": {"type": "string"}, "exchange_rate": {"type": "number"}, "date": {"type": "string"}, }, "required": ["txn_id", "exchange_rate", "date"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "post_elimination_entry", "description": "Submit the final GROUP elimination entry once the residual is fully understood.", "parameters": { "type": "object", "properties": { "entity_id": {"type": "string"}, "amount": {"type": "number"}, "account": {"type": "string"}, }, "required": ["entity_id", "amount", "account"], "additionalProperties": False, }, }, }, ] def _tool_call_to_action(name: str, arguments: dict[str, Any]) -> ActionLike: if name == "query_subledger": return QuerySubledger( entity=arguments["entity"], account_code=arguments["account_code"], date_range=tuple(arguments["date_range"]), ) if name == "link_transactions": return LinkTransactions( debit_txn_id=arguments["debit_txn_id"], credit_txn_id=arguments["credit_txn_id"], rationale=arguments["rationale"], ) if name == "apply_forex_adjustment": return ApplyForexAdjustment( txn_id=arguments["txn_id"], exchange_rate=arguments["exchange_rate"], date=arguments["date"], ) if name == "post_elimination_entry": return PostEliminationEntry( entity_id=arguments["entity_id"], amount=arguments["amount"], account=arguments["account"], ) raise ValueError(f"Unsupported tool call: {name}") def _json_dict_to_action(payload: dict[str, Any]) -> ActionLike: return EnterpriseFinanceActionPayload.model_validate(payload).root def _fallback_action(observation: EnterpriseFinanceObservation) -> ActionLike: if observation.structured_ledgers: start_date, end_date = _date_bounds(observation.structured_ledgers) first_row = observation.structured_ledgers[0] return QuerySubledger( entity=first_row.entity_id, account_code=first_row.account_code, date_range=(start_date, end_date), ) return PostEliminationEntry( entity_id="GROUP", amount=0.0, account="IC_ELIM_CLEARING", ) def _format_action(action: ActionLike) -> str: payload = ( action.model_dump(mode="json") if hasattr(action, "model_dump") else EnterpriseFinanceActionPayload.model_validate(action).model_dump(mode="json") ) return json.dumps(payload, separators=(",", ":")) def _provider_prefers_json_fallback(api_base_url: str) -> bool: return "groq.com" in api_base_url.lower() def _fallback_json_completion( *, llm_client: OpenAI, model: str, user_prompt: str, temperature: float, max_tokens: int, ) -> ActionLike: completion = llm_client.chat.completions.create( model=model, temperature=temperature, max_tokens=max_tokens, messages=[ {"role": "system", "content": JSON_FALLBACK_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], ) content = completion.choices[0].message.content or "" payload = json.loads(_extract_json_block(content)) return _json_dict_to_action(payload) def _print_step_trace( step_index: int, action: ActionLike, result: StepResult[EnterpriseFinanceObservation], state: EnterpriseFinanceState, ) -> None: print(f"Step {step_index}: { _format_action(action) }") print( " reward=" f"{(result.reward or 0.0):+.2f}" f" done={result.done}" f" open_balance={state.open_abs_balance:.2f}" f" unmatched={state.unmatched_valid_count}" f" total_reward={state.total_reward:.2f}" f" score={state.final_score}" ) print(f" feedback={result.observation.feedback_message}") async def run_heuristic_episode(client: EpisodeClient, difficulty: str) -> EpisodeSummary: reset_result = await client.reset(difficulty=difficulty) observation = reset_result.observation if observation.structured_ledgers: start_date, end_date = _date_bounds(observation.structured_ledgers) first_row = observation.structured_ledgers[0] await client.step( QuerySubledger( entity=first_row.entity_id, account_code=first_row.account_code, date_range=(start_date, end_date), ) ) if difficulty == "medium": for instruction in _medium_instructions(observation): await client.step( ApplyForexAdjustment( txn_id=instruction["fx_txn_id"], exchange_rate=instruction["fx_rate"], date=instruction["fx_date"], ) ) result = await client.step( LinkTransactions( debit_txn_id=instruction["debit_txn_id"], credit_txn_id=instruction["credit_txn_id"], rationale="Invoice blob, legal entities, and approved FX rate align.", ) ) if result.done: break else: for debit_txn_id, credit_txn_id in _pair_by_reference(observation): result = await client.step( LinkTransactions( debit_txn_id=debit_txn_id, credit_txn_id=credit_txn_id, rationale="Counterparty references and debit-credit orientation agree.", ) ) if result.done: break amount, account = _terminal_posting(observation, difficulty) final_result = await client.step( PostEliminationEntry(entity_id="GROUP", amount=amount, account=account) ) final_state = await client.state() return EpisodeSummary(final_result=final_result, final_state=final_state) async def run_openai_episode( client: EpisodeClient, *, llm_client: OpenAI, api_base_url: str, difficulty: str, model: str, max_steps: int, temperature: float, max_tokens: int, ) -> EpisodeSummary: history: list[str] = [] tools = _build_tools() result = await client.reset(difficulty=difficulty) current_state = await client.state() for step_index in range(1, max_steps + 1): user_prompt = _build_user_prompt( step_index, result.observation, current_state, history, ) action: ActionLike try: if _provider_prefers_json_fallback(api_base_url): action = _fallback_json_completion( llm_client=llm_client, model=model, user_prompt=user_prompt, temperature=temperature, max_tokens=max_tokens, ) else: completion = llm_client.chat.completions.create( model=model, temperature=temperature, max_tokens=max_tokens, tool_choice="required", parallel_tool_calls=False, tools=tools, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], ) message = completion.choices[0].message tool_call = message.tool_calls[0] if getattr(message, "tool_calls", None) else None if tool_call is None: action = _fallback_action(result.observation) else: arguments = json.loads(tool_call.function.arguments or "{}") action = _tool_call_to_action(tool_call.function.name, arguments) except Exception as exc: # noqa: BLE001 if "tool_use_failed" not in str(exc) and "Failed to call a function" not in str(exc): raise action = _fallback_json_completion( llm_client=llm_client, model=model, user_prompt=user_prompt, temperature=temperature, max_tokens=max_tokens, ) result = await client.step(action) current_state = await client.state() history.append( f"step={step_index} reward={(result.reward or 0.0):+.2f} " f"unmatched={current_state.unmatched_valid_count} " f"open_balance={current_state.open_abs_balance:.2f}" ) _print_step_trace(step_index, action, result, current_state) if result.done: break return EpisodeSummary(final_result=result, final_state=current_state) def _print_summary(summary: EpisodeSummary) -> None: print("Final Result") print(f" reward={summary.final_result.reward}") print(f" done={summary.final_result.done}") print(f" feedback={summary.final_result.observation.feedback_message}") state = summary.final_state compact_state = { "difficulty": state.difficulty, "scenario_id": state.scenario_id, "step_count": state.step_count, "open_abs_balance": state.open_abs_balance, "linked_pairs": len(state.linked_pairs), "fx_adjustments": len(state.fx_adjustments), "invalid_actions": state.invalid_actions, "unmatched_valid_count": state.unmatched_valid_count, "expected_terminal_amount": state.expected_terminal_amount, "expected_terminal_account": state.expected_terminal_account, "force_balance_attempted": state.force_balance_attempted, "final_score": state.final_score, "total_reward": state.total_reward, } print(json.dumps(compact_state, indent=2)) async def _main_async(args: argparse.Namespace) -> None: try: async with EnterpriseFinanceClient(base_url=args.env_base_url) as client: if args.policy == "heuristic": summary = await run_heuristic_episode(client, difficulty=args.difficulty) else: if not API_KEY: raise RuntimeError( "HF_TOKEN (or API_KEY) must be set for --policy openai." ) if not args.model_name: raise RuntimeError("MODEL_NAME (or --model-name) must be set for --policy openai.") llm_client = OpenAI(base_url=args.api_base_url, api_key=API_KEY) summary = await run_openai_episode( client, llm_client=llm_client, api_base_url=args.api_base_url, difficulty=args.difficulty, model=args.model_name, max_steps=args.max_steps, temperature=args.temperature, max_tokens=args.max_tokens, ) except Exception as exc: # noqa: BLE001 raise SystemExit( f"Failed to run inference against environment API at {args.env_base_url}: {exc}" ) from exc _print_summary(summary) def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run enterprise finance OpenEnv baselines.") parser.add_argument("--env-base-url", default=ENV_BASE_URL) parser.add_argument("--api-base-url", default=API_BASE_URL) parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") parser.add_argument("--policy", choices=["openai", "heuristic"], default="openai") parser.add_argument("--model-name", default=MODEL_NAME) parser.add_argument("--max-steps", type=int, default=MAX_STEPS) parser.add_argument("--temperature", type=float, default=TEMPERATURE) parser.add_argument("--max-tokens", type=int, default=MAX_TOKENS) return parser def main() -> None: parser = build_parser() args = parser.parse_args() asyncio.run(_main_async(args)) if __name__ == "__main__": main()