Spaces:
Sleeping
Sleeping
File size: 5,768 Bytes
53dbcc1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """Data models for the Flatmate RL environment."""
from __future__ import annotations
from typing import Any, Literal
from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field, model_validator
ActionType = Literal["assistant_message", "tool_call"]
class FlatmateRlAction(Action):
"""One assistant turn or one backend tool call."""
action_type: ActionType = Field(..., description="assistant_message or tool_call.")
assistant_message: str = Field(default="", description="Broker message shown to the active user.")
tool_name: str = Field(default="", description="Tool name when action_type='tool_call'.")
tool_arguments: dict[str, Any] = Field(default_factory=dict, description="Tool arguments.")
@model_validator(mode="after")
def validate_shape(self) -> "FlatmateRlAction":
if self.action_type == "assistant_message":
if not self.assistant_message.strip():
raise ValueError("assistant_message is required when action_type='assistant_message'.")
if self.tool_name or self.tool_arguments:
raise ValueError("tool_name and tool_arguments must be empty for assistant_message actions.")
if self.action_type == "tool_call" and not self.tool_name.strip():
raise ValueError("tool_name is required when action_type='tool_call'.")
return self
class FlatmateRlObservation(Observation):
"""Visible state returned after reset and each step."""
status: str = Field(default="ready", description="ready, user_response, tool_result, completed, or failed.")
scenario_id: str = Field(default="", description="Active scenario id.")
scenario_label: str = Field(default="", description="Human-readable scenario label.")
difficulty: str = Field(default="", description="Scenario difficulty.")
phase: str = Field(default="buyer", description="buyer or seller.")
current_user_request: str = Field(default="", description="Latest visible user-side request or reply.")
last_user_message: str = Field(default="", description="Most recent user/seller message.")
conversation_history: list[dict[str, Any]] = Field(default_factory=list, description="Visible transcript.")
buyer_conversation_history: list[dict[str, Any]] = Field(default_factory=list, description="Buyer/broker transcript.")
seller_conversation_history: list[dict[str, Any]] = Field(default_factory=list, description="Seller/broker transcript.")
last_tool_result: dict[str, Any] = Field(default_factory=dict, description="Most recent tool result.")
tool_results: list[dict[str, Any]] = Field(default_factory=list, description="All tool results so far.")
tool_trace: list[dict[str, Any]] = Field(default_factory=list, description="Tool call trace with args and outcomes.")
available_tools: list[str] = Field(default_factory=list, description="Tools available in the current phase.")
prerequisites_satisfied: dict[str, bool] = Field(
default_factory=dict,
description="Compact workflow state: details_stored, posts_searched, location_matched, slots_checked, buyer_confirmed, poster_confirmed.",
)
recent_tool_calls: list[dict[str, Any]] = Field(
default_factory=list,
description="Last five tool calls as compact tool_name, tool_arguments_summary, success entries.",
)
gathered_fields: list[str] = Field(default_factory=list, description="Fields gathered so far for the active phase.")
remaining_required_fields: list[str] = Field(default_factory=list, description="Required fields still missing.")
selected_posts: list[str] = Field(default_factory=list, description="Posts selected by the buyer.")
booked_visits: list[dict[str, Any]] = Field(default_factory=list, description="Booked visits.")
profile_stored: bool = Field(default=False, description="Whether the current phase profile is stored.")
buyer_profile_stored: bool = Field(default=False, description="Whether buyer details are stored.")
seller_profile_stored: bool = Field(default=False, description="Whether seller details are stored.")
violations: list[str] = Field(default_factory=list, description="Ordering or policy violations.")
step_reward: float = Field(default=0.0, description="Reward from the most recent step.")
total_reward: float = Field(default=0.0, description="Cumulative reward for the episode.")
message: str = Field(default="", description="Human-readable step summary.")
feedback_summary: str = Field(default="", description="Broker-facing guidance derived from the latest state or tool result.")
class FlatmateRlState(State):
"""Server-side state snapshot for the active episode."""
scenario_id: str = Field(default="", description="Active scenario id.")
phase: str = Field(default="buyer", description="Current episode phase.")
status: str = Field(default="ready", description="Current episode status.")
gathered_fields: list[str] = Field(default_factory=list, description="Gathered fields in the active phase.")
selected_posts: list[str] = Field(default_factory=list, description="Selected post ids.")
booked_visits: list[dict[str, Any]] = Field(default_factory=list, description="Booked visits.")
buyer_profile_stored: bool = Field(default=False, description="Whether buyer details are stored.")
seller_profile_stored: bool = Field(default=False, description="Whether seller details are stored.")
tool_trace: list[dict[str, Any]] = Field(default_factory=list, description="Structured tool trace.")
total_reward: float = Field(default=0.0, description="Cumulative reward for the episode.")
done: bool = Field(default=False, description="Whether the episode is complete.")
|