"""Broker-app style debugging UI for the Flatmate RL environment.""" from __future__ import annotations import html import json import logging from typing import Any try: import gradio as gr except ImportError: # pragma: no cover gr = None from openenv.core.env_server.serialization import serialize_observation try: from ..env_config import load_repo_env from .scenarios import POSTS, SCENARIOS from .heuristic_policy import autopolicy_next_request except ImportError: from env_config import load_repo_env from server.scenarios import POSTS, SCENARIOS from server.heuristic_policy import autopolicy_next_request load_repo_env() BROKER_MODELS = ["heuristic_debug_policy"] USER_MODELS = ["openenv_builtin_user"] DEFAULT_BROKER_MODEL = BROKER_MODELS[0] DEFAULT_USER_MODEL = USER_MODELS[0] CHATBOT_USES_MESSAGES = True logger = logging.getLogger("flatmate_rl.web") if not logger.handlers: _handler = logging.StreamHandler() _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s [%(name)s] %(message)s")) logger.addHandler(_handler) logger.setLevel(logging.INFO) logger.propagate = False CUSTOM_CSS = """ .panel-surface { border: 1px solid var(--border-color-primary); border-radius: 8px; background: var(--background-fill-secondary); color: var(--body-text-color); } .muted-text { color: var(--body-text-color-subdued); } .task-card { padding: 14px 16px; margin: 8px 0 14px; } .task-card__head { display: flex; align-items: center; gap: 10px; margin-bottom: 8px; } .task-card__title { font-size: 18px; font-weight: 800; color: var(--body-text-color); } .task-card__pill { font-size: 12px; font-weight: 700; color: var(--body-text-color); background: var(--block-background-fill); border: 1px solid var(--border-color-primary); border-radius: 999px; padding: 2px 8px; } .task-card__objective { font-size: 14px; margin-bottom: 10px; } .task-card__grid { display: grid; grid-template-columns: repeat(3, minmax(0, 1fr)); gap: 10px; font-size: 13px; } .task-card__section-title { font-weight: 700; color: var(--body-text-color); margin-bottom: 4px; } .status-card { padding: 12px; } .status-card pre { white-space: pre-wrap; margin: 8px 0 0; font-size: 13px; color: inherit; } .score-reasons, .score-awards { margin-top: 10px; font-size: 13px; } .score-reasons__title, .score-awards__title { font-weight: 800; margin-bottom: 4px; } .score-reasons ul, .score-awards ul { margin: 4px 0 0 18px; padding: 0; } .score-reasons li, .score-awards li { margin: 3px 0; } .status-card__label { font-size: 14px; font-weight: 700; margin-bottom: 6px; } .status-card__score { font-size: 28px; font-weight: 800; } .status-pass { background: color-mix(in srgb, var(--color-green-500, #22c55e) 16%, var(--background-fill-primary)); color: var(--body-text-color); border-color: color-mix(in srgb, var(--color-green-500, #22c55e) 45%, var(--border-color-primary)); } .status-warn { background: color-mix(in srgb, var(--color-yellow-500, #eab308) 18%, var(--background-fill-primary)); color: var(--body-text-color); border-color: color-mix(in srgb, var(--color-yellow-500, #eab308) 45%, var(--border-color-primary)); } .status-fail { background: color-mix(in srgb, var(--color-red-500, #ef4444) 16%, var(--background-fill-primary)); color: var(--body-text-color); border-color: color-mix(in srgb, var(--color-red-500, #ef4444) 45%, var(--border-color-primary)); } .model-status { padding: 8px 10px; } .final-banner { margin-top: 10px; padding: 14px; font-size: 22px; font-weight: 800; } .final-banner__reason { font-size: 13px; font-weight: 500; } @media (max-width: 900px) { .task-card__grid { grid-template-columns: 1fr; } } """ def _task_choices() -> list[tuple[str, str]]: return [(scenario["label"], scenario_id) for scenario_id, scenario in SCENARIOS.items()] def _active_scenario(task_id: str) -> dict[str, Any]: return SCENARIOS[task_id] def _serialize_reset(web_manager, observation): serialized = serialize_observation(observation) state = web_manager.env.state web_manager.episode_state.episode_id = state.episode_id web_manager.episode_state.step_count = state.step_count web_manager.episode_state.current_observation = serialized["observation"] web_manager.episode_state.action_logs = [] web_manager.episode_state.is_reset = True return serialized def _task_definition_html(task_id: str) -> str: scenario = _active_scenario(task_id) truth = scenario["ground_truth"] return ( '
' '
' f'
{html.escape(scenario["label"])}
' f'
{html.escape(scenario["difficulty"])}
' "
" f'
{html.escape(scenario["description"])}
' '
' '
Bookings
' f'
{truth["required_bookings"]}
' '
Required Tools
' f'
{html.escape(", ".join(truth["required_tool_calls"]))}
' '
Success Condition
' f'
{html.escape(truth["success_condition"])}
' "
" ) def _chatbot_rows(history: list[dict[str, Any]]) -> list[Any]: logger.info( "chatbot_rows:start uses_messages=%s history_len=%s sample_roles=%s", CHATBOT_USES_MESSAGES, len(history), [str(entry.get("role", "user")) for entry in history[:5]], ) rows = [ { "role": "assistant" if str(entry.get("role", "user")) == "assistant" else "user", "content": str(entry.get("content", "")), } for entry in history ] logger.info("chatbot_rows:done format=messages rows_len=%s", len(rows)) return rows def _build_chatbot(*, label: str, height: int): global CHATBOT_USES_MESSAGES logger.info("build_chatbot:start label=%s height=%s", label, height) try: chatbot = gr.Chatbot(label=label, type="messages", height=height) CHATBOT_USES_MESSAGES = True except TypeError: logger.warning("build_chatbot:type_messages_unsupported label=%s", label) chatbot = gr.Chatbot(label=label, height=height) # Gradio 6 in Docker still validates fallback chatbots using message dicts. CHATBOT_USES_MESSAGES = True logger.info( "build_chatbot:done label=%s detected_type=%s uses_messages=%s", label, getattr(chatbot, "type", None), CHATBOT_USES_MESSAGES, ) return chatbot def _user_data_rows(task_id: str, observation: dict[str, Any]) -> list[list[Any]]: scenario = _active_scenario(task_id) buyer = scenario["scenario_creation_config"]["expected_answers"] rows = [ [ f"buyer_{task_id}", scenario["description"], "buyer", buyer.get("user_sub_type", "flat"), buyer.get("location_pref_type", ""), ", ".join(buyer.get("areas", [])), buyer.get("budget_max"), None, buyer.get("price_range_negotiable"), buyer.get("is_price_range_fixed"), f"dietary={buyer.get('dietary')}; occupation={buyer.get('occupation')}", ] ] if scenario.get("seller_profile"): seller = scenario["seller_profile"] rows.append( [ "seller_post_dynamic_followup_1", seller.get("description", ""), "seller", "flat", "specific_area", seller.get("area", ""), None, seller.get("rent"), False, True, f"dietary={seller.get('dietary')}; fit={seller.get('occupation_requirement')}", ] ) return rows def _user_data_json(task_id: str, observation: dict[str, Any]) -> dict[str, Any]: scenario = _active_scenario(task_id) payload: dict[str, Any] = { f"buyer_{task_id}": scenario["scenario_creation_config"]["expected_answers"], } if scenario.get("seller_profile"): payload["seller_post_dynamic_followup_1"] = scenario["seller_profile"] payload["observation_flags"] = { "buyer_profile_stored": observation.get("buyer_profile_stored", False), "seller_profile_stored": observation.get("seller_profile_stored", False), } return payload def _storage_rows(task_id: str, observation: dict[str, Any]) -> list[list[Any]]: scenario = _active_scenario(task_id) truth = scenario["ground_truth"] buyer_stored = bool(observation.get("buyer_profile_stored")) seller_stored = bool(observation.get("seller_profile_stored")) rows = [ [ task_id, "stored" if buyer_stored else "pending", f"buyer_{task_id}", ", ".join(truth["required_info"]), "", ", ".join(truth["required_info"]) if buyer_stored else "", "", "{}", "" if buyer_stored else "buyer profile not stored yet", ] ] if scenario.get("seller_profile"): rows.append( [ task_id, "stored" if seller_stored else "pending", "seller_post_dynamic_followup_1", "area, rent, dietary, listing_type, occupation_requirement, calendar_slots", "", "area, rent, dietary, listing_type, occupation_requirement, calendar_slots" if seller_stored else "", "", "{}", "" if seller_stored else "seller profile not stored yet", ] ) return rows def _storage_log_html(task_id: str, observation: dict[str, Any]) -> str: headers = [ "task_id", "status", "buyer_user_id", "inserted_fields", "skipped_fields", "matched_expected_fields", "mismatched_expected_fields", "ignored_tool_args", "failure_reason", ] rows = _storage_rows(task_id, observation) head_html = "".join(f"{html.escape(header)}" for header in headers) body_html = "".join( "" + "".join(f"{html.escape('' if value is None else str(value))}" for value in row) + "" for row in rows ) return ( '
' '
' '' f'{head_html}' f"{body_html}" "
" "
" "
" ) def _html_table(headers: list[str], rows: list[list[Any]]) -> str: head_html = "".join(f"{html.escape(header)}" for header in headers) body_html = "".join( "" + "".join(f"{html.escape('' if value is None else str(value))}" for value in row) + "" for row in rows ) return ( '
' '
' '' f'{head_html}' f"{body_html}" "
" "
" "
" ) def _json_html(value: Any) -> str: return ( '
' f"
{html.escape(json.dumps(value, indent=2, sort_keys=True))}
" "
" ) def _user_data_explorer_html(task_id: str, observation: dict[str, Any]) -> str: return _html_table( [ "user_id", "user_description", "user_type", "user_sub_type", "location_mode", "location_value", "buyer_max_budget", "seller_min_price", "price_range_negotiable", "is_price_range_fixed", "optional_preferences", ], _user_data_rows(task_id, observation), ) def _scenario_check_rows(task_id: str, observation: dict[str, Any]) -> list[list[Any]]: scenario = _active_scenario(task_id) buyer = scenario["scenario_creation_config"]["expected_answers"] rows = [] checked_slots = { trace["tool"]: trace for trace in observation.get("tool_trace", []) } for post_id in scenario["task_post_ids"]: post = POSTS[post_id] match_location = post["area"] in buyer.get("areas", []) match_budget = post["rent"] <= buyer.get("budget_max", 0) match_diet = not (buyer.get("dietary") == "non-veg" and post["diet"] == "veg only") score = sum([match_location, match_budget, match_diet]) / 3 status = "compatible" if score >= 1 else "partial" if score > 0 else "incompatible" reasons = [] if not match_location: reasons.append("area mismatch") if not match_budget: reasons.append("over budget") if not match_diet: reasons.append("diet mismatch") if not reasons: reasons.append("matches scenario constraints") rows.append( [ post_id, status, round(score, 2), f"selected={'yes' if post_id in observation.get('selected_posts', []) else 'no'}; booked={'yes' if any(item['post_id']==post_id for item in observation.get('booked_visits', [])) else 'no'}", "; ".join(reasons), ] ) return rows def _visit_scheduler_rows(task_id: str, observation: dict[str, Any]) -> list[list[Any]]: rows = [] for item in observation.get("booked_visits", []): rows.append( [ f"buyer_{task_id}", f"seller_{item['post_id']}", item["post_id"], item["time"], item["time"], item["time"], "scheduled", ] ) return rows def _scenario_checks_html(task_id: str, observation: dict[str, Any]) -> str: return _html_table( ["Post ID", "Scenario Status", "Match Score", "State Flags", "Reasons"], _scenario_check_rows(task_id, observation), ) def _visit_scheduler_html(task_id: str, observation: dict[str, Any]) -> str: return _html_table( [ "buyer_user_id", "seller_user_id", "post_id", "scheduled_date", "start_time", "end_time", "status", ], _visit_scheduler_rows(task_id, observation), ) def _score_html(observation: dict[str, Any]) -> str: total_reward = float(observation.get("total_reward", 0.0)) step_reward = float(observation.get("step_reward", 0.0)) violations = observation.get("violations", []) booked = observation.get("booked_visits", []) if observation.get("done"): status_class = "status-pass" elif violations: status_class = "status-fail" else: status_class = "status-warn" point_cuts = violations or ["No explicit point cuts recorded."] awards = [] if observation.get("buyer_profile_stored"): awards.append("Buyer profile stored") if observation.get("seller_profile_stored"): awards.append("Seller profile stored") if booked: awards.extend(f"Booked {item['post_id']} at {item['time']}" for item in booked) if not awards: awards.append("No positive milestones yet.") award_html = "".join(f"
  • {html.escape(item)}
  • " for item in awards) cut_html = "".join(f"
  • {html.escape(item)}
  • " for item in point_cuts) return ( f'
    ' '
    Episode Reward
    ' f'
    {total_reward:.2f}
    ' f'
    step_reward={step_reward:.2f}\nbookings={len(booked)}\nviolations={len(violations)}
    ' '
    Positive Reasons
    " '
    Point Cuts
    " "
    " ) def _tool_log_rows(observation: dict[str, Any]) -> list[list[Any]]: rows = [] for trace in observation.get("tool_trace", []): rows.append( [ trace.get("step", ""), trace.get("tool", ""), _format_short_json(trace.get("args", {})), trace.get("message", ""), ] ) return rows def _episode_status_dump(observation: dict[str, Any]) -> dict[str, Any]: return { "status": observation.get("status", "ready"), "phase": observation.get("phase", "buyer"), "done": observation.get("done", False), "scenario_id": observation.get("scenario_id", ""), "step_reward": observation.get("step_reward", 0.0), "total_reward": observation.get("total_reward", 0.0), "last_tool_result": observation.get("last_tool_result", {}), } def _live_env_dump(observation: dict[str, Any]) -> dict[str, Any]: return observation def _post_rows(task_id: str, observation: dict[str, Any]) -> list[list[Any]]: rows = [] selected = set(observation.get("selected_posts", [])) booked = {item["post_id"] for item in observation.get("booked_visits", [])} for post_id in _active_scenario(task_id)["task_post_ids"]: post = POSTS[post_id] status = "booked" if post_id in booked else "selected" if post_id in selected else "available" rows.append([post["id"], post["area"], post["rent"], post["diet"], post["type"], post["commute_to_goregaon_mins"], status]) return rows def _model_status_html(broker_model: str, user_model: str) -> str: return f'
    Broker policy: {html.escape(broker_model)} | User simulator: {html.escape(user_model)}
    ' def _final_banner_html(observation: dict[str, Any]) -> str: if not observation.get("done"): return "" booked = observation.get("booked_visits", []) reason = "Episode completed." if booked: reason = "Booked visits: " + ", ".join(f"{item['post_id']} at {item['time']}" for item in booked) return f'
    Complete
    {html.escape(reason)}
    ' def _format_short_json(payload: Any) -> str: text = json.dumps(payload if payload not in (None, "") else {}, ensure_ascii=False) return text if len(text) <= 120 else text[:117] + "..." def _tool_choice_update(observation: dict[str, Any]): tools = observation.get("available_tools", []) return gr.update(choices=tools, value=(tools[0] if tools else None)) def _seller_chat_label(task_id: str, post_id: str | None = None) -> str: if post_id: return f"seller_id=seller_{post_id} broker chat" return f"seller_id=seller_{task_id} broker chat" def _extract_contacted_post_ids(observation: dict[str, Any]) -> list[str]: post_ids: list[str] = [] for trace in observation.get("tool_trace", []): args = trace.get("args", {}) post_id = args.get("post_id") if isinstance(post_id, str) and post_id not in post_ids: post_ids.append(post_id) for item in observation.get("booked_visits", []): post_id = item.get("post_id") if isinstance(post_id, str) and post_id not in post_ids: post_ids.append(post_id) return post_ids def _default_ui_state(task_id: str, observation: dict[str, Any] | None = None) -> dict[str, Any]: return { "task_id": task_id, "buyer_chat_filter": "__latest__", "post_chat_filter": "__active__", "observation": observation or {}, } def _normalize_ui_state(task_id: str, observation: dict[str, Any], ui_state: dict[str, Any] | None) -> dict[str, Any]: state = dict(ui_state or {}) if state.get("task_id") != task_id: state = _default_ui_state(task_id, observation) else: state.setdefault("buyer_chat_filter", "__latest__") state.setdefault("post_chat_filter", "__active__") state["task_id"] = task_id state["observation"] = observation return state def _buyer_chat_filter_choices(observation: dict[str, Any]) -> list[tuple[str, str]]: choices = [("Latest buyer chat", "__latest__")] if observation.get("buyer_conversation_history"): choices.append(("Buyer Chat 1", "buyer_chat_1")) return choices def _post_chat_filter_choices(task_id: str, observation: dict[str, Any]) -> list[tuple[str, str]]: choices: list[tuple[str, str]] = [] if observation.get("seller_conversation_history"): choices.append((_seller_chat_label(task_id), "__seller_lead__")) choices.append(("latest seller/post broker chat", "__active__")) for post_id in _extract_contacted_post_ids(observation): choices.append((_seller_chat_label(task_id, post_id), post_id)) return choices def _buyer_chat_filter_update(observation: dict[str, Any], ui_state: dict[str, Any]): choices = _buyer_chat_filter_choices(observation) valid_values = {value for _, value in choices} selected = ui_state.get("buyer_chat_filter") or "__latest__" if selected not in valid_values: selected = "__latest__" ui_state["buyer_chat_filter"] = selected return gr.update(choices=choices, value=selected) def _post_chat_filter_update(task_id: str, observation: dict[str, Any], ui_state: dict[str, Any]): choices = _post_chat_filter_choices(task_id, observation) valid_values = {value for _, value in choices} selected = ui_state.get("post_chat_filter") or "__active__" if selected not in valid_values: selected = "__active__" ui_state["post_chat_filter"] = selected return gr.update(choices=choices, value=selected) def _filtered_buyer_chat(observation: dict[str, Any], ui_state: dict[str, Any]): selected = ui_state.get("buyer_chat_filter") or "__latest__" label = "User ↔ Broker" if selected != "__latest__" and observation.get("buyer_conversation_history"): label = "User ↔ Broker (Buyer Chat 1)" return gr.update(value=_chatbot_rows(observation.get("buyer_conversation_history", [])), label=label) def _current_post_chat(task_id: str, observation: dict[str, Any], ui_state: dict[str, Any]): selected = ui_state.get("post_chat_filter") or "__active__" if selected == "__seller_lead__": return gr.update( value=_chatbot_rows(observation.get("seller_conversation_history", [])), label=_seller_chat_label(task_id), ) if selected != "__active__": return gr.update( value=_chatbot_rows(observation.get("seller_conversation_history", [])), label=_seller_chat_label(task_id, selected), ) active_post = _extract_contacted_post_ids(observation) label = _seller_chat_label(task_id, active_post[-1]) if active_post else "latest seller/post broker chat" return gr.update(value=_chatbot_rows(observation.get("seller_conversation_history", [])), label=label) def _ui_values(task_id: str, broker_model: str, user_model: str, payload: dict[str, Any], ui_state: dict[str, Any] | None = None) -> tuple[Any, ...]: observation = payload.get("observation", {}) ui_state = _normalize_ui_state(task_id, observation, ui_state) logger.info( "ui_values:start task_id=%s broker_model=%s user_model=%s done=%s phase=%s buyer_history=%s seller_history=%s tools=%s violations=%s", task_id, broker_model, user_model, observation.get("done"), observation.get("phase"), len(observation.get("buyer_conversation_history", [])), len(observation.get("seller_conversation_history", [])), len(observation.get("tool_trace", [])), len(observation.get("violations", [])), ) violations_text = "\n".join(observation.get("violations", [])) if observation.get("violations") else "No violations detected." done = bool(observation.get("done")) next_btn_update = gr.update(interactive=not done) full_btn_update = gr.update(interactive=not done) values = ( _task_definition_html(task_id), _buyer_chat_filter_update(observation, ui_state), _filtered_buyer_chat(observation, ui_state), _post_chat_filter_update(task_id, observation, ui_state), _current_post_chat(task_id, observation, ui_state), _user_data_explorer_html(task_id, observation), _json_html(_user_data_json(task_id, observation)), _storage_log_html(task_id, observation), _scenario_checks_html(task_id, observation), _visit_scheduler_html(task_id, observation), _score_html(observation), _tool_log_rows(observation), _episode_status_dump(observation), _live_env_dump(observation), violations_text, _post_rows(task_id, observation), _model_status_html(broker_model, user_model), _final_banner_html(observation), next_btn_update, full_btn_update, _tool_choice_update(observation), ui_state, ) logger.info("ui_values:done task_id=%s outputs=%s", task_id, len(values)) return values def build_flatmate_gradio_app(web_manager, action_fields, metadata, is_chat_env, title, quick_start_md): del action_fields, metadata, is_chat_env, title, quick_start_md if gr is None: # pragma: no cover raise ImportError("gradio is required to build the Flatmate UI.") default_task_id = next(iter(SCENARIOS)) logger.info("build_flatmate_gradio_app:start") with gr.Blocks(title="FlatmateEnv - Visit Scheduling Simulator") as demo: app_state = gr.State(_default_ui_state(default_task_id, {})) gr.HTML(f"") gr.Markdown( """ # 🏠 FlatmateEnv — Visit Scheduling Simulator Multi-agent flatmate visit scheduling evaluation """ ) model_status = gr.HTML(value=_model_status_html(DEFAULT_BROKER_MODEL, DEFAULT_USER_MODEL), label="Model Status") task_definition = gr.HTML(value=_task_definition_html(default_task_id), label="Task Definition") gr.Markdown("### Model Configuration") with gr.Row(): task_dropdown = gr.Dropdown(_task_choices(), label="Task", value=default_task_id) broker_model = gr.Dropdown(BROKER_MODELS, label="Broker Model", value=DEFAULT_BROKER_MODEL) user_model = gr.Dropdown(USER_MODELS, label="User Model", value=DEFAULT_USER_MODEL) with gr.Row(): with gr.Column(scale=6): gr.Markdown("## Simulation") with gr.Row(): with gr.Column(scale=1): buyer_chat_filter = gr.Dropdown( choices=[("Latest buyer chat", "__latest__")], value="__latest__", label="Buyer-Broker Chats", ) chatbot = _build_chatbot(label="User ↔ Broker", height=500) with gr.Column(scale=1): post_chat_filter = gr.Dropdown( choices=[("latest seller/post broker chat", "__active__")], value="__active__", label="Seller-Broker Chats", ) post_agent_chat = _build_chatbot(label="Broker ↔ Post Owner", height=500) with gr.Tabs(): with gr.Tab("User Data Explorer"): user_data_table = gr.HTML( value=_user_data_explorer_html(default_task_id, {}), label="User Data Explorer", ) user_data_json = gr.HTML( value=_json_html(_user_data_json(default_task_id, {})), label="User Detail Documents", ) with gr.Tab("User Detail Storage Log"): user_storage_log = gr.HTML( value=_storage_log_html(default_task_id, {}), label="User Detail Storage Log", ) with gr.Tab("Scenario Checks"): scenario_checks = gr.HTML( value=_scenario_checks_html(default_task_id, {}), label="Scenario Checks", ) with gr.Tab("Visit Scheduler"): visit_scheduler = gr.HTML( value=_visit_scheduler_html(default_task_id, {}), label="Visit Scheduler", ) with gr.Row(): next_btn = gr.Button("▶ Run Next Turn", variant="primary") full_btn = gr.Button("⚡ Run Full Simulation") with gr.Row(): reset_btn = gr.Button("🔄 Reset") tool_name = gr.Dropdown(choices=[], label="Tool", visible=False) with gr.Column(scale=4): gr.Markdown("## Live Evaluation") gr.Markdown("### Score Panel") score_panel = gr.HTML(value=_score_html({})) gr.Markdown("### Constraint Violations") violations = gr.Textbox(label="Violations & Flags", value="No violations detected.", lines=4, interactive=False) final_banner = gr.HTML(value="") gr.Markdown("### Tool Call Log") tool_log = gr.Dataframe( headers=["Turn", "Tool", "Args Summary", "Result Summary"], value=[], interactive=False, ) gr.Markdown("### Episode Status") episode_status = gr.JSON(label="Episode Failure / Status", value=_episode_status_dump({})) gr.Markdown("### Environment State") live_env = gr.JSON(label="Live Env State", value={}) with gr.Accordion("Post Database Reference", open=False): post_db = gr.Dataframe( headers=["ID", "Area", "Rent", "Diet", "Type", "Commute_Goregaon", "Status"], value=_post_rows(default_task_id, {}), interactive=False, ) common_outputs = [ task_definition, buyer_chat_filter, chatbot, post_chat_filter, post_agent_chat, user_data_table, user_data_json, user_storage_log, scenario_checks, visit_scheduler, score_panel, tool_log, episode_status, live_env, violations, post_db, model_status, final_banner, next_btn, full_btn, tool_name, app_state, ] async def reset_simulation(task_id: str, broker_model_name: str, user_model_name: str, ui_state: dict[str, Any]): logger.info( "callback:reset:start task_id=%s broker_model=%s user_model=%s", task_id, broker_model_name, user_model_name, ) try: observation = await web_manager._run_sync_in_thread_pool(web_manager.env.reset, scenario_id=task_id) logger.info("callback:reset:after_env_reset task_id=%s", task_id) serialized = _serialize_reset(web_manager, observation) logger.info( "callback:reset:serialized task_id=%s obs_keys=%s", task_id, sorted(serialized.get("observation", {}).keys()), ) await web_manager._send_state_update() logger.info("callback:reset:after_state_update task_id=%s", task_id) return _ui_values(task_id, broker_model_name, user_model_name, serialized, ui_state) except Exception: logger.exception("callback:reset:error task_id=%s", task_id) raise async def run_manual_message(task_id: str, broker_model_name: str, user_model_name: str, message: str): logger.info( "callback:manual_message:start task_id=%s message_len=%s", task_id, len(message or ""), ) try: request_payload = {"action_type": "assistant_message", "assistant_message": message} logger.info("callback:manual_message:request task_id=%s payload=%s", task_id, request_payload) serialized = await web_manager.step_environment(request_payload) logger.info( "callback:manual_message:after_step task_id=%s done=%s phase=%s", task_id, serialized.get("observation", {}).get("done"), serialized.get("observation", {}).get("phase"), ) return _ui_values(task_id, broker_model_name, user_model_name, serialized) except Exception: logger.exception("callback:manual_message:error task_id=%s", task_id) raise async def run_manual_tool(task_id: str, broker_model_name: str, user_model_name: str, selected_tool: str, raw_arguments: str): logger.info( "callback:manual_tool:start task_id=%s tool=%s raw_arguments=%s", task_id, selected_tool, raw_arguments, ) try: parsed_arguments = json.loads(raw_arguments or "{}") request_payload = { "action_type": "tool_call", "tool_name": selected_tool, "tool_arguments": parsed_arguments, } logger.info("callback:manual_tool:request task_id=%s payload=%s", task_id, request_payload) serialized = await web_manager.step_environment(request_payload) logger.info( "callback:manual_tool:after_step task_id=%s done=%s phase=%s", task_id, serialized.get("observation", {}).get("done"), serialized.get("observation", {}).get("phase"), ) return _ui_values(task_id, broker_model_name, user_model_name, serialized) except Exception: logger.exception("callback:manual_tool:error task_id=%s tool=%s", task_id, selected_tool) raise async def run_next_turn(task_id: str, broker_model_name: str, user_model_name: str, ui_state: dict[str, Any]): logger.info("callback:next_turn:start task_id=%s", task_id) try: current = dict(web_manager.episode_state.current_observation or {}) logger.info( "callback:next_turn:current task_id=%s has_current=%s current_scenario=%s done=%s phase=%s", task_id, bool(current), current.get("scenario_id"), current.get("done"), current.get("phase"), ) if not current or current.get("scenario_id") != task_id: observation = await web_manager._run_sync_in_thread_pool(web_manager.env.reset, scenario_id=task_id) logger.info("callback:next_turn:after_env_reset task_id=%s", task_id) serialized = _serialize_reset(web_manager, observation) await web_manager._send_state_update() current = serialized["observation"] request_payload = autopolicy_next_request(task_id, current) logger.info("callback:next_turn:autopolicy task_id=%s payload=%s", task_id, request_payload) if request_payload is None: return _ui_values(task_id, broker_model_name, user_model_name, {"observation": current}, ui_state) serialized = await web_manager.step_environment(request_payload) logger.info( "callback:next_turn:after_step task_id=%s done=%s phase=%s", task_id, serialized.get("observation", {}).get("done"), serialized.get("observation", {}).get("phase"), ) return _ui_values(task_id, broker_model_name, user_model_name, serialized, ui_state) except Exception: logger.exception("callback:next_turn:error task_id=%s", task_id) raise async def run_full_simulation(task_id: str, broker_model_name: str, user_model_name: str, ui_state: dict[str, Any]): logger.info("callback:full_simulation:start task_id=%s", task_id) try: current = dict(web_manager.episode_state.current_observation or {}) logger.info( "callback:full_simulation:current task_id=%s has_current=%s current_scenario=%s done=%s phase=%s", task_id, bool(current), current.get("scenario_id"), current.get("done"), current.get("phase"), ) if not current or current.get("scenario_id") != task_id: observation = await web_manager._run_sync_in_thread_pool(web_manager.env.reset, scenario_id=task_id) logger.info("callback:full_simulation:after_env_reset task_id=%s", task_id) serialized = _serialize_reset(web_manager, observation) await web_manager._send_state_update() current = serialized["observation"] last_payload = {"observation": current} for step_index in range(20): current = last_payload["observation"] logger.info( "callback:full_simulation:loop task_id=%s step_index=%s done=%s phase=%s", task_id, step_index, current.get("done"), current.get("phase"), ) if current.get("done"): break request_payload = autopolicy_next_request(task_id, current) logger.info( "callback:full_simulation:autopolicy task_id=%s step_index=%s payload=%s", task_id, step_index, request_payload, ) if request_payload is None: break last_payload = await web_manager.step_environment(request_payload) logger.info("callback:full_simulation:done task_id=%s", task_id) return _ui_values(task_id, broker_model_name, user_model_name, last_payload, ui_state) except Exception: logger.exception("callback:full_simulation:error task_id=%s", task_id) raise def set_buyer_chat_filter(ui_state: dict[str, Any], selection: str): ui_state = dict(ui_state or {}) ui_state["buyer_chat_filter"] = selection or "__latest__" observation = ui_state.get("observation", {}) task_id = ui_state.get("task_id", default_task_id) ui_state = _normalize_ui_state(task_id, observation, ui_state) return _filtered_buyer_chat(observation, ui_state), _buyer_chat_filter_update(observation, ui_state), ui_state def set_post_chat_filter(ui_state: dict[str, Any], selection: str): ui_state = dict(ui_state or {}) ui_state["post_chat_filter"] = selection or "__active__" observation = ui_state.get("observation", {}) task_id = ui_state.get("task_id", default_task_id) ui_state = _normalize_ui_state(task_id, observation, ui_state) return _current_post_chat(task_id, observation, ui_state), _post_chat_filter_update(task_id, observation, ui_state), ui_state reset_btn.click(reset_simulation, inputs=[task_dropdown, broker_model, user_model, app_state], outputs=common_outputs) next_btn.click(run_next_turn, inputs=[task_dropdown, broker_model, user_model, app_state], outputs=common_outputs) full_btn.click(run_full_simulation, inputs=[task_dropdown, broker_model, user_model, app_state], outputs=common_outputs) buyer_chat_filter.change(set_buyer_chat_filter, inputs=[app_state, buyer_chat_filter], outputs=[chatbot, buyer_chat_filter, app_state]) post_chat_filter.change(set_post_chat_filter, inputs=[app_state, post_chat_filter], outputs=[post_agent_chat, post_chat_filter, app_state]) logger.info("build_flatmate_gradio_app:done") return demo