ritishshrirao commited on
Commit
ea695ac
·
1 Parent(s): 02fe199

Update inference

Browse files
Files changed (1) hide show
  1. inference.py +4 -3
inference.py CHANGED
@@ -94,10 +94,10 @@ def log_step(step: int, action: str, reward: float, done: bool, error: str | Non
94
  )
95
 
96
 
97
- def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
98
  rewards_text = ",".join(f"{value:.2f}" for value in rewards)
99
  print(
100
- f"[END] success={str(bool(success)).lower()} steps={steps} score={score:.2f} rewards={rewards_text}",
101
  flush=True,
102
  )
103
 
@@ -441,8 +441,9 @@ def main() -> None:
441
  )
442
 
443
  score = float(summary.get("avg_reward", 0.0) or 0.0)
 
444
  success = score >= SUCCESS_SCORE_THRESHOLD
445
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
446
 
447
  record, dashboard = _maybe_write_artifacts(
448
  env=env,
 
94
  )
95
 
96
 
97
+ def log_end(task: str, success: bool, steps: int, score: float, rewards: list[float]) -> None:
98
  rewards_text = ",".join(f"{value:.2f}" for value in rewards)
99
  print(
100
+ f"[END] task={task} success={str(bool(success)).lower()} steps={steps} score={score:.2f} rewards={rewards_text}",
101
  flush=True,
102
  )
103
 
 
441
  )
442
 
443
  score = float(summary.get("avg_reward", 0.0) or 0.0)
444
+ score = max(0.01, min(0.99, score))
445
  success = score >= SUCCESS_SCORE_THRESHOLD
446
+ log_end(task=TASK_NAME, success=success, steps=steps_taken, score=score, rewards=rewards)
447
 
448
  record, dashboard = _maybe_write_artifacts(
449
  env=env,