File size: 2,043 Bytes
db4fa53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from __future__ import annotations

from typing import Any

from pydantic import BaseModel, Field


class OpenEnvTaskSummary(BaseModel):
    task_id: str
    task_type: str
    question: str
    difficulty: str = "unknown"
    grader: dict[str, Any] = Field(default_factory=dict)


class OpenEnvObservationModel(BaseModel):
    tool_outputs: list[dict[str, Any]]
    graph_snapshot: dict[str, Any]
    action_history: list[dict[str, Any]]
    task: dict[str, Any]


class OpenEnvResetRequest(BaseModel):
    task_id: str | None = None
    task_index: int | None = None


class OpenEnvActionRequest(BaseModel):
    session_id: str | None = Field(
        default=None,
        description="Session identifier. Optional for /step compatibility alias, which uses the latest session.",
    )
    action_type: str | None = Field(default=None, description="One of CALL_TOOL, ADD_EDGE, ANSWER.")
    payload: dict[str, Any] = Field(default_factory=dict)
    action: dict[str, Any] | None = None

    def resolved_action_type(self) -> str:
        if self.action_type:
            return str(self.action_type)
        if isinstance(self.action, dict):
            nested = self.action.get("action_type")
            if nested:
                return str(nested)
        return ""

    def resolved_payload(self) -> dict[str, Any]:
        if self.payload:
            return dict(self.payload)
        if isinstance(self.action, dict):
            nested = self.action.get("payload")
            if isinstance(nested, dict):
                return dict(nested)
        return {}


class OpenEnvResponseEnvelope(BaseModel):
    session_id: str
    observation: OpenEnvObservationModel
    reward: float
    done: bool
    info: dict[str, Any]


class OpenEnvInferenceReportRequest(BaseModel):
    run: dict[str, Any] = Field(default_factory=dict)
    summary: dict[str, Any]
    episodes: list[dict[str, Any]] = Field(default_factory=list)


class OpenEnvInferenceReportResponse(BaseModel):
    status: str
    output_path: str
    dashboard_path: str