File size: 5,768 Bytes
53dbcc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Data models for the Flatmate RL environment."""

from __future__ import annotations

from typing import Any, Literal

from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field, model_validator


ActionType = Literal["assistant_message", "tool_call"]


class FlatmateRlAction(Action):
    """One assistant turn or one backend tool call."""

    action_type: ActionType = Field(..., description="assistant_message or tool_call.")
    assistant_message: str = Field(default="", description="Broker message shown to the active user.")
    tool_name: str = Field(default="", description="Tool name when action_type='tool_call'.")
    tool_arguments: dict[str, Any] = Field(default_factory=dict, description="Tool arguments.")

    @model_validator(mode="after")
    def validate_shape(self) -> "FlatmateRlAction":
        if self.action_type == "assistant_message":
            if not self.assistant_message.strip():
                raise ValueError("assistant_message is required when action_type='assistant_message'.")
            if self.tool_name or self.tool_arguments:
                raise ValueError("tool_name and tool_arguments must be empty for assistant_message actions.")
        if self.action_type == "tool_call" and not self.tool_name.strip():
            raise ValueError("tool_name is required when action_type='tool_call'.")
        return self


class FlatmateRlObservation(Observation):
    """Visible state returned after reset and each step."""

    status: str = Field(default="ready", description="ready, user_response, tool_result, completed, or failed.")
    scenario_id: str = Field(default="", description="Active scenario id.")
    scenario_label: str = Field(default="", description="Human-readable scenario label.")
    difficulty: str = Field(default="", description="Scenario difficulty.")
    phase: str = Field(default="buyer", description="buyer or seller.")
    current_user_request: str = Field(default="", description="Latest visible user-side request or reply.")
    last_user_message: str = Field(default="", description="Most recent user/seller message.")
    conversation_history: list[dict[str, Any]] = Field(default_factory=list, description="Visible transcript.")
    buyer_conversation_history: list[dict[str, Any]] = Field(default_factory=list, description="Buyer/broker transcript.")
    seller_conversation_history: list[dict[str, Any]] = Field(default_factory=list, description="Seller/broker transcript.")
    last_tool_result: dict[str, Any] = Field(default_factory=dict, description="Most recent tool result.")
    tool_results: list[dict[str, Any]] = Field(default_factory=list, description="All tool results so far.")
    tool_trace: list[dict[str, Any]] = Field(default_factory=list, description="Tool call trace with args and outcomes.")
    available_tools: list[str] = Field(default_factory=list, description="Tools available in the current phase.")
    prerequisites_satisfied: dict[str, bool] = Field(
        default_factory=dict,
        description="Compact workflow state: details_stored, posts_searched, location_matched, slots_checked, buyer_confirmed, poster_confirmed.",
    )
    recent_tool_calls: list[dict[str, Any]] = Field(
        default_factory=list,
        description="Last five tool calls as compact tool_name, tool_arguments_summary, success entries.",
    )
    gathered_fields: list[str] = Field(default_factory=list, description="Fields gathered so far for the active phase.")
    remaining_required_fields: list[str] = Field(default_factory=list, description="Required fields still missing.")
    selected_posts: list[str] = Field(default_factory=list, description="Posts selected by the buyer.")
    booked_visits: list[dict[str, Any]] = Field(default_factory=list, description="Booked visits.")
    profile_stored: bool = Field(default=False, description="Whether the current phase profile is stored.")
    buyer_profile_stored: bool = Field(default=False, description="Whether buyer details are stored.")
    seller_profile_stored: bool = Field(default=False, description="Whether seller details are stored.")
    violations: list[str] = Field(default_factory=list, description="Ordering or policy violations.")
    step_reward: float = Field(default=0.0, description="Reward from the most recent step.")
    total_reward: float = Field(default=0.0, description="Cumulative reward for the episode.")
    message: str = Field(default="", description="Human-readable step summary.")
    feedback_summary: str = Field(default="", description="Broker-facing guidance derived from the latest state or tool result.")


class FlatmateRlState(State):
    """Server-side state snapshot for the active episode."""

    scenario_id: str = Field(default="", description="Active scenario id.")
    phase: str = Field(default="buyer", description="Current episode phase.")
    status: str = Field(default="ready", description="Current episode status.")
    gathered_fields: list[str] = Field(default_factory=list, description="Gathered fields in the active phase.")
    selected_posts: list[str] = Field(default_factory=list, description="Selected post ids.")
    booked_visits: list[dict[str, Any]] = Field(default_factory=list, description="Booked visits.")
    buyer_profile_stored: bool = Field(default=False, description="Whether buyer details are stored.")
    seller_profile_stored: bool = Field(default=False, description="Whether seller details are stored.")
    tool_trace: list[dict[str, Any]] = Field(default_factory=list, description="Structured tool trace.")
    total_reward: float = Field(default=0.0, description="Cumulative reward for the episode.")
    done: bool = Field(default=False, description="Whether the episode is complete.")