Spaces:
Sleeping
Sleeping
Add step-2 GRPO notebook and hidden-flex fix
Browse files- inference.py +32 -0
- server/episode.py +49 -7
- tests/test_flatmate_rl.py +16 -1
- tests/test_reward_regression.py +1 -1
- train_flatmate_rl_grpo_step2.ipynb +205 -0
inference.py
CHANGED
|
@@ -71,6 +71,30 @@ SYSTEM_PROMPT = textwrap.dedent(
|
|
| 71 |
- If a tool can perform the next required operation, call the tool immediately.
|
| 72 |
- Do not send acknowledgement or progress messages such as "I will search now" when a tool call is needed.
|
| 73 |
- Prefer safe, incremental progress toward storing user details, matching listings, and booking visits.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
"""
|
| 75 |
).strip()
|
| 76 |
|
|
@@ -371,12 +395,20 @@ def build_user_prompt(step: int, observation: Any) -> str:
|
|
| 371 |
Step: {step}
|
| 372 |
Phase: {observation.phase}
|
| 373 |
Status: {observation.status}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
Available tools: {observation.available_tools}
|
| 375 |
Last tool result: {json.dumps(last_tool_result, ensure_ascii=False)}
|
| 376 |
Prerequisites satisfied: {json.dumps(observation.prerequisites_satisfied, ensure_ascii=False)}
|
| 377 |
Recent tool calls: {json.dumps(observation.recent_tool_calls, ensure_ascii=False)}
|
| 378 |
Booked visits: {observation.booked_visits}
|
| 379 |
|
|
|
|
|
|
|
| 380 |
Buyer/Broker transcript:
|
| 381 |
{json.dumps(observation.buyer_conversation_history[-8:], ensure_ascii=False)}
|
| 382 |
|
|
|
|
| 71 |
- If a tool can perform the next required operation, call the tool immediately.
|
| 72 |
- Do not send acknowledgement or progress messages such as "I will search now" when a tool call is needed.
|
| 73 |
- Prefer safe, incremental progress toward storing user details, matching listings, and booking visits.
|
| 74 |
+
- Use exact tool argument names from the prompt. Never invent aliases such as visit_time.
|
| 75 |
+
- Treat negative reward, violations, and feedback_summary as corrective feedback for the next action.
|
| 76 |
+
"""
|
| 77 |
+
).strip()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
TOOL_CONTRACT_PROMPT = textwrap.dedent(
|
| 81 |
+
"""
|
| 82 |
+
Tool argument contract:
|
| 83 |
+
- store_user_details: tool_arguments can be {} after required buyer fields are gathered.
|
| 84 |
+
- search_posts: tool_arguments can be {}.
|
| 85 |
+
- match_location_preference: {"post_ids":["post_id", ...]}.
|
| 86 |
+
- get_commute_time: {"post_ids":["post_id", ...]}.
|
| 87 |
+
- check_calendar_slots: {"post_ids":["post_id", ...]}.
|
| 88 |
+
- shortlist: {"post_ids":["post_id", ...]}.
|
| 89 |
+
- contact_poster: {"post_id":"post_id","time_text":"exact slot from check_calendar_slots"}. This shares the buyer profile with the seller/poster and asks them to confirm profile fit plus visit time.
|
| 90 |
+
- book_viewing: {"post_id":"post_id","time_text":"same exact slot confirmed by buyer and poster"}.
|
| 91 |
+
|
| 92 |
+
Booking workflow:
|
| 93 |
+
1. Ask for missing buyer fields before store_user_details.
|
| 94 |
+
2. Store buyer details, search posts, match location, get commute time, then check calendar slots.
|
| 95 |
+
3. Ask the buyer to confirm one exact slot from check_calendar_slots.
|
| 96 |
+
4. Call contact_poster with post_id and time_text for that same slot.
|
| 97 |
+
5. Only after buyer_confirmed and poster_confirmed are true, call book_viewing with post_id and time_text.
|
| 98 |
"""
|
| 99 |
).strip()
|
| 100 |
|
|
|
|
| 395 |
Step: {step}
|
| 396 |
Phase: {observation.phase}
|
| 397 |
Status: {observation.status}
|
| 398 |
+
Feedback summary: {observation.feedback_summary}
|
| 399 |
+
Environment message: {observation.message}
|
| 400 |
+
Step reward: {observation.step_reward}
|
| 401 |
+
Total reward: {observation.total_reward}
|
| 402 |
+
Violations: {observation.violations}
|
| 403 |
+
Remaining required fields: {observation.remaining_required_fields}
|
| 404 |
Available tools: {observation.available_tools}
|
| 405 |
Last tool result: {json.dumps(last_tool_result, ensure_ascii=False)}
|
| 406 |
Prerequisites satisfied: {json.dumps(observation.prerequisites_satisfied, ensure_ascii=False)}
|
| 407 |
Recent tool calls: {json.dumps(observation.recent_tool_calls, ensure_ascii=False)}
|
| 408 |
Booked visits: {observation.booked_visits}
|
| 409 |
|
| 410 |
+
{TOOL_CONTRACT_PROMPT}
|
| 411 |
+
|
| 412 |
Buyer/Broker transcript:
|
| 413 |
{json.dumps(observation.buyer_conversation_history[-8:], ensure_ascii=False)}
|
| 414 |
|
server/episode.py
CHANGED
|
@@ -95,6 +95,7 @@ class FlatmateEpisode:
|
|
| 95 |
self._commutes_checked: dict[str, int] = {}
|
| 96 |
self._poster_confirmations: dict[str, str] = {}
|
| 97 |
self._client_confirmations: dict[str, str] = {}
|
|
|
|
| 98 |
self._seller_confirmations: dict[str, str] = {}
|
| 99 |
self._buyer_offer_confirmations: dict[str, str] = {}
|
| 100 |
self._dynamic_post_id: str | None = None
|
|
@@ -136,6 +137,7 @@ class FlatmateEpisode:
|
|
| 136 |
self._commutes_checked = {}
|
| 137 |
self._poster_confirmations = {}
|
| 138 |
self._client_confirmations = {}
|
|
|
|
| 139 |
self._seller_confirmations = {}
|
| 140 |
self._buyer_offer_confirmations = {}
|
| 141 |
self._dynamic_post_id = None
|
|
@@ -259,6 +261,12 @@ class FlatmateEpisode:
|
|
| 259 |
slots.extend(profile["hidden_additional_availability"])
|
| 260 |
return slots
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
def _record_violation(self, text: str) -> None:
|
| 263 |
if text not in self._violations:
|
| 264 |
self._violations.append(text)
|
|
@@ -496,8 +504,10 @@ class FlatmateEpisode:
|
|
| 496 |
self._state.gathered_fields.append("hidden_flex_revealed")
|
| 497 |
if alternatives_offered:
|
| 498 |
if "sunday 5pm" in lowered:
|
|
|
|
| 499 |
return "I can make Sunday 5pm work, so I confirm Sunday 5pm."
|
| 500 |
if "saturday 1pm" in lowered:
|
|
|
|
| 501 |
return "Saturday 1pm works for me too, so I confirm Saturday 1pm."
|
| 502 |
|
| 503 |
# Scenario 2: waitlist — fire cancellation notification on first message after add_to_waitlist
|
|
@@ -861,6 +871,16 @@ class FlatmateEpisode:
|
|
| 861 |
self._state.selected_posts = post_ids
|
| 862 |
return {"tool": "shortlist", "success": True, "message": "Posts shortlisted.", "selected_posts": post_ids}
|
| 863 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
def _tool_contact_poster(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
| 865 |
post_id = arguments.get("post_id", "")
|
| 866 |
time_text = arguments.get("time_text", "")
|
|
@@ -872,20 +892,34 @@ class FlatmateEpisode:
|
|
| 872 |
return {"tool": "contact_poster", "success": False, "message": "Time must come from check_calendar_slots."}
|
| 873 |
self._seller_history.append(
|
| 874 |
{
|
| 875 |
-
"role": "
|
| 876 |
-
"content":
|
|
|
|
|
|
|
|
|
|
| 877 |
}
|
| 878 |
)
|
| 879 |
self._poster_confirmations[post_id] = time_text
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
|
| 884 |
def _tool_book_viewing(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
| 885 |
post_id = arguments.get("post_id", "")
|
| 886 |
time_text = arguments.get("time_text", "")
|
| 887 |
if post_id not in self._poster_confirmations or self._poster_confirmations[post_id] != time_text:
|
| 888 |
return {"tool": "book_viewing", "success": False, "message": "Poster has not explicitly confirmed this time."}
|
|
|
|
|
|
|
| 889 |
if post_id not in self._client_confirmations or self._client_confirmations[post_id] != time_text:
|
| 890 |
return {"tool": "book_viewing", "success": False, "message": "Client has not explicitly confirmed this time."}
|
| 891 |
if self._scenario["task_id"] == "task_visit_multi" and post_id not in self._state.selected_posts:
|
|
@@ -940,8 +974,15 @@ class FlatmateEpisode:
|
|
| 940 |
config = self._scenario["scenario_creation_config"].get("negotiation_config", {})
|
| 941 |
seller_floor = config.get("seller_floor", 0)
|
| 942 |
self._negotiation_rounds_seller += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
if proposed_rent >= seller_floor:
|
| 944 |
self._seller_price_accepted = proposed_rent
|
|
|
|
| 945 |
return {
|
| 946 |
"tool": "propose_price_to_seller",
|
| 947 |
"success": True,
|
|
@@ -950,6 +991,7 @@ class FlatmateEpisode:
|
|
| 950 |
"proposed_rent": proposed_rent,
|
| 951 |
}
|
| 952 |
hint = " Maybe a small discount is possible." if self._negotiation_rounds_seller >= 2 else ""
|
|
|
|
| 953 |
return {
|
| 954 |
"tool": "propose_price_to_seller",
|
| 955 |
"success": True,
|
|
@@ -1167,9 +1209,9 @@ class FlatmateEpisode:
|
|
| 1167 |
post = self._resolve_post(post_id)
|
| 1168 |
if not post or time_text not in post["calendar_slots"]:
|
| 1169 |
return {"tool": "confirm_seller_match", "success": False, "message": "Selected seller slot is invalid."}
|
| 1170 |
-
self._seller_history.append({"role": "
|
| 1171 |
self._seller_confirmations[post_id] = time_text
|
| 1172 |
-
self._seller_history.append({"role": "
|
| 1173 |
return {"tool": "confirm_seller_match", "success": True, "message": f"Seller confirmed {time_text}.", "post_id": post_id, "time_text": time_text}
|
| 1174 |
|
| 1175 |
def _tool_offer_matched_listing_to_buyer(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
|
|
| 95 |
self._commutes_checked: dict[str, int] = {}
|
| 96 |
self._poster_confirmations: dict[str, str] = {}
|
| 97 |
self._client_confirmations: dict[str, str] = {}
|
| 98 |
+
self._seller_profile_fit_confirmations: dict[str, bool] = {}
|
| 99 |
self._seller_confirmations: dict[str, str] = {}
|
| 100 |
self._buyer_offer_confirmations: dict[str, str] = {}
|
| 101 |
self._dynamic_post_id: str | None = None
|
|
|
|
| 137 |
self._commutes_checked = {}
|
| 138 |
self._poster_confirmations = {}
|
| 139 |
self._client_confirmations = {}
|
| 140 |
+
self._seller_profile_fit_confirmations = {}
|
| 141 |
self._seller_confirmations = {}
|
| 142 |
self._buyer_offer_confirmations = {}
|
| 143 |
self._dynamic_post_id = None
|
|
|
|
| 261 |
slots.extend(profile["hidden_additional_availability"])
|
| 262 |
return slots
|
| 263 |
|
| 264 |
+
def _record_client_confirmation_for_slot(self, slot: str) -> None:
|
| 265 |
+
for post_id, checked_slots in self._slots_checked.items():
|
| 266 |
+
if slot in checked_slots:
|
| 267 |
+
self._client_confirmations[post_id] = slot
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
def _record_violation(self, text: str) -> None:
|
| 271 |
if text not in self._violations:
|
| 272 |
self._violations.append(text)
|
|
|
|
| 504 |
self._state.gathered_fields.append("hidden_flex_revealed")
|
| 505 |
if alternatives_offered:
|
| 506 |
if "sunday 5pm" in lowered:
|
| 507 |
+
self._record_client_confirmation_for_slot("Sunday 5pm")
|
| 508 |
return "I can make Sunday 5pm work, so I confirm Sunday 5pm."
|
| 509 |
if "saturday 1pm" in lowered:
|
| 510 |
+
self._record_client_confirmation_for_slot("Saturday 1pm")
|
| 511 |
return "Saturday 1pm works for me too, so I confirm Saturday 1pm."
|
| 512 |
|
| 513 |
# Scenario 2: waitlist — fire cancellation notification on first message after add_to_waitlist
|
|
|
|
| 871 |
self._state.selected_posts = post_ids
|
| 872 |
return {"tool": "shortlist", "success": True, "message": "Posts shortlisted.", "selected_posts": post_ids}
|
| 873 |
|
| 874 |
+
def _buyer_profile_summary_for_seller(self) -> str:
|
| 875 |
+
profile = self._scenario["buyer_profile"]
|
| 876 |
+
return (
|
| 877 |
+
f"buyer profile: budget up to Rs. {profile['budget_max']}; "
|
| 878 |
+
f"dietary preference {profile['dietary']}; "
|
| 879 |
+
f"preferred areas {', '.join(profile['areas'])}; "
|
| 880 |
+
f"occupation {profile['occupation']}; "
|
| 881 |
+
f"visit availability {', '.join(profile['visit_availability'])}"
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
def _tool_contact_poster(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
| 885 |
post_id = arguments.get("post_id", "")
|
| 886 |
time_text = arguments.get("time_text", "")
|
|
|
|
| 892 |
return {"tool": "contact_poster", "success": False, "message": "Time must come from check_calendar_slots."}
|
| 893 |
self._seller_history.append(
|
| 894 |
{
|
| 895 |
+
"role": "assistant",
|
| 896 |
+
"content": (
|
| 897 |
+
f"Client selected {post_id}. Please review this {self._buyer_profile_summary_for_seller()}. "
|
| 898 |
+
f"Can you confirm the buyer profile is acceptable and that we can visit at {time_text}?"
|
| 899 |
+
),
|
| 900 |
}
|
| 901 |
)
|
| 902 |
self._poster_confirmations[post_id] = time_text
|
| 903 |
+
self._seller_profile_fit_confirmations[post_id] = True
|
| 904 |
+
poster_message = f"Yes, confirmed. The buyer profile is acceptable and {time_text} works for the visit."
|
| 905 |
+
self._seller_history.append({"role": "user", "content": poster_message})
|
| 906 |
+
return {
|
| 907 |
+
"tool": "contact_poster",
|
| 908 |
+
"success": True,
|
| 909 |
+
"message": f"Poster confirmed buyer profile fit and {time_text}.",
|
| 910 |
+
"post_id": post_id,
|
| 911 |
+
"time_text": time_text,
|
| 912 |
+
"buyer_profile_shared": True,
|
| 913 |
+
"seller_profile_fit_confirmed": True,
|
| 914 |
+
}
|
| 915 |
|
| 916 |
def _tool_book_viewing(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
| 917 |
post_id = arguments.get("post_id", "")
|
| 918 |
time_text = arguments.get("time_text", "")
|
| 919 |
if post_id not in self._poster_confirmations or self._poster_confirmations[post_id] != time_text:
|
| 920 |
return {"tool": "book_viewing", "success": False, "message": "Poster has not explicitly confirmed this time."}
|
| 921 |
+
if not self._seller_profile_fit_confirmations.get(post_id):
|
| 922 |
+
return {"tool": "book_viewing", "success": False, "message": "Poster has not confirmed the buyer profile fit."}
|
| 923 |
if post_id not in self._client_confirmations or self._client_confirmations[post_id] != time_text:
|
| 924 |
return {"tool": "book_viewing", "success": False, "message": "Client has not explicitly confirmed this time."}
|
| 925 |
if self._scenario["task_id"] == "task_visit_multi" and post_id not in self._state.selected_posts:
|
|
|
|
| 974 |
config = self._scenario["scenario_creation_config"].get("negotiation_config", {})
|
| 975 |
seller_floor = config.get("seller_floor", 0)
|
| 976 |
self._negotiation_rounds_seller += 1
|
| 977 |
+
self._seller_history.append(
|
| 978 |
+
{
|
| 979 |
+
"role": "assistant",
|
| 980 |
+
"content": f"The buyer is interested in {post_id}. Would you accept Rs. {proposed_rent}?",
|
| 981 |
+
}
|
| 982 |
+
)
|
| 983 |
if proposed_rent >= seller_floor:
|
| 984 |
self._seller_price_accepted = proposed_rent
|
| 985 |
+
self._seller_history.append({"role": "user", "content": f"Yes, I can accept Rs. {proposed_rent}."})
|
| 986 |
return {
|
| 987 |
"tool": "propose_price_to_seller",
|
| 988 |
"success": True,
|
|
|
|
| 991 |
"proposed_rent": proposed_rent,
|
| 992 |
}
|
| 993 |
hint = " Maybe a small discount is possible." if self._negotiation_rounds_seller >= 2 else ""
|
| 994 |
+
self._seller_history.append({"role": "user", "content": f"I can't go as low as Rs. {proposed_rent}.{hint}"})
|
| 995 |
return {
|
| 996 |
"tool": "propose_price_to_seller",
|
| 997 |
"success": True,
|
|
|
|
| 1209 |
post = self._resolve_post(post_id)
|
| 1210 |
if not post or time_text not in post["calendar_slots"]:
|
| 1211 |
return {"tool": "confirm_seller_match", "success": False, "message": "Selected seller slot is invalid."}
|
| 1212 |
+
self._seller_history.append({"role": "assistant", "content": f"Can we confirm {time_text} for {post_id}?"})
|
| 1213 |
self._seller_confirmations[post_id] = time_text
|
| 1214 |
+
self._seller_history.append({"role": "user", "content": f"Confirmed, {time_text} works from the seller side."})
|
| 1215 |
return {"tool": "confirm_seller_match", "success": True, "message": f"Seller confirmed {time_text}.", "post_id": post_id, "time_text": time_text}
|
| 1216 |
|
| 1217 |
def _tool_offer_matched_listing_to_buyer(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
tests/test_flatmate_rl.py
CHANGED
|
@@ -200,8 +200,17 @@ def test_single_visit_scenario_books_one_visit() -> None:
|
|
| 200 |
assert final_obs.done is True
|
| 201 |
assert final_obs.booked_visits == [{"post_id": "post_023", "time": "Saturday 11am"}]
|
| 202 |
assert len(final_obs.seller_conversation_history) >= 2
|
| 203 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
assert "Saturday 11am works for the visit" in final_obs.seller_conversation_history[1]["content"]
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
def test_buyer_answers_diet_and_availability_when_broker_asks_for_both() -> None:
|
|
@@ -430,6 +439,10 @@ def test_hidden_flex_requires_alternative_slot_to_unlock_backup_availability() -
|
|
| 430 |
obs = _msg(env, "No Tuesday slot matches. I can offer Saturday 1pm or Sunday 5pm instead.")
|
| 431 |
assert "confirm" in obs.last_user_message.lower()
|
| 432 |
assert "Sunday 5pm" in obs.last_user_message or "Saturday 1pm" in obs.last_user_message
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
|
| 435 |
def test_multi_visit_scenario_books_two_visits() -> None:
|
|
@@ -533,3 +546,5 @@ def test_negotiation_heuristic_confirms_deal_with_agreed_rent() -> None:
|
|
| 533 |
assert obs.status == "completed"
|
| 534 |
assert obs.booked_visits == [{"post_id": "post_155", "time": "negotiated_deal", "agreed_rent": 21000}]
|
| 535 |
assert obs.last_tool_result["tool"] == "confirm_negotiated_deal"
|
|
|
|
|
|
|
|
|
| 200 |
assert final_obs.done is True
|
| 201 |
assert final_obs.booked_visits == [{"post_id": "post_023", "time": "Saturday 11am"}]
|
| 202 |
assert len(final_obs.seller_conversation_history) >= 2
|
| 203 |
+
assert final_obs.seller_conversation_history[0]["role"] == "assistant"
|
| 204 |
+
assert final_obs.seller_conversation_history[1]["role"] == "user"
|
| 205 |
+
assert "buyer profile" in final_obs.seller_conversation_history[0]["content"]
|
| 206 |
+
assert "budget up to Rs. 20000" in final_obs.seller_conversation_history[0]["content"]
|
| 207 |
+
assert "Can you confirm the buyer profile is acceptable" in final_obs.seller_conversation_history[0]["content"]
|
| 208 |
+
assert "Saturday 11am" in final_obs.seller_conversation_history[0]["content"]
|
| 209 |
+
assert "buyer profile is acceptable" in final_obs.seller_conversation_history[1]["content"]
|
| 210 |
assert "Saturday 11am works for the visit" in final_obs.seller_conversation_history[1]["content"]
|
| 211 |
+
contact_result = next(result for result in final_obs.tool_results if result["tool"] == "contact_poster")
|
| 212 |
+
assert contact_result["buyer_profile_shared"] is True
|
| 213 |
+
assert contact_result["seller_profile_fit_confirmed"] is True
|
| 214 |
|
| 215 |
|
| 216 |
def test_buyer_answers_diet_and_availability_when_broker_asks_for_both() -> None:
|
|
|
|
| 439 |
obs = _msg(env, "No Tuesday slot matches. I can offer Saturday 1pm or Sunday 5pm instead.")
|
| 440 |
assert "confirm" in obs.last_user_message.lower()
|
| 441 |
assert "Sunday 5pm" in obs.last_user_message or "Saturday 1pm" in obs.last_user_message
|
| 442 |
+
_tool(env, "contact_poster", post_id="post_023", time_text="Sunday 5pm")
|
| 443 |
+
obs = _tool(env, "book_viewing", post_id="post_023", time_text="Sunday 5pm")
|
| 444 |
+
assert obs.done is True
|
| 445 |
+
assert obs.booked_visits == [{"post_id": "post_023", "time": "Sunday 5pm"}]
|
| 446 |
|
| 447 |
|
| 448 |
def test_multi_visit_scenario_books_two_visits() -> None:
|
|
|
|
| 546 |
assert obs.status == "completed"
|
| 547 |
assert obs.booked_visits == [{"post_id": "post_155", "time": "negotiated_deal", "agreed_rent": 21000}]
|
| 548 |
assert obs.last_tool_result["tool"] == "confirm_negotiated_deal"
|
| 549 |
+
assert any("Would you accept Rs. 21000" in item["content"] for item in obs.seller_conversation_history)
|
| 550 |
+
assert any("I can accept Rs. 21000" in item["content"] for item in obs.seller_conversation_history)
|
tests/test_reward_regression.py
CHANGED
|
@@ -9,7 +9,7 @@ from flatmate_rl.server.heuristic_policy import expected_policy_action
|
|
| 9 |
|
| 10 |
HEURISTIC_BASELINES = {
|
| 11 |
"task_visit_single": 0.70,
|
| 12 |
-
"task_visit_single_hidden_flex":
|
| 13 |
"task_visit_multi": 1.10,
|
| 14 |
"task_visit_single_seller_followup": 0.90,
|
| 15 |
}
|
|
|
|
| 9 |
|
| 10 |
HEURISTIC_BASELINES = {
|
| 11 |
"task_visit_single": 0.70,
|
| 12 |
+
"task_visit_single_hidden_flex": 0.90,
|
| 13 |
"task_visit_multi": 1.10,
|
| 14 |
"task_visit_single_seller_followup": 0.90,
|
| 15 |
}
|
train_flatmate_rl_grpo_step2.ipynb
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Flatmate RL GRPO Step-2 Curriculum\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook is a minimal GRPO starter for `flatmate_rl`.\n",
|
| 10 |
+
"It only trains the first two workflow steps:\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"1. ask for the missing buyer details\n",
|
| 13 |
+
"2. store the buyer profile\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"The goal is to keep the reward simple enough to bootstrap the broker policy before training on later booking steps."
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": null,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"%pip install -q trl transformers accelerate datasets peft bitsandbytes sentencepiece\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"from __future__ import annotations\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"import json\n",
|
| 29 |
+
"import sys\n",
|
| 30 |
+
"from pathlib import Path\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"repo_root = Path.cwd().resolve().parent\n",
|
| 33 |
+
"if str(repo_root) not in sys.path:\n",
|
| 34 |
+
" sys.path.insert(0, str(repo_root))\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"from datasets import Dataset\n",
|
| 37 |
+
"from flatmate_rl import FlatmateRlAction\n",
|
| 38 |
+
"from flatmate_rl.server.flatmate_rl_environment import FlatmateRlEnvironment\n",
|
| 39 |
+
"from flatmate_rl.server.heuristic_policy import expected_policy_action\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"print('imports ready')"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"TARGET_SCENARIOS = [\n",
|
| 51 |
+
" 'task_visit_single',\n",
|
| 52 |
+
" 'task_visit_single_hidden_flex',\n",
|
| 53 |
+
" 'task_visit_multi',\n",
|
| 54 |
+
" 'task_visit_single_seller_followup',\n",
|
| 55 |
+
"]\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"def format_prompt(obs, step: int) -> str:\n",
|
| 58 |
+
" visible_state = {\n",
|
| 59 |
+
" 'step': step,\n",
|
| 60 |
+
" 'phase': obs.phase,\n",
|
| 61 |
+
" 'status': obs.status,\n",
|
| 62 |
+
" 'remaining_required_fields': obs.remaining_required_fields,\n",
|
| 63 |
+
" 'available_tools': obs.available_tools,\n",
|
| 64 |
+
" 'feedback_summary': obs.feedback_summary,\n",
|
| 65 |
+
" 'message': obs.message,\n",
|
| 66 |
+
" 'last_tool_result': obs.last_tool_result,\n",
|
| 67 |
+
" 'buyer_history': obs.buyer_conversation_history[-4:],\n",
|
| 68 |
+
" 'seller_history': obs.seller_conversation_history[-4:],\n",
|
| 69 |
+
" }\n",
|
| 70 |
+
"\n",
|
| 71 |
+
" return (\n",
|
| 72 |
+
" 'Return exactly one JSON object.\\\\n'\n",
|
| 73 |
+
" 'Schema: {\"action_type\":\"assistant_message\",\"assistant_message\":\"...\"} or '\n",
|
| 74 |
+
" '{\"action_type\":\"tool_call\",\"tool_name\":\"...\",\"tool_arguments\":{...}}\\\\n\\\\n'\n",
|
| 75 |
+
" f'Observation:\\n{json.dumps(visible_state, ensure_ascii=False, indent=2)}\\n'\n",
|
| 76 |
+
" 'Return JSON only.'\n",
|
| 77 |
+
" )\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"rows = []\n",
|
| 80 |
+
"for scenario_id in TARGET_SCENARIOS:\n",
|
| 81 |
+
" env = FlatmateRlEnvironment()\n",
|
| 82 |
+
" obs = env.reset(scenario_id=scenario_id)\n",
|
| 83 |
+
" for step in (1, 2):\n",
|
| 84 |
+
" payload = expected_policy_action(scenario_id, obs.model_dump())\n",
|
| 85 |
+
" if payload is None:\n",
|
| 86 |
+
" break\n",
|
| 87 |
+
" rows.append(\n",
|
| 88 |
+
" {\n",
|
| 89 |
+
" 'scenario_id': scenario_id,\n",
|
| 90 |
+
" 'step': step,\n",
|
| 91 |
+
" 'prompt': format_prompt(obs, step),\n",
|
| 92 |
+
" 'expected_action': payload,\n",
|
| 93 |
+
" }\n",
|
| 94 |
+
" )\n",
|
| 95 |
+
" obs = env.step(FlatmateRlAction.model_validate(payload))\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"train_ds = Dataset.from_list(rows)\n",
|
| 98 |
+
"train_ds[:2]\n"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "code",
|
| 103 |
+
"execution_count": null,
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"def score_completion(example, completion_text: str) -> float:\n",
|
| 108 |
+
" try:\n",
|
| 109 |
+
" action = json.loads(completion_text)\n",
|
| 110 |
+
" except json.JSONDecodeError:\n",
|
| 111 |
+
" return -0.25\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" step = int(example['step'])\n",
|
| 114 |
+
" expected = example['expected_action']\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" if step == 1:\n",
|
| 117 |
+
" message = str(action.get('assistant_message', '')).lower()\n",
|
| 118 |
+
" if action.get('action_type') == 'assistant_message' and 'diet' in message and 'availability' in message:\n",
|
| 119 |
+
" return 1.0\n",
|
| 120 |
+
" return -0.1\n",
|
| 121 |
+
"\n",
|
| 122 |
+
" if step == 2:\n",
|
| 123 |
+
" if action.get('action_type') == 'tool_call' and action.get('tool_name') == expected.get('tool_name'):\n",
|
| 124 |
+
" return 1.0\n",
|
| 125 |
+
" return -0.2\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" return 0.0\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"for row in rows[:2]:\n",
|
| 130 |
+
" print(row['scenario_id'], row['step'], score_completion(row, json.dumps(row['expected_action'])))\n"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "code",
|
| 135 |
+
"execution_count": null,
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"outputs": [],
|
| 138 |
+
"source": [
|
| 139 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"model_name = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
|
| 142 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 143 |
+
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"from peft import LoraConfig\n",
|
| 146 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"grpo_args = GRPOConfig(\n",
|
| 149 |
+
" output_dir='flatmate_grpo_step2',\n",
|
| 150 |
+
" learning_rate=1e-5,\n",
|
| 151 |
+
" per_device_train_batch_size=1,\n",
|
| 152 |
+
" gradient_accumulation_steps=4,\n",
|
| 153 |
+
" max_prompt_length=1024,\n",
|
| 154 |
+
" max_completion_length=256,\n",
|
| 155 |
+
" num_generations=4,\n",
|
| 156 |
+
" logging_steps=1,\n",
|
| 157 |
+
" save_steps=25,\n",
|
| 158 |
+
")\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"lora_config = LoraConfig(\n",
|
| 161 |
+
" r=8,\n",
|
| 162 |
+
" lora_alpha=16,\n",
|
| 163 |
+
" lora_dropout=0.05,\n",
|
| 164 |
+
" bias='none',\n",
|
| 165 |
+
" task_type='CAUSAL_LM',\n",
|
| 166 |
+
")\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"def reward_func(prompts, completions, **kwargs):\n",
|
| 169 |
+
" rewards = []\n",
|
| 170 |
+
" examples = kwargs['examples']\n",
|
| 171 |
+
" for example, completion in zip(examples, completions):\n",
|
| 172 |
+
" rewards.append(score_completion(example, completion))\n",
|
| 173 |
+
" return rewards\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"# Starter training block.\n",
|
| 176 |
+
"# If your installed TRL version expects a slightly different GRPOTrainer signature,\n",
|
| 177 |
+
"# keep the dataset, reward, and LoRA config from above and adapt only the constructor call.\n",
|
| 178 |
+
"trainer = GRPOTrainer(\n",
|
| 179 |
+
" model=model,\n",
|
| 180 |
+
" tokenizer=tokenizer,\n",
|
| 181 |
+
" args=grpo_args,\n",
|
| 182 |
+
" train_dataset=train_ds,\n",
|
| 183 |
+
" reward_funcs=[reward_func],\n",
|
| 184 |
+
" peft_config=lora_config,\n",
|
| 185 |
+
")\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"# trainer.train()\n",
|
| 188 |
+
"print('GRPO trainer configured for the step-1/step-2 curriculum')\n"
|
| 189 |
+
]
|
| 190 |
+
}
|
| 191 |
+
],
|
| 192 |
+
"metadata": {
|
| 193 |
+
"kernelspec": {
|
| 194 |
+
"display_name": "Python 3",
|
| 195 |
+
"language": "python",
|
| 196 |
+
"name": "python3"
|
| 197 |
+
},
|
| 198 |
+
"language_info": {
|
| 199 |
+
"name": "python",
|
| 200 |
+
"version": "3.12"
|
| 201 |
+
}
|
| 202 |
+
},
|
| 203 |
+
"nbformat": 4,
|
| 204 |
+
"nbformat_minor": 5
|
| 205 |
+
}
|