Spaces:
Running
Running
Commit ·
e76b46b
1
Parent(s): 53afd2e
fix: normalise LangGraph dict output back to AgentState in _run_with_langgraph
Browse files- agent/reflection_agent.py +30 -6
agent/reflection_agent.py
CHANGED
|
@@ -380,6 +380,15 @@ class ReflectionAgent:
|
|
| 380 |
|
| 381 |
app = graph.compile()
|
| 382 |
final = app.invoke(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
return final
|
| 384 |
|
| 385 |
except Exception as e:
|
|
@@ -389,10 +398,25 @@ class ReflectionAgent:
|
|
| 389 |
def _log_trajectories(self, state: AgentState) -> None:
|
| 390 |
"""Write all attempt records to the trajectory logger."""
|
| 391 |
from agent.trajectory_logger import TrajectoryEntry
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
entry = TrajectoryEntry(
|
| 394 |
-
instance_id=
|
| 395 |
-
repo=
|
| 396 |
attempt=attempt_data["attempt_num"],
|
| 397 |
patch=attempt_data["patch"],
|
| 398 |
test_stdout=attempt_data["test_stdout"],
|
|
@@ -400,9 +424,9 @@ class ReflectionAgent:
|
|
| 400 |
pass_to_pass_results=attempt_data["pass_to_pass_results"],
|
| 401 |
resolved=attempt_data["resolved"],
|
| 402 |
failure_category=attempt_data["failure_category"],
|
| 403 |
-
elapsed_seconds=0.0,
|
| 404 |
-
localised_files=
|
| 405 |
-
problem_statement=
|
| 406 |
token_cost={},
|
| 407 |
)
|
| 408 |
self.traj_logger.log(entry)
|
|
|
|
| 380 |
|
| 381 |
app = graph.compile()
|
| 382 |
final = app.invoke(state)
|
| 383 |
+
|
| 384 |
+
# LangGraph may return a plain dict instead of AgentState.
|
| 385 |
+
# Normalise back to the dataclass so downstream code is consistent.
|
| 386 |
+
if isinstance(final, dict):
|
| 387 |
+
final = AgentState(**{
|
| 388 |
+
k: final[k] for k in AgentState.__dataclass_fields__
|
| 389 |
+
if k in final
|
| 390 |
+
})
|
| 391 |
+
|
| 392 |
return final
|
| 393 |
|
| 394 |
except Exception as e:
|
|
|
|
| 398 |
def _log_trajectories(self, state: AgentState) -> None:
|
| 399 |
"""Write all attempt records to the trajectory logger."""
|
| 400 |
from agent.trajectory_logger import TrajectoryEntry
|
| 401 |
+
|
| 402 |
+
# Handle both AgentState dataclass and plain dict (LangGraph compat)
|
| 403 |
+
if isinstance(state, dict):
|
| 404 |
+
attempts = state.get("attempts", [])
|
| 405 |
+
instance_id = state.get("instance_id", "")
|
| 406 |
+
repo = state.get("repo", "")
|
| 407 |
+
localised = state.get("localised_files", [])
|
| 408 |
+
problem = state.get("problem_statement", "")
|
| 409 |
+
else:
|
| 410 |
+
attempts = state.attempts
|
| 411 |
+
instance_id = state.instance_id
|
| 412 |
+
repo = state.repo
|
| 413 |
+
localised = state.localised_files
|
| 414 |
+
problem = state.problem_statement
|
| 415 |
+
|
| 416 |
+
for attempt_data in attempts:
|
| 417 |
entry = TrajectoryEntry(
|
| 418 |
+
instance_id=instance_id,
|
| 419 |
+
repo=repo,
|
| 420 |
attempt=attempt_data["attempt_num"],
|
| 421 |
patch=attempt_data["patch"],
|
| 422 |
test_stdout=attempt_data["test_stdout"],
|
|
|
|
| 424 |
pass_to_pass_results=attempt_data["pass_to_pass_results"],
|
| 425 |
resolved=attempt_data["resolved"],
|
| 426 |
failure_category=attempt_data["failure_category"],
|
| 427 |
+
elapsed_seconds=0.0,
|
| 428 |
+
localised_files=localised,
|
| 429 |
+
problem_statement=problem,
|
| 430 |
token_cost={},
|
| 431 |
)
|
| 432 |
self.traj_logger.log(entry)
|