SouravNath commited on
Commit
e76b46b
·
1 Parent(s): 53afd2e

fix: normalise LangGraph dict output back to AgentState in _run_with_langgraph

Browse files
Files changed (1) hide show
  1. agent/reflection_agent.py +30 -6
agent/reflection_agent.py CHANGED
@@ -380,6 +380,15 @@ class ReflectionAgent:
380
 
381
  app = graph.compile()
382
  final = app.invoke(state)
 
 
 
 
 
 
 
 
 
383
  return final
384
 
385
  except Exception as e:
@@ -389,10 +398,25 @@ class ReflectionAgent:
389
  def _log_trajectories(self, state: AgentState) -> None:
390
  """Write all attempt records to the trajectory logger."""
391
  from agent.trajectory_logger import TrajectoryEntry
392
- for attempt_data in state.attempts:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  entry = TrajectoryEntry(
394
- instance_id=state.instance_id,
395
- repo=state.repo,
396
  attempt=attempt_data["attempt_num"],
397
  patch=attempt_data["patch"],
398
  test_stdout=attempt_data["test_stdout"],
@@ -400,9 +424,9 @@ class ReflectionAgent:
400
  pass_to_pass_results=attempt_data["pass_to_pass_results"],
401
  resolved=attempt_data["resolved"],
402
  failure_category=attempt_data["failure_category"],
403
- elapsed_seconds=0.0, # per-attempt timing tracked separately
404
- localised_files=state.localised_files,
405
- problem_statement=state.problem_statement,
406
  token_cost={},
407
  )
408
  self.traj_logger.log(entry)
 
380
 
381
  app = graph.compile()
382
  final = app.invoke(state)
383
+
384
+ # LangGraph may return a plain dict instead of AgentState.
385
+ # Normalise back to the dataclass so downstream code is consistent.
386
+ if isinstance(final, dict):
387
+ final = AgentState(**{
388
+ k: final[k] for k in AgentState.__dataclass_fields__
389
+ if k in final
390
+ })
391
+
392
  return final
393
 
394
  except Exception as e:
 
398
  def _log_trajectories(self, state: AgentState) -> None:
399
  """Write all attempt records to the trajectory logger."""
400
  from agent.trajectory_logger import TrajectoryEntry
401
+
402
+ # Handle both AgentState dataclass and plain dict (LangGraph compat)
403
+ if isinstance(state, dict):
404
+ attempts = state.get("attempts", [])
405
+ instance_id = state.get("instance_id", "")
406
+ repo = state.get("repo", "")
407
+ localised = state.get("localised_files", [])
408
+ problem = state.get("problem_statement", "")
409
+ else:
410
+ attempts = state.attempts
411
+ instance_id = state.instance_id
412
+ repo = state.repo
413
+ localised = state.localised_files
414
+ problem = state.problem_statement
415
+
416
+ for attempt_data in attempts:
417
  entry = TrajectoryEntry(
418
+ instance_id=instance_id,
419
+ repo=repo,
420
  attempt=attempt_data["attempt_num"],
421
  patch=attempt_data["patch"],
422
  test_stdout=attempt_data["test_stdout"],
 
424
  pass_to_pass_results=attempt_data["pass_to_pass_results"],
425
  resolved=attempt_data["resolved"],
426
  failure_category=attempt_data["failure_category"],
427
+ elapsed_seconds=0.0,
428
+ localised_files=localised,
429
+ problem_statement=problem,
430
  token_cost={},
431
  )
432
  self.traj_logger.log(entry)