File size: 2,864 Bytes
d814291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

from osint_env.agents.single_agent import SingleAgentRunner
from osint_env.agents.swarm_agent import SwarmAgentRunner
from osint_env.env.environment import OSINTEnvironment
from osint_env.env.reward import compute_graph_f1
from osint_env.eval.metrics import EvalMetrics
from osint_env.llm.interface import LLMClient


def run_evaluation(
    env: OSINTEnvironment,
    episodes: int = 20,
    return_details: bool = False,
    llm: LLMClient | None = None,
) -> dict:
    metrics = EvalMetrics()
    if env.config.swarm.enabled:
        runner = SwarmAgentRunner(env=env, llm=llm)
    else:
        runner = SingleAgentRunner(env=env, llm=llm)
    episode_rows: list[dict] = []
    for _ in range(episodes):
        info = runner.run_episode()
        task_type = env.state.task.task_type if env.state else "unknown"
        task_id = env.state.task.task_id if env.state else "unknown"
        truth = env.state.task.supporting_edges if env.state else []
        pred = env.memory_graph.edges if env.state else []
        graph_f1 = compute_graph_f1(pred, truth)
        metrics.add(info, task_type=task_type, graph_f1=graph_f1)
        episode_rows.append(
            {
                "task_id": task_id,
                "task_type": task_type,
                "question": env.state.task.question if env.state else "",
                "task_answer": str(info.get("task_answer", "")),
                "agent_answer": str(info.get("agent_answer", "")) if info.get("agent_answer") is not None else "",
                "graph_f1": graph_f1,
                "reward": float(info.get("total_reward", 0.0)),
                "steps": int(info.get("step_count", 0)),
                "tool_calls": int(info.get("tool_calls", 0)),
                "success": int(info.get("agent_answer") == info.get("task_answer")),
                "reward_components": dict(info.get("reward_components", {})),
                "spawn_count": int(info.get("spawn_count", 0)),
                "spawn_critical_steps": int(info.get("spawn_critical_steps", 0)),
                "pred_edges": [
                    {
                        "src": edge.src,
                        "rel": edge.rel,
                        "dst": edge.dst,
                        "confidence": float(edge.confidence),
                    }
                    for edge in pred
                ],
                "truth_edges": [
                    {
                        "src": edge.src,
                        "rel": edge.rel,
                        "dst": edge.dst,
                        "confidence": float(edge.confidence),
                    }
                    for edge in truth
                ],
            }
        )
    summary = metrics.summary()
    if return_details:
        return {"summary": summary, "episodes": episode_rows}
    return summary