Prasham1710 commited on
Commit
d9ced2a
·
1 Parent(s): 14a2eb9

Add Groq-safe inference fallback

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
__pycache__/inference.cpython-312.pyc CHANGED
Binary files a/__pycache__/inference.cpython-312.pyc and b/__pycache__/inference.cpython-312.pyc differ
 
inference.py CHANGED
@@ -4,6 +4,7 @@ import argparse
4
  import asyncio
5
  import json
6
  import os
 
7
  from dataclasses import dataclass
8
  from typing import Any, Protocol
9
 
@@ -54,6 +55,18 @@ Use exactly one tool call on each turn. Prefer safe, incremental actions:
54
  Never invent transaction ids, FX rates, dates, or accounts.
55
  """
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  class EpisodeClient(Protocol):
59
  async def reset(self, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]:
@@ -191,6 +204,14 @@ def _build_user_prompt(
191
  return json.dumps(prompt_payload, indent=2)
192
 
193
 
 
 
 
 
 
 
 
 
194
  def _build_tools() -> list[dict[str, Any]]:
195
  return [
196
  {
@@ -297,6 +318,10 @@ def _tool_call_to_action(name: str, arguments: dict[str, Any]) -> ActionLike:
297
  raise ValueError(f"Unsupported tool call: {name}")
298
 
299
 
 
 
 
 
300
  def _fallback_action(observation: EnterpriseFinanceObservation) -> ActionLike:
301
  if observation.structured_ledgers:
302
  start_date, end_date = _date_bounds(observation.structured_ledgers)
@@ -323,6 +348,32 @@ def _format_action(action: ActionLike) -> str:
323
  return json.dumps(payload, separators=(",", ":"))
324
 
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def _print_step_trace(
327
  step_index: int,
328
  action: ActionLike,
@@ -399,6 +450,7 @@ async def run_openai_episode(
399
  client: EpisodeClient,
400
  *,
401
  llm_client: OpenAI,
 
402
  difficulty: str,
403
  model: str,
404
  max_steps: int,
@@ -411,33 +463,52 @@ async def run_openai_episode(
411
  current_state = await client.state()
412
 
413
  for step_index in range(1, max_steps + 1):
414
- completion = llm_client.chat.completions.create(
415
- model=model,
416
- temperature=temperature,
417
- max_tokens=max_tokens,
418
- tool_choice="required",
419
- tools=tools,
420
- messages=[
421
- {"role": "system", "content": SYSTEM_PROMPT},
422
- {
423
- "role": "user",
424
- "content": _build_user_prompt(
425
- step_index,
426
- result.observation,
427
- current_state,
428
- history,
429
- ),
430
- },
431
- ],
432
  )
433
-
434
- message = completion.choices[0].message
435
- tool_call = message.tool_calls[0] if getattr(message, "tool_calls", None) else None
436
- if tool_call is None:
437
- action = _fallback_action(result.observation)
438
- else:
439
- arguments = json.loads(tool_call.function.arguments or "{}")
440
- action = _tool_call_to_action(tool_call.function.name, arguments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  result = await client.step(action)
443
  current_state = await client.state()
@@ -494,6 +565,7 @@ async def _main_async(args: argparse.Namespace) -> None:
494
  summary = await run_openai_episode(
495
  client,
496
  llm_client=llm_client,
 
497
  difficulty=args.difficulty,
498
  model=args.model_name,
499
  max_steps=args.max_steps,
 
4
  import asyncio
5
  import json
6
  import os
7
+ import re
8
  from dataclasses import dataclass
9
  from typing import Any, Protocol
10
 
 
55
  Never invent transaction ids, FX rates, dates, or accounts.
56
  """
57
 
58
+ JSON_FALLBACK_SYSTEM_PROMPT = """You are the Consolidation Controller for a GAAP-compliant enterprise finance simulation.
59
+
60
+ Reply with exactly one JSON object and nothing else.
61
+ The JSON object must match one of these shapes:
62
+ {"type":"query_subledger","entity":"PARENT_US","account_code":"IC_AR","date_range":["2026-01-01","2026-01-31"]}
63
+ {"type":"link_transactions","debit_txn_id":"TXN1","credit_txn_id":"TXN2","rationale":"Explain the match."}
64
+ {"type":"apply_forex_adjustment","txn_id":"TXN1","exchange_rate":1.3025,"date":"2026-02-05"}
65
+ {"type":"post_elimination_entry","entity_id":"GROUP","amount":12.34,"account":"IC_FX_ELIM_CLEARING"}
66
+
67
+ Choose exactly one action for this turn. Do not emit multiple actions. Do not use markdown fences.
68
+ """
69
+
70
 
71
  class EpisodeClient(Protocol):
72
  async def reset(self, **kwargs: Any) -> StepResult[EnterpriseFinanceObservation]:
 
204
  return json.dumps(prompt_payload, indent=2)
205
 
206
 
207
+ def _extract_json_block(content: str) -> str:
208
+ stripped = content.strip()
209
+ if stripped.startswith("```"):
210
+ stripped = re.sub(r"^```(?:json)?", "", stripped).strip()
211
+ stripped = re.sub(r"```$", "", stripped).strip()
212
+ return stripped
213
+
214
+
215
  def _build_tools() -> list[dict[str, Any]]:
216
  return [
217
  {
 
318
  raise ValueError(f"Unsupported tool call: {name}")
319
 
320
 
321
+ def _json_dict_to_action(payload: dict[str, Any]) -> ActionLike:
322
+ return EnterpriseFinanceActionPayload.model_validate(payload).root
323
+
324
+
325
  def _fallback_action(observation: EnterpriseFinanceObservation) -> ActionLike:
326
  if observation.structured_ledgers:
327
  start_date, end_date = _date_bounds(observation.structured_ledgers)
 
348
  return json.dumps(payload, separators=(",", ":"))
349
 
350
 
351
+ def _provider_prefers_json_fallback(api_base_url: str) -> bool:
352
+ return "groq.com" in api_base_url.lower()
353
+
354
+
355
+ def _fallback_json_completion(
356
+ *,
357
+ llm_client: OpenAI,
358
+ model: str,
359
+ user_prompt: str,
360
+ temperature: float,
361
+ max_tokens: int,
362
+ ) -> ActionLike:
363
+ completion = llm_client.chat.completions.create(
364
+ model=model,
365
+ temperature=temperature,
366
+ max_tokens=max_tokens,
367
+ messages=[
368
+ {"role": "system", "content": JSON_FALLBACK_SYSTEM_PROMPT},
369
+ {"role": "user", "content": user_prompt},
370
+ ],
371
+ )
372
+ content = completion.choices[0].message.content or ""
373
+ payload = json.loads(_extract_json_block(content))
374
+ return _json_dict_to_action(payload)
375
+
376
+
377
  def _print_step_trace(
378
  step_index: int,
379
  action: ActionLike,
 
450
  client: EpisodeClient,
451
  *,
452
  llm_client: OpenAI,
453
+ api_base_url: str,
454
  difficulty: str,
455
  model: str,
456
  max_steps: int,
 
463
  current_state = await client.state()
464
 
465
  for step_index in range(1, max_steps + 1):
466
+ user_prompt = _build_user_prompt(
467
+ step_index,
468
+ result.observation,
469
+ current_state,
470
+ history,
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  )
472
+ action: ActionLike
473
+ try:
474
+ if _provider_prefers_json_fallback(api_base_url):
475
+ action = _fallback_json_completion(
476
+ llm_client=llm_client,
477
+ model=model,
478
+ user_prompt=user_prompt,
479
+ temperature=temperature,
480
+ max_tokens=max_tokens,
481
+ )
482
+ else:
483
+ completion = llm_client.chat.completions.create(
484
+ model=model,
485
+ temperature=temperature,
486
+ max_tokens=max_tokens,
487
+ tool_choice="required",
488
+ parallel_tool_calls=False,
489
+ tools=tools,
490
+ messages=[
491
+ {"role": "system", "content": SYSTEM_PROMPT},
492
+ {"role": "user", "content": user_prompt},
493
+ ],
494
+ )
495
+ message = completion.choices[0].message
496
+ tool_call = message.tool_calls[0] if getattr(message, "tool_calls", None) else None
497
+ if tool_call is None:
498
+ action = _fallback_action(result.observation)
499
+ else:
500
+ arguments = json.loads(tool_call.function.arguments or "{}")
501
+ action = _tool_call_to_action(tool_call.function.name, arguments)
502
+ except Exception as exc: # noqa: BLE001
503
+ if "tool_use_failed" not in str(exc) and "Failed to call a function" not in str(exc):
504
+ raise
505
+ action = _fallback_json_completion(
506
+ llm_client=llm_client,
507
+ model=model,
508
+ user_prompt=user_prompt,
509
+ temperature=temperature,
510
+ max_tokens=max_tokens,
511
+ )
512
 
513
  result = await client.step(action)
514
  current_state = await client.state()
 
565
  summary = await run_openai_episode(
566
  client,
567
  llm_client=llm_client,
568
+ api_base_url=args.api_base_url,
569
  difficulty=args.difficulty,
570
  model=args.model_name,
571
  max_steps=args.max_steps,
tests/__pycache__/test_end_to_end.cpython-312-pytest-9.0.2.pyc CHANGED
Binary files a/tests/__pycache__/test_end_to_end.cpython-312-pytest-9.0.2.pyc and b/tests/__pycache__/test_end_to_end.cpython-312-pytest-9.0.2.pyc differ
 
tests/test_end_to_end.py CHANGED
@@ -132,6 +132,7 @@ async def test_openai_policy_path_can_solve_easy_with_fake_client() -> None:
132
  summary = await run_openai_episode(
133
  LocalAsyncAdapter("easy"),
134
  llm_client=FakeOpenAIClient(),
 
135
  difficulty="easy",
136
  model="fake-model",
137
  max_steps=200,
 
132
  summary = await run_openai_episode(
133
  LocalAsyncAdapter("easy"),
134
  llm_client=FakeOpenAIClient(),
135
+ api_base_url="https://router.huggingface.co/v1",
136
  difficulty="easy",
137
  model="fake-model",
138
  max_steps=200,
uv.lock CHANGED
@@ -620,6 +620,7 @@ dependencies = [
620
  { name = "openai" },
621
  { name = "openenv-core" },
622
  { name = "pydantic" },
 
623
  { name = "uvicorn" },
624
  ]
625
 
@@ -638,6 +639,7 @@ requires-dist = [
638
  { name = "pydantic", specifier = ">=2.8.0" },
639
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
640
  { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
 
641
  { name = "uvicorn", specifier = ">=0.30.0" },
642
  ]
643
  provides-extras = ["dev"]
 
620
  { name = "openai" },
621
  { name = "openenv-core" },
622
  { name = "pydantic" },
623
+ { name = "python-dotenv" },
624
  { name = "uvicorn" },
625
  ]
626
 
 
639
  { name = "pydantic", specifier = ">=2.8.0" },
640
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
641
  { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
642
+ { name = "python-dotenv", specifier = ">=1.0.0" },
643
  { name = "uvicorn", specifier = ">=0.30.0" },
644
  ]
645
  provides-extras = ["dev"]