kushalExplores commited on
Commit
dbb1ce2
·
verified ·
1 Parent(s): 0594d27

Add step-2 GRPO notebook and hidden-flex fix

Browse files
inference.py CHANGED
@@ -71,6 +71,30 @@ SYSTEM_PROMPT = textwrap.dedent(
71
  - If a tool can perform the next required operation, call the tool immediately.
72
  - Do not send acknowledgement or progress messages such as "I will search now" when a tool call is needed.
73
  - Prefer safe, incremental progress toward storing user details, matching listings, and booking visits.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  """
75
  ).strip()
76
 
@@ -371,12 +395,20 @@ def build_user_prompt(step: int, observation: Any) -> str:
371
  Step: {step}
372
  Phase: {observation.phase}
373
  Status: {observation.status}
 
 
 
 
 
 
374
  Available tools: {observation.available_tools}
375
  Last tool result: {json.dumps(last_tool_result, ensure_ascii=False)}
376
  Prerequisites satisfied: {json.dumps(observation.prerequisites_satisfied, ensure_ascii=False)}
377
  Recent tool calls: {json.dumps(observation.recent_tool_calls, ensure_ascii=False)}
378
  Booked visits: {observation.booked_visits}
379
 
 
 
380
  Buyer/Broker transcript:
381
  {json.dumps(observation.buyer_conversation_history[-8:], ensure_ascii=False)}
382
 
 
71
  - If a tool can perform the next required operation, call the tool immediately.
72
  - Do not send acknowledgement or progress messages such as "I will search now" when a tool call is needed.
73
  - Prefer safe, incremental progress toward storing user details, matching listings, and booking visits.
74
+ - Use exact tool argument names from the prompt. Never invent aliases such as visit_time.
75
+ - Treat negative reward, violations, and feedback_summary as corrective feedback for the next action.
76
+ """
77
+ ).strip()
78
+
79
+
80
+ TOOL_CONTRACT_PROMPT = textwrap.dedent(
81
+ """
82
+ Tool argument contract:
83
+ - store_user_details: tool_arguments can be {} after required buyer fields are gathered.
84
+ - search_posts: tool_arguments can be {}.
85
+ - match_location_preference: {"post_ids":["post_id", ...]}.
86
+ - get_commute_time: {"post_ids":["post_id", ...]}.
87
+ - check_calendar_slots: {"post_ids":["post_id", ...]}.
88
+ - shortlist: {"post_ids":["post_id", ...]}.
89
+ - contact_poster: {"post_id":"post_id","time_text":"exact slot from check_calendar_slots"}. This shares the buyer profile with the seller/poster and asks them to confirm profile fit plus visit time.
90
+ - book_viewing: {"post_id":"post_id","time_text":"same exact slot confirmed by buyer and poster"}.
91
+
92
+ Booking workflow:
93
+ 1. Ask for missing buyer fields before store_user_details.
94
+ 2. Store buyer details, search posts, match location, get commute time, then check calendar slots.
95
+ 3. Ask the buyer to confirm one exact slot from check_calendar_slots.
96
+ 4. Call contact_poster with post_id and time_text for that same slot.
97
+ 5. Only after buyer_confirmed and poster_confirmed are true, call book_viewing with post_id and time_text.
98
  """
99
  ).strip()
100
 
 
395
  Step: {step}
396
  Phase: {observation.phase}
397
  Status: {observation.status}
398
+ Feedback summary: {observation.feedback_summary}
399
+ Environment message: {observation.message}
400
+ Step reward: {observation.step_reward}
401
+ Total reward: {observation.total_reward}
402
+ Violations: {observation.violations}
403
+ Remaining required fields: {observation.remaining_required_fields}
404
  Available tools: {observation.available_tools}
405
  Last tool result: {json.dumps(last_tool_result, ensure_ascii=False)}
406
  Prerequisites satisfied: {json.dumps(observation.prerequisites_satisfied, ensure_ascii=False)}
407
  Recent tool calls: {json.dumps(observation.recent_tool_calls, ensure_ascii=False)}
408
  Booked visits: {observation.booked_visits}
409
 
410
+ {TOOL_CONTRACT_PROMPT}
411
+
412
  Buyer/Broker transcript:
413
  {json.dumps(observation.buyer_conversation_history[-8:], ensure_ascii=False)}
414
 
server/episode.py CHANGED
@@ -95,6 +95,7 @@ class FlatmateEpisode:
95
  self._commutes_checked: dict[str, int] = {}
96
  self._poster_confirmations: dict[str, str] = {}
97
  self._client_confirmations: dict[str, str] = {}
 
98
  self._seller_confirmations: dict[str, str] = {}
99
  self._buyer_offer_confirmations: dict[str, str] = {}
100
  self._dynamic_post_id: str | None = None
@@ -136,6 +137,7 @@ class FlatmateEpisode:
136
  self._commutes_checked = {}
137
  self._poster_confirmations = {}
138
  self._client_confirmations = {}
 
139
  self._seller_confirmations = {}
140
  self._buyer_offer_confirmations = {}
141
  self._dynamic_post_id = None
@@ -259,6 +261,12 @@ class FlatmateEpisode:
259
  slots.extend(profile["hidden_additional_availability"])
260
  return slots
261
 
 
 
 
 
 
 
262
  def _record_violation(self, text: str) -> None:
263
  if text not in self._violations:
264
  self._violations.append(text)
@@ -496,8 +504,10 @@ class FlatmateEpisode:
496
  self._state.gathered_fields.append("hidden_flex_revealed")
497
  if alternatives_offered:
498
  if "sunday 5pm" in lowered:
 
499
  return "I can make Sunday 5pm work, so I confirm Sunday 5pm."
500
  if "saturday 1pm" in lowered:
 
501
  return "Saturday 1pm works for me too, so I confirm Saturday 1pm."
502
 
503
  # Scenario 2: waitlist — fire cancellation notification on first message after add_to_waitlist
@@ -861,6 +871,16 @@ class FlatmateEpisode:
861
  self._state.selected_posts = post_ids
862
  return {"tool": "shortlist", "success": True, "message": "Posts shortlisted.", "selected_posts": post_ids}
863
 
 
 
 
 
 
 
 
 
 
 
864
  def _tool_contact_poster(self, arguments: dict[str, Any]) -> dict[str, Any]:
865
  post_id = arguments.get("post_id", "")
866
  time_text = arguments.get("time_text", "")
@@ -872,20 +892,34 @@ class FlatmateEpisode:
872
  return {"tool": "contact_poster", "success": False, "message": "Time must come from check_calendar_slots."}
873
  self._seller_history.append(
874
  {
875
- "role": "user",
876
- "content": f"Client selected {post_id}. Can we visit at {time_text}?",
 
 
 
877
  }
878
  )
879
  self._poster_confirmations[post_id] = time_text
880
- poster_message = f"Yes, confirmed. {time_text} works for the visit."
881
- self._seller_history.append({"role": "assistant", "content": poster_message})
882
- return {"tool": "contact_poster", "success": True, "message": f"Poster confirmed {time_text}.", "post_id": post_id, "time_text": time_text}
 
 
 
 
 
 
 
 
 
883
 
884
  def _tool_book_viewing(self, arguments: dict[str, Any]) -> dict[str, Any]:
885
  post_id = arguments.get("post_id", "")
886
  time_text = arguments.get("time_text", "")
887
  if post_id not in self._poster_confirmations or self._poster_confirmations[post_id] != time_text:
888
  return {"tool": "book_viewing", "success": False, "message": "Poster has not explicitly confirmed this time."}
 
 
889
  if post_id not in self._client_confirmations or self._client_confirmations[post_id] != time_text:
890
  return {"tool": "book_viewing", "success": False, "message": "Client has not explicitly confirmed this time."}
891
  if self._scenario["task_id"] == "task_visit_multi" and post_id not in self._state.selected_posts:
@@ -940,8 +974,15 @@ class FlatmateEpisode:
940
  config = self._scenario["scenario_creation_config"].get("negotiation_config", {})
941
  seller_floor = config.get("seller_floor", 0)
942
  self._negotiation_rounds_seller += 1
 
 
 
 
 
 
943
  if proposed_rent >= seller_floor:
944
  self._seller_price_accepted = proposed_rent
 
945
  return {
946
  "tool": "propose_price_to_seller",
947
  "success": True,
@@ -950,6 +991,7 @@ class FlatmateEpisode:
950
  "proposed_rent": proposed_rent,
951
  }
952
  hint = " Maybe a small discount is possible." if self._negotiation_rounds_seller >= 2 else ""
 
953
  return {
954
  "tool": "propose_price_to_seller",
955
  "success": True,
@@ -1167,9 +1209,9 @@ class FlatmateEpisode:
1167
  post = self._resolve_post(post_id)
1168
  if not post or time_text not in post["calendar_slots"]:
1169
  return {"tool": "confirm_seller_match", "success": False, "message": "Selected seller slot is invalid."}
1170
- self._seller_history.append({"role": "user", "content": f"Can we confirm {time_text} for {post_id}?"})
1171
  self._seller_confirmations[post_id] = time_text
1172
- self._seller_history.append({"role": "assistant", "content": f"Confirmed, {time_text} works from the seller side."})
1173
  return {"tool": "confirm_seller_match", "success": True, "message": f"Seller confirmed {time_text}.", "post_id": post_id, "time_text": time_text}
1174
 
1175
  def _tool_offer_matched_listing_to_buyer(self, arguments: dict[str, Any]) -> dict[str, Any]:
 
95
  self._commutes_checked: dict[str, int] = {}
96
  self._poster_confirmations: dict[str, str] = {}
97
  self._client_confirmations: dict[str, str] = {}
98
+ self._seller_profile_fit_confirmations: dict[str, bool] = {}
99
  self._seller_confirmations: dict[str, str] = {}
100
  self._buyer_offer_confirmations: dict[str, str] = {}
101
  self._dynamic_post_id: str | None = None
 
137
  self._commutes_checked = {}
138
  self._poster_confirmations = {}
139
  self._client_confirmations = {}
140
+ self._seller_profile_fit_confirmations = {}
141
  self._seller_confirmations = {}
142
  self._buyer_offer_confirmations = {}
143
  self._dynamic_post_id = None
 
261
  slots.extend(profile["hidden_additional_availability"])
262
  return slots
263
 
264
+ def _record_client_confirmation_for_slot(self, slot: str) -> None:
265
+ for post_id, checked_slots in self._slots_checked.items():
266
+ if slot in checked_slots:
267
+ self._client_confirmations[post_id] = slot
268
+ return
269
+
270
  def _record_violation(self, text: str) -> None:
271
  if text not in self._violations:
272
  self._violations.append(text)
 
504
  self._state.gathered_fields.append("hidden_flex_revealed")
505
  if alternatives_offered:
506
  if "sunday 5pm" in lowered:
507
+ self._record_client_confirmation_for_slot("Sunday 5pm")
508
  return "I can make Sunday 5pm work, so I confirm Sunday 5pm."
509
  if "saturday 1pm" in lowered:
510
+ self._record_client_confirmation_for_slot("Saturday 1pm")
511
  return "Saturday 1pm works for me too, so I confirm Saturday 1pm."
512
 
513
  # Scenario 2: waitlist — fire cancellation notification on first message after add_to_waitlist
 
871
  self._state.selected_posts = post_ids
872
  return {"tool": "shortlist", "success": True, "message": "Posts shortlisted.", "selected_posts": post_ids}
873
 
874
+ def _buyer_profile_summary_for_seller(self) -> str:
875
+ profile = self._scenario["buyer_profile"]
876
+ return (
877
+ f"buyer profile: budget up to Rs. {profile['budget_max']}; "
878
+ f"dietary preference {profile['dietary']}; "
879
+ f"preferred areas {', '.join(profile['areas'])}; "
880
+ f"occupation {profile['occupation']}; "
881
+ f"visit availability {', '.join(profile['visit_availability'])}"
882
+ )
883
+
884
  def _tool_contact_poster(self, arguments: dict[str, Any]) -> dict[str, Any]:
885
  post_id = arguments.get("post_id", "")
886
  time_text = arguments.get("time_text", "")
 
892
  return {"tool": "contact_poster", "success": False, "message": "Time must come from check_calendar_slots."}
893
  self._seller_history.append(
894
  {
895
+ "role": "assistant",
896
+ "content": (
897
+ f"Client selected {post_id}. Please review this {self._buyer_profile_summary_for_seller()}. "
898
+ f"Can you confirm the buyer profile is acceptable and that we can visit at {time_text}?"
899
+ ),
900
  }
901
  )
902
  self._poster_confirmations[post_id] = time_text
903
+ self._seller_profile_fit_confirmations[post_id] = True
904
+ poster_message = f"Yes, confirmed. The buyer profile is acceptable and {time_text} works for the visit."
905
+ self._seller_history.append({"role": "user", "content": poster_message})
906
+ return {
907
+ "tool": "contact_poster",
908
+ "success": True,
909
+ "message": f"Poster confirmed buyer profile fit and {time_text}.",
910
+ "post_id": post_id,
911
+ "time_text": time_text,
912
+ "buyer_profile_shared": True,
913
+ "seller_profile_fit_confirmed": True,
914
+ }
915
 
916
  def _tool_book_viewing(self, arguments: dict[str, Any]) -> dict[str, Any]:
917
  post_id = arguments.get("post_id", "")
918
  time_text = arguments.get("time_text", "")
919
  if post_id not in self._poster_confirmations or self._poster_confirmations[post_id] != time_text:
920
  return {"tool": "book_viewing", "success": False, "message": "Poster has not explicitly confirmed this time."}
921
+ if not self._seller_profile_fit_confirmations.get(post_id):
922
+ return {"tool": "book_viewing", "success": False, "message": "Poster has not confirmed the buyer profile fit."}
923
  if post_id not in self._client_confirmations or self._client_confirmations[post_id] != time_text:
924
  return {"tool": "book_viewing", "success": False, "message": "Client has not explicitly confirmed this time."}
925
  if self._scenario["task_id"] == "task_visit_multi" and post_id not in self._state.selected_posts:
 
974
  config = self._scenario["scenario_creation_config"].get("negotiation_config", {})
975
  seller_floor = config.get("seller_floor", 0)
976
  self._negotiation_rounds_seller += 1
977
+ self._seller_history.append(
978
+ {
979
+ "role": "assistant",
980
+ "content": f"The buyer is interested in {post_id}. Would you accept Rs. {proposed_rent}?",
981
+ }
982
+ )
983
  if proposed_rent >= seller_floor:
984
  self._seller_price_accepted = proposed_rent
985
+ self._seller_history.append({"role": "user", "content": f"Yes, I can accept Rs. {proposed_rent}."})
986
  return {
987
  "tool": "propose_price_to_seller",
988
  "success": True,
 
991
  "proposed_rent": proposed_rent,
992
  }
993
  hint = " Maybe a small discount is possible." if self._negotiation_rounds_seller >= 2 else ""
994
+ self._seller_history.append({"role": "user", "content": f"I can't go as low as Rs. {proposed_rent}.{hint}"})
995
  return {
996
  "tool": "propose_price_to_seller",
997
  "success": True,
 
1209
  post = self._resolve_post(post_id)
1210
  if not post or time_text not in post["calendar_slots"]:
1211
  return {"tool": "confirm_seller_match", "success": False, "message": "Selected seller slot is invalid."}
1212
+ self._seller_history.append({"role": "assistant", "content": f"Can we confirm {time_text} for {post_id}?"})
1213
  self._seller_confirmations[post_id] = time_text
1214
+ self._seller_history.append({"role": "user", "content": f"Confirmed, {time_text} works from the seller side."})
1215
  return {"tool": "confirm_seller_match", "success": True, "message": f"Seller confirmed {time_text}.", "post_id": post_id, "time_text": time_text}
1216
 
1217
  def _tool_offer_matched_listing_to_buyer(self, arguments: dict[str, Any]) -> dict[str, Any]:
tests/test_flatmate_rl.py CHANGED
@@ -200,8 +200,17 @@ def test_single_visit_scenario_books_one_visit() -> None:
200
  assert final_obs.done is True
201
  assert final_obs.booked_visits == [{"post_id": "post_023", "time": "Saturday 11am"}]
202
  assert len(final_obs.seller_conversation_history) >= 2
203
- assert "Can we visit at Saturday 11am" in final_obs.seller_conversation_history[0]["content"]
 
 
 
 
 
 
204
  assert "Saturday 11am works for the visit" in final_obs.seller_conversation_history[1]["content"]
 
 
 
205
 
206
 
207
  def test_buyer_answers_diet_and_availability_when_broker_asks_for_both() -> None:
@@ -430,6 +439,10 @@ def test_hidden_flex_requires_alternative_slot_to_unlock_backup_availability() -
430
  obs = _msg(env, "No Tuesday slot matches. I can offer Saturday 1pm or Sunday 5pm instead.")
431
  assert "confirm" in obs.last_user_message.lower()
432
  assert "Sunday 5pm" in obs.last_user_message or "Saturday 1pm" in obs.last_user_message
 
 
 
 
433
 
434
 
435
  def test_multi_visit_scenario_books_two_visits() -> None:
@@ -533,3 +546,5 @@ def test_negotiation_heuristic_confirms_deal_with_agreed_rent() -> None:
533
  assert obs.status == "completed"
534
  assert obs.booked_visits == [{"post_id": "post_155", "time": "negotiated_deal", "agreed_rent": 21000}]
535
  assert obs.last_tool_result["tool"] == "confirm_negotiated_deal"
 
 
 
200
  assert final_obs.done is True
201
  assert final_obs.booked_visits == [{"post_id": "post_023", "time": "Saturday 11am"}]
202
  assert len(final_obs.seller_conversation_history) >= 2
203
+ assert final_obs.seller_conversation_history[0]["role"] == "assistant"
204
+ assert final_obs.seller_conversation_history[1]["role"] == "user"
205
+ assert "buyer profile" in final_obs.seller_conversation_history[0]["content"]
206
+ assert "budget up to Rs. 20000" in final_obs.seller_conversation_history[0]["content"]
207
+ assert "Can you confirm the buyer profile is acceptable" in final_obs.seller_conversation_history[0]["content"]
208
+ assert "Saturday 11am" in final_obs.seller_conversation_history[0]["content"]
209
+ assert "buyer profile is acceptable" in final_obs.seller_conversation_history[1]["content"]
210
  assert "Saturday 11am works for the visit" in final_obs.seller_conversation_history[1]["content"]
211
+ contact_result = next(result for result in final_obs.tool_results if result["tool"] == "contact_poster")
212
+ assert contact_result["buyer_profile_shared"] is True
213
+ assert contact_result["seller_profile_fit_confirmed"] is True
214
 
215
 
216
  def test_buyer_answers_diet_and_availability_when_broker_asks_for_both() -> None:
 
439
  obs = _msg(env, "No Tuesday slot matches. I can offer Saturday 1pm or Sunday 5pm instead.")
440
  assert "confirm" in obs.last_user_message.lower()
441
  assert "Sunday 5pm" in obs.last_user_message or "Saturday 1pm" in obs.last_user_message
442
+ _tool(env, "contact_poster", post_id="post_023", time_text="Sunday 5pm")
443
+ obs = _tool(env, "book_viewing", post_id="post_023", time_text="Sunday 5pm")
444
+ assert obs.done is True
445
+ assert obs.booked_visits == [{"post_id": "post_023", "time": "Sunday 5pm"}]
446
 
447
 
448
  def test_multi_visit_scenario_books_two_visits() -> None:
 
546
  assert obs.status == "completed"
547
  assert obs.booked_visits == [{"post_id": "post_155", "time": "negotiated_deal", "agreed_rent": 21000}]
548
  assert obs.last_tool_result["tool"] == "confirm_negotiated_deal"
549
+ assert any("Would you accept Rs. 21000" in item["content"] for item in obs.seller_conversation_history)
550
+ assert any("I can accept Rs. 21000" in item["content"] for item in obs.seller_conversation_history)
tests/test_reward_regression.py CHANGED
@@ -9,7 +9,7 @@ from flatmate_rl.server.heuristic_policy import expected_policy_action
9
 
10
  HEURISTIC_BASELINES = {
11
  "task_visit_single": 0.70,
12
- "task_visit_single_hidden_flex": -1.70,
13
  "task_visit_multi": 1.10,
14
  "task_visit_single_seller_followup": 0.90,
15
  }
 
9
 
10
  HEURISTIC_BASELINES = {
11
  "task_visit_single": 0.70,
12
+ "task_visit_single_hidden_flex": 0.90,
13
  "task_visit_multi": 1.10,
14
  "task_visit_single_seller_followup": 0.90,
15
  }
train_flatmate_rl_grpo_step2.ipynb ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Flatmate RL GRPO Step-2 Curriculum\n",
8
+ "\n",
9
+ "This notebook is a minimal GRPO starter for `flatmate_rl`.\n",
10
+ "It only trains the first two workflow steps:\n",
11
+ "\n",
12
+ "1. ask for the missing buyer details\n",
13
+ "2. store the buyer profile\n",
14
+ "\n",
15
+ "The goal is to keep the reward simple enough to bootstrap the broker policy before training on later booking steps."
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "%pip install -q trl transformers accelerate datasets peft bitsandbytes sentencepiece\n",
25
+ "\n",
26
+ "from __future__ import annotations\n",
27
+ "\n",
28
+ "import json\n",
29
+ "import sys\n",
30
+ "from pathlib import Path\n",
31
+ "\n",
32
+ "repo_root = Path.cwd().resolve().parent\n",
33
+ "if str(repo_root) not in sys.path:\n",
34
+ " sys.path.insert(0, str(repo_root))\n",
35
+ "\n",
36
+ "from datasets import Dataset\n",
37
+ "from flatmate_rl import FlatmateRlAction\n",
38
+ "from flatmate_rl.server.flatmate_rl_environment import FlatmateRlEnvironment\n",
39
+ "from flatmate_rl.server.heuristic_policy import expected_policy_action\n",
40
+ "\n",
41
+ "print('imports ready')"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "TARGET_SCENARIOS = [\n",
51
+ " 'task_visit_single',\n",
52
+ " 'task_visit_single_hidden_flex',\n",
53
+ " 'task_visit_multi',\n",
54
+ " 'task_visit_single_seller_followup',\n",
55
+ "]\n",
56
+ "\n",
57
+ "def format_prompt(obs, step: int) -> str:\n",
58
+ " visible_state = {\n",
59
+ " 'step': step,\n",
60
+ " 'phase': obs.phase,\n",
61
+ " 'status': obs.status,\n",
62
+ " 'remaining_required_fields': obs.remaining_required_fields,\n",
63
+ " 'available_tools': obs.available_tools,\n",
64
+ " 'feedback_summary': obs.feedback_summary,\n",
65
+ " 'message': obs.message,\n",
66
+ " 'last_tool_result': obs.last_tool_result,\n",
67
+ " 'buyer_history': obs.buyer_conversation_history[-4:],\n",
68
+ " 'seller_history': obs.seller_conversation_history[-4:],\n",
69
+ " }\n",
70
+ "\n",
71
+ " return (\n",
72
+ " 'Return exactly one JSON object.\\\\n'\n",
73
+ " 'Schema: {\"action_type\":\"assistant_message\",\"assistant_message\":\"...\"} or '\n",
74
+ " '{\"action_type\":\"tool_call\",\"tool_name\":\"...\",\"tool_arguments\":{...}}\\\\n\\\\n'\n",
75
+ " f'Observation:\\n{json.dumps(visible_state, ensure_ascii=False, indent=2)}\\n'\n",
76
+ " 'Return JSON only.'\n",
77
+ " )\n",
78
+ "\n",
79
+ "rows = []\n",
80
+ "for scenario_id in TARGET_SCENARIOS:\n",
81
+ " env = FlatmateRlEnvironment()\n",
82
+ " obs = env.reset(scenario_id=scenario_id)\n",
83
+ " for step in (1, 2):\n",
84
+ " payload = expected_policy_action(scenario_id, obs.model_dump())\n",
85
+ " if payload is None:\n",
86
+ " break\n",
87
+ " rows.append(\n",
88
+ " {\n",
89
+ " 'scenario_id': scenario_id,\n",
90
+ " 'step': step,\n",
91
+ " 'prompt': format_prompt(obs, step),\n",
92
+ " 'expected_action': payload,\n",
93
+ " }\n",
94
+ " )\n",
95
+ " obs = env.step(FlatmateRlAction.model_validate(payload))\n",
96
+ "\n",
97
+ "train_ds = Dataset.from_list(rows)\n",
98
+ "train_ds[:2]\n"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "def score_completion(example, completion_text: str) -> float:\n",
108
+ " try:\n",
109
+ " action = json.loads(completion_text)\n",
110
+ " except json.JSONDecodeError:\n",
111
+ " return -0.25\n",
112
+ "\n",
113
+ " step = int(example['step'])\n",
114
+ " expected = example['expected_action']\n",
115
+ "\n",
116
+ " if step == 1:\n",
117
+ " message = str(action.get('assistant_message', '')).lower()\n",
118
+ " if action.get('action_type') == 'assistant_message' and 'diet' in message and 'availability' in message:\n",
119
+ " return 1.0\n",
120
+ " return -0.1\n",
121
+ "\n",
122
+ " if step == 2:\n",
123
+ " if action.get('action_type') == 'tool_call' and action.get('tool_name') == expected.get('tool_name'):\n",
124
+ " return 1.0\n",
125
+ " return -0.2\n",
126
+ "\n",
127
+ " return 0.0\n",
128
+ "\n",
129
+ "for row in rows[:2]:\n",
130
+ " print(row['scenario_id'], row['step'], score_completion(row, json.dumps(row['expected_action'])))\n"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
140
+ "\n",
141
+ "model_name = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
142
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
143
+ "model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')\n",
144
+ "\n",
145
+ "from peft import LoraConfig\n",
146
+ "from trl import GRPOConfig, GRPOTrainer\n",
147
+ "\n",
148
+ "grpo_args = GRPOConfig(\n",
149
+ " output_dir='flatmate_grpo_step2',\n",
150
+ " learning_rate=1e-5,\n",
151
+ " per_device_train_batch_size=1,\n",
152
+ " gradient_accumulation_steps=4,\n",
153
+ " max_prompt_length=1024,\n",
154
+ " max_completion_length=256,\n",
155
+ " num_generations=4,\n",
156
+ " logging_steps=1,\n",
157
+ " save_steps=25,\n",
158
+ ")\n",
159
+ "\n",
160
+ "lora_config = LoraConfig(\n",
161
+ " r=8,\n",
162
+ " lora_alpha=16,\n",
163
+ " lora_dropout=0.05,\n",
164
+ " bias='none',\n",
165
+ " task_type='CAUSAL_LM',\n",
166
+ ")\n",
167
+ "\n",
168
+ "def reward_func(prompts, completions, **kwargs):\n",
169
+ " rewards = []\n",
170
+ " examples = kwargs['examples']\n",
171
+ " for example, completion in zip(examples, completions):\n",
172
+ " rewards.append(score_completion(example, completion))\n",
173
+ " return rewards\n",
174
+ "\n",
175
+ "# Starter training block.\n",
176
+ "# If your installed TRL version expects a slightly different GRPOTrainer signature,\n",
177
+ "# keep the dataset, reward, and LoRA config from above and adapt only the constructor call.\n",
178
+ "trainer = GRPOTrainer(\n",
179
+ " model=model,\n",
180
+ " tokenizer=tokenizer,\n",
181
+ " args=grpo_args,\n",
182
+ " train_dataset=train_ds,\n",
183
+ " reward_funcs=[reward_func],\n",
184
+ " peft_config=lora_config,\n",
185
+ ")\n",
186
+ "\n",
187
+ "# trainer.train()\n",
188
+ "print('GRPO trainer configured for the step-1/step-2 curriculum')\n"
189
+ ]
190
+ }
191
+ ],
192
+ "metadata": {
193
+ "kernelspec": {
194
+ "display_name": "Python 3",
195
+ "language": "python",
196
+ "name": "python3"
197
+ },
198
+ "language_info": {
199
+ "name": "python",
200
+ "version": "3.12"
201
+ }
202
+ },
203
+ "nbformat": 4,
204
+ "nbformat_minor": 5
205
+ }