Enterprise_Finance_env / inference.py
Prasham1710's picture
Add Groq-safe inference fallback
d9ced2a
from __future__ import annotations
import argparse
import asyncio
import json
import os
import re
from dataclasses import dataclass
from typing import Any, Protocol
from openai import OpenAI
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover - optional dependency for local convenience.
def load_dotenv() -> bool:
return False
try:
from openenv.core.client_types import StepResult
except ImportError:
from enterprise_finance_env._compat import StepResult
from enterprise_finance_env.client import EnterpriseFinanceClient
from enterprise_finance_env.models import (
ActionLike,
ApplyForexAdjustment,
EnterpriseFinanceActionPayload,
EnterpriseFinanceObservation,
EnterpriseFinanceState,
LinkTransactions,
PostEliminationEntry,
QuerySubledger,
)
load_dotenv()
DEFAULT_ENV_BASE_URL = "https://prasham1710-enterprise-finance-env.hf.space"
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL")
ENV_BASE_URL = os.getenv("ENV_BASE_URL", DEFAULT_ENV_BASE_URL)
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "300"))
MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation.
Use exactly one tool call on each turn. Prefer safe, incremental actions:
- QuerySubledger to narrow the ledger slice.
- LinkTransactions only when the debit-credit pair is clearly valid.
- ApplyForexAdjustment only when the observation exposes an approved rate and date.
- PostEliminationEntry only when you are confident the residual and account are correct.
Never invent transaction ids, FX rates, dates, or accounts.
"""
JSON_FALLBACK_SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation.
Reply with exactly one JSON object and nothing else.
The JSON object must match one of these shapes:
{"type":"query_subledger","entity":"PARENT_US","account_code":"IC_AR","date_range":["2026-01-01","2026-01-31"]}
{"type":"link_transactions","debit_txn_id":"TXN1","credit_txn_id":"TXN2","rationale":"Explain the match."}
{"type":"apply_forex_adjustment","txn_id":"TXN1","exchange_rate":1.3025,"date":"2026-02-05"}
{"type":"post_elimination_entry","entity_id":"GROUP","amount":12.34,"account":"IC_FX_ELIM_CLEARING"}
Choose exactly one action for this turn. Do not emit multiple actions. Do not use markdown fences.
"""
class EpisodeClient(Protocol):
async def reset(self, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]:
...
async def step(self, action: ActionLike, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]:
...
async def state(self) -> EnterpriseFinanceState:
...
@dataclass
class EpisodeSummary:
final_result: StepResult[EnterpriseFinanceObservation]
final_state: EnterpriseFinanceState
def _normalize_reference(reference: str) -> str:
return "-".join(reference.split("-")[:-1]) if "-" in reference else reference
def _date_bounds(rows: list[Any]) -> tuple[str, str]:
dates = sorted(row.txn_date for row in rows)
return dates[0], dates[-1]
def _pair_by_reference(observation: EnterpriseFinanceObservation) -> list[tuple[str, str]]:
grouped: dict[str, dict[str, str]] = {}
for row in observation.structured_ledgers:
bucket = grouped.setdefault(_normalize_reference(row.reference), {})
bucket[row.side] = row.txn_id
pairs: list[tuple[str, str]] = []
for bucket in grouped.values():
debit = bucket.get("debit")
credit = bucket.get("credit")
if debit and credit:
pairs.append((debit, credit))
return pairs
def _medium_instructions(observation: EnterpriseFinanceObservation) -> list[dict[str, Any]]:
instructions: list[dict[str, Any]] = []
for blob in observation.unstructured_invoices:
data = json.loads(blob) if isinstance(blob, str) else blob
instructions.append(
{
"debit_txn_id": data["debit_txn_id"],
"credit_txn_id": data["credit_txn_id"],
"fx_txn_id": data["fx_txn_id"],
"fx_date": data["fx_date"],
"fx_rate": float(data["fx_rate"]),
"approved_fx_drift_usd": float(data["approved_fx_drift_usd"]),
}
)
return instructions
def _terminal_posting(observation: EnterpriseFinanceObservation, difficulty: str) -> tuple[float, str]:
if difficulty == "easy":
return 0.0, "IC_ELIM_CLEARING"
if difficulty == "medium":
total = sum(
float(blob["approved_fx_drift_usd"])
if not isinstance(blob, str)
else float(json.loads(blob)["approved_fx_drift_usd"])
for blob in observation.unstructured_invoices
)
return round(total, 2), "IC_FX_ELIM_CLEARING"
total = sum(
float(blob["approved_write_off_usd"])
if not isinstance(blob, str)
else float(json.loads(blob)["approved_write_off_usd"])
for blob in observation.dispute_inbox
)
return round(total, 2), "IC_WRITE_OFF_CLEARING"
def _observation_digest(observation: EnterpriseFinanceObservation) -> dict[str, Any]:
rows = observation.structured_ledgers
sample_rows = [
{
"txn_id": row.txn_id,
"entity_id": row.entity_id,
"counterparty_entity_id": row.counterparty_entity_id,
"account_code": row.account_code,
"side": row.side,
"amount": row.amount,
"currency": row.currency,
"txn_date": row.txn_date,
"reference": row.reference,
"has_fx_adjustment": row.has_fx_adjustment,
"disputed": row.disputed,
}
for row in rows[:12]
]
invoice_sample = [
blob if isinstance(blob, dict) else json.loads(blob)
for blob in observation.unstructured_invoices[:4]
]
dispute_sample = [
blob if isinstance(blob, dict) else json.loads(blob)
for blob in observation.dispute_inbox[:4]
]
return {
"feedback_message": observation.feedback_message,
"unmatched_row_count": len(rows),
"sample_rows": sample_rows,
"fx_rates_sample": dict(list(observation.fx_rates.items())[:10]),
"unstructured_invoices_sample": invoice_sample,
"dispute_inbox_sample": dispute_sample,
}
def _build_user_prompt(
step: int,
observation: EnterpriseFinanceObservation,
state: EnterpriseFinanceState,
history: list[str],
) -> str:
history_block = "\n".join(history[-4:]) if history else "None"
prompt_payload = {
"step": step,
"difficulty": state.difficulty,
"step_count": state.step_count,
"open_abs_balance": state.open_abs_balance,
"unmatched_valid_count": state.unmatched_valid_count,
"total_reward": state.total_reward,
"observation": _observation_digest(observation),
"recent_history": history_block,
}
return json.dumps(prompt_payload, indent=2)
def _extract_json_block(content: str) -> str:
stripped = content.strip()
if stripped.startswith("```"):
stripped = re.sub(r"^```(?:json)?", "", stripped).strip()
stripped = re.sub(r"```$", "", stripped).strip()
return stripped
def _build_tools() -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "query_subledger",
"description": "Query a ledger slice for one entity and account code over a date range.",
"parameters": {
"type": "object",
"properties": {
"entity": {"type": "string"},
"account_code": {"type": "string"},
"date_range": {
"type": "array",
"items": {"type": "string"},
"minItems": 2,
"maxItems": 2,
},
},
"required": ["entity", "account_code", "date_range"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "link_transactions",
"description": "Link a debit and credit transaction when they are valid elimination counterparts.",
"parameters": {
"type": "object",
"properties": {
"debit_txn_id": {"type": "string"},
"credit_txn_id": {"type": "string"},
"rationale": {"type": "string"},
},
"required": ["debit_txn_id", "credit_txn_id", "rationale"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "apply_forex_adjustment",
"description": "Apply an approved FX adjustment using the exact published rate and date.",
"parameters": {
"type": "object",
"properties": {
"txn_id": {"type": "string"},
"exchange_rate": {"type": "number"},
"date": {"type": "string"},
},
"required": ["txn_id", "exchange_rate", "date"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "post_elimination_entry",
"description": "Submit the final GROUP elimination entry once the residual is fully understood.",
"parameters": {
"type": "object",
"properties": {
"entity_id": {"type": "string"},
"amount": {"type": "number"},
"account": {"type": "string"},
},
"required": ["entity_id", "amount", "account"],
"additionalProperties": False,
},
},
},
]
def _tool_call_to_action(name: str, arguments: dict[str, Any]) -> ActionLike:
if name == "query_subledger":
return QuerySubledger(
entity=arguments["entity"],
account_code=arguments["account_code"],
date_range=tuple(arguments["date_range"]),
)
if name == "link_transactions":
return LinkTransactions(
debit_txn_id=arguments["debit_txn_id"],
credit_txn_id=arguments["credit_txn_id"],
rationale=arguments["rationale"],
)
if name == "apply_forex_adjustment":
return ApplyForexAdjustment(
txn_id=arguments["txn_id"],
exchange_rate=arguments["exchange_rate"],
date=arguments["date"],
)
if name == "post_elimination_entry":
return PostEliminationEntry(
entity_id=arguments["entity_id"],
amount=arguments["amount"],
account=arguments["account"],
)
raise ValueError(f"Unsupported tool call: {name}")
def _json_dict_to_action(payload: dict[str, Any]) -> ActionLike:
return EnterpriseFinanceActionPayload.model_validate(payload).root
def _fallback_action(observation: EnterpriseFinanceObservation) -> ActionLike:
if observation.structured_ledgers:
start_date, end_date = _date_bounds(observation.structured_ledgers)
first_row = observation.structured_ledgers[0]
return QuerySubledger(
entity=first_row.entity_id,
account_code=first_row.account_code,
date_range=(start_date, end_date),
)
return PostEliminationEntry(
entity_id="GROUP",
amount=0.0,
account="IC_ELIM_CLEARING",
)
def _format_action(action: ActionLike) -> str:
payload = (
action.model_dump(mode="json")
if hasattr(action, "model_dump")
else EnterpriseFinanceActionPayload.model_validate(action).model_dump(mode="json")
)
return json.dumps(payload, separators=(",", ":"))
def _provider_prefers_json_fallback(api_base_url: str) -> bool:
return "groq.com" in api_base_url.lower()
def _fallback_json_completion(
*,
llm_client: OpenAI,
model: str,
user_prompt: str,
temperature: float,
max_tokens: int,
) -> ActionLike:
completion = llm_client.chat.completions.create(
model=model,
temperature=temperature,
max_tokens=max_tokens,
messages=[
{"role": "system", "content": JSON_FALLBACK_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
)
content = completion.choices[0].message.content or ""
payload = json.loads(_extract_json_block(content))
return _json_dict_to_action(payload)
def _print_step_trace(
step_index: int,
action: ActionLike,
result: StepResult[EnterpriseFinanceObservation],
state: EnterpriseFinanceState,
) -> None:
print(f"Step {step_index}: { _format_action(action) }")
print(
" reward="
f"{(result.reward or 0.0):+.2f}"
f" done={result.done}"
f" open_balance={state.open_abs_balance:.2f}"
f" unmatched={state.unmatched_valid_count}"
f" total_reward={state.total_reward:.2f}"
f" score={state.final_score}"
)
print(f" feedback={result.observation.feedback_message}")
async def run_heuristic_episode(client: EpisodeClient, difficulty: str) -> EpisodeSummary:
reset_result = await client.reset(difficulty=difficulty)
observation = reset_result.observation
if observation.structured_ledgers:
start_date, end_date = _date_bounds(observation.structured_ledgers)
first_row = observation.structured_ledgers[0]
await client.step(
QuerySubledger(
entity=first_row.entity_id,
account_code=first_row.account_code,
date_range=(start_date, end_date),
)
)
if difficulty == "medium":
for instruction in _medium_instructions(observation):
await client.step(
ApplyForexAdjustment(
txn_id=instruction["fx_txn_id"],
exchange_rate=instruction["fx_rate"],
date=instruction["fx_date"],
)
)
result = await client.step(
LinkTransactions(
debit_txn_id=instruction["debit_txn_id"],
credit_txn_id=instruction["credit_txn_id"],
rationale="Invoice blob, legal entities, and approved FX rate align.",
)
)
if result.done:
break
else:
for debit_txn_id, credit_txn_id in _pair_by_reference(observation):
result = await client.step(
LinkTransactions(
debit_txn_id=debit_txn_id,
credit_txn_id=credit_txn_id,
rationale="Counterparty references and debit-credit orientation agree.",
)
)
if result.done:
break
amount, account = _terminal_posting(observation, difficulty)
final_result = await client.step(
PostEliminationEntry(entity_id="GROUP", amount=amount, account=account)
)
final_state = await client.state()
return EpisodeSummary(final_result=final_result, final_state=final_state)
async def run_openai_episode(
client: EpisodeClient,
*,
llm_client: OpenAI,
api_base_url: str,
difficulty: str,
model: str,
max_steps: int,
temperature: float,
max_tokens: int,
) -> EpisodeSummary:
history: list[str] = []
tools = _build_tools()
result = await client.reset(difficulty=difficulty)
current_state = await client.state()
for step_index in range(1, max_steps + 1):
user_prompt = _build_user_prompt(
step_index,
result.observation,
current_state,
history,
)
action: ActionLike
try:
if _provider_prefers_json_fallback(api_base_url):
action = _fallback_json_completion(
llm_client=llm_client,
model=model,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
)
else:
completion = llm_client.chat.completions.create(
model=model,
temperature=temperature,
max_tokens=max_tokens,
tool_choice="required",
parallel_tool_calls=False,
tools=tools,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
)
message = completion.choices[0].message
tool_call = message.tool_calls[0] if getattr(message, "tool_calls", None) else None
if tool_call is None:
action = _fallback_action(result.observation)
else:
arguments = json.loads(tool_call.function.arguments or "{}")
action = _tool_call_to_action(tool_call.function.name, arguments)
except Exception as exc: # noqa: BLE001
if "tool_use_failed" not in str(exc) and "Failed to call a function" not in str(exc):
raise
action = _fallback_json_completion(
llm_client=llm_client,
model=model,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
)
result = await client.step(action)
current_state = await client.state()
history.append(
f"step={step_index} reward={(result.reward or 0.0):+.2f} "
f"unmatched={current_state.unmatched_valid_count} "
f"open_balance={current_state.open_abs_balance:.2f}"
)
_print_step_trace(step_index, action, result, current_state)
if result.done:
break
return EpisodeSummary(final_result=result, final_state=current_state)
def _print_summary(summary: EpisodeSummary) -> None:
print("Final Result")
print(f" reward={summary.final_result.reward}")
print(f" done={summary.final_result.done}")
print(f" feedback={summary.final_result.observation.feedback_message}")
state = summary.final_state
compact_state = {
"difficulty": state.difficulty,
"scenario_id": state.scenario_id,
"step_count": state.step_count,
"open_abs_balance": state.open_abs_balance,
"linked_pairs": len(state.linked_pairs),
"fx_adjustments": len(state.fx_adjustments),
"invalid_actions": state.invalid_actions,
"unmatched_valid_count": state.unmatched_valid_count,
"expected_terminal_amount": state.expected_terminal_amount,
"expected_terminal_account": state.expected_terminal_account,
"force_balance_attempted": state.force_balance_attempted,
"final_score": state.final_score,
"total_reward": state.total_reward,
}
print(json.dumps(compact_state, indent=2))
async def _main_async(args: argparse.Namespace) -> None:
try:
async with EnterpriseFinanceClient(base_url=args.env_base_url) as client:
if args.policy == "heuristic":
summary = await run_heuristic_episode(client, difficulty=args.difficulty)
else:
if not API_KEY:
raise RuntimeError(
"HF_TOKEN (or API_KEY) must be set for --policy openai."
)
if not args.model_name:
raise RuntimeError("MODEL_NAME (or --model-name) must be set for --policy openai.")
llm_client = OpenAI(base_url=args.api_base_url, api_key=API_KEY)
summary = await run_openai_episode(
client,
llm_client=llm_client,
api_base_url=args.api_base_url,
difficulty=args.difficulty,
model=args.model_name,
max_steps=args.max_steps,
temperature=args.temperature,
max_tokens=args.max_tokens,
)
except Exception as exc: # noqa: BLE001
raise SystemExit(
f"Failed to run inference against environment API at {args.env_base_url}: {exc}"
) from exc
_print_summary(summary)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run enterprise finance OpenEnv baselines.")
parser.add_argument("--env-base-url", default=ENV_BASE_URL)
parser.add_argument("--api-base-url", default=API_BASE_URL)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument("--policy", choices=["openai", "heuristic"], default="openai")
parser.add_argument("--model-name", default=MODEL_NAME)
parser.add_argument("--max-steps", type=int, default=MAX_STEPS)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--max-tokens", type=int, default=MAX_TOKENS)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
asyncio.run(_main_async(args))
if __name__ == "__main__":
main()