File size: 11,599 Bytes
cc6473a
ecbe0d8
cc6473a
ece0bbe
 
 
 
 
 
 
 
cc6473a
 
 
 
 
 
 
 
 
 
 
 
 
 
ece0bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc6473a
 
ecbe0d8
 
 
 
 
 
 
cc6473a
 
 
 
ecbe0d8
 
ece0bbe
 
 
 
bb2a9c7
 
 
 
 
ecbe0d8
 
 
 
bb2a9c7
ecbe0d8
 
 
bb2a9c7
 
 
 
ecbe0d8
 
 
 
 
 
 
 
 
 
cc6473a
 
 
 
 
 
 
 
 
ecbe0d8
 
 
64d24b3
cc6473a
 
 
ecbe0d8
 
 
 
 
 
 
 
 
 
 
cc6473a
ecbe0d8
 
 
 
 
 
 
 
 
cc6473a
 
ecbe0d8
 
cc6473a
 
 
 
 
 
 
 
 
ecbe0d8
cc6473a
 
 
 
 
 
 
 
ecbe0d8
 
 
cc6473a
 
 
 
 
ece0bbe
cc6473a
 
 
 
 
 
 
 
 
 
ece0bbe
 
 
 
 
 
cc6473a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecbe0d8
 
cc6473a
ecbe0d8
 
 
 
 
cc6473a
 
 
 
 
ecbe0d8
 
cc6473a
 
ecbe0d8
cc6473a
 
 
 
 
 
 
ecbe0d8
 
 
 
 
 
 
 
cc6473a
 
 
 
 
ecbe0d8
cc6473a
 
 
ecbe0d8
 
 
 
 
cc6473a
 
 
 
ecbe0d8
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""
Dataset generator for RhythmEnv GRPO training (meta-RL version).

Plays episodes under a continuously-sampled profile per seed and emits
observation prompts at each step, paired with the replay metadata
(seed, step_index, action_history) the reward functions need to
reconstruct env state deterministically.

The system prompt asks for "S M W ACTION_NAME" — three belief digits then
the action. A `hint_fraction` slice of episodes carries a true-belief hint
in the prompt as a curriculum warmup; the rest force pure inference.
"""

import sys
import os
import random

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from models import ActionType, RhythmAction
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS, METERS

SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
DAY_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]

SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
You see 5 life meters and a rolling history of recent steps. The same action
affects different people differently — you must INFER who you're helping from
rewards, meter changes, and per-meter ANOMALY signals.

Each step, do TWO things:

1. Reason briefly about what the observations imply about the person.
   Focus on:
     - Anomalies (actual delta vs neutral-profile expectation): big positive
       social_serenity / connection responses → high S; big morning cognition
       gains → high M; productive work giving vitality back → high W
     - Current meter state: any meter under 0.15 needs urgent recovery
     - What action best fits BOTH the inferred profile and the current state

2. Output your final answer on the LAST line in this exact format:
       S M W ACTION_NAME
   where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
   estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
   DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
   SOCIALIZE, ME_TIME, BINGE_WATCH

Wrap your reasoning in <reasoning>...</reasoning> tags. Keep reasoning under
120 tokens. The final answer line MUST be the last line of your response.

Belief→action quick reference:
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
- Connection decays passively — actively maintain via SOCIALIZE/FAMILY_TIME
- Don't repeat the same action 3+ times in a row — repetition penalty applies

Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
then exploit your sharpened belief by picking actions that match the inferred
profile + current meter state.

Example output:
<reasoning>
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) — high
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
(anom +0.04) → high M. Vitality at 0.6 still ok, serenity dropping. With low S +
high M, MEDITATE is the recovery play that fits.
</reasoning>
2 8 5 MEDITATE"""


def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
    """Format an observation into a user prompt for the LLM.

    If profile_hint is provided (curriculum's "visible" phase), include it in
    the prompt so the agent learns the *skill* of using profile signals
    before having to infer them from scratch.
    """
    day_name = DAY_NAMES[obs.day] if obs.day < 7 else f"Day {obs.day}"
    slot_name = SLOT_NAMES[obs.slot] if obs.slot < 4 else f"Slot {obs.slot}"
    event_str = f"\nActive event: {obs.active_event}" if obs.active_event else ""

    history_lines = []
    for h in (obs.step_history or [])[-5:]:  # last 5 only to fit prompt budget
        # Per-meter anomalies (actual_delta − expected_under_neutral_profile)
        # are the cleanest profile-inference signal — they show how this person's
        # response DEVIATES from the average person. Surfacing them here in the
        # prompt is what gives the agent a fingerprint to learn from.
        anom_str = (
            f" [anom V{h.vitality_anomaly:+.2f} C{h.cognition_anomaly:+.2f} "
            f"P{h.progress_anomaly:+.2f} S{h.serenity_anomaly:+.2f} "
            f"Cn{h.connection_anomaly:+.2f}]"
        )
        history_lines.append(
            f"  step {h.step}: {h.action} -> reward {h.reward:+.2f} "
            f"(V{h.vitality_delta:+.2f} C{h.cognition_delta:+.2f} "
            f"P{h.progress_delta:+.2f} S{h.serenity_delta:+.2f} Cn{h.connection_delta:+.2f})"
            f"{anom_str}"
        )
    history_str = ""
    if history_lines:
        history_str = (
            "\n\nRecent history (anom = how this person deviated from neutral baseline):\n"
            + "\n".join(history_lines)
        )

    hint_str = ""
    if profile_hint is not None:
        hint_str = (
            f"\n\nKnown about this person (training hint):\n"
            f"  social_pref={profile_hint['social_pref']:.2f}, "
            f"morning_pref={profile_hint['morning_pref']:.2f}, "
            f"work_pref={profile_hint['work_pref']:.2f}"
        )

    return (
        f"Step: {obs.timestep}/{MAX_STEPS} ({day_name} {slot_name})\n"
        f"Remaining steps: {obs.remaining_steps}\n\n"
        f"Meters:\n"
        f"  Vitality:   {obs.vitality:.2f}\n"
        f"  Cognition:  {obs.cognition:.2f}\n"
        f"  Progress:   {obs.progress:.2f}\n"
        f"  Serenity:   {obs.serenity:.2f}\n"
        f"  Connection: {obs.connection:.2f}"
        f"{event_str}"
        f"{history_str}"
        f"{hint_str}\n\n"
        f"Output your belief, then your action (format: S M W ACTION_NAME):"
    )


def generate_episode_samples(
    seed: int,
    strategy: str = "random",
    profile_mode: str = "continuous",
    show_profile_hint: bool = False,
) -> list:
    """Play one episode and return a list of training samples.

    Each sample includes the prompt + replay metadata (seed, step_index,
    action_history, profile_mode) so reward functions can deterministically
    reconstruct the env state.

    Args:
        seed: Episode seed (also determines profile when profile_mode=continuous).
        strategy: "random" or "heuristic" — used to roll out the episode for
            state diversity. The agent's training generations replace these
            actions; we only need the prefix history for replay.
        profile_mode: "continuous" (sampled per seed) or "discrete" (1 of 3
            hardcoded profiles).
        show_profile_hint: If True, include the true belief vector in the prompt.
            Use during the curriculum's "visible" warmup phase.
    """
    env = RhythmEnvironment()
    obs = env.reset(seed=seed, profile_mode=profile_mode)
    profile_hint = env.get_profile_hint() if show_profile_hint else None
    rng = random.Random(seed + 1000)
    actions_taken = []
    samples = []
    all_actions = list(ActionType)

    for step in range(MAX_STEPS):
        if obs.done:
            break

        prompt = format_observation_prompt(obs, profile_hint=profile_hint)

        samples.append({
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ],
            "seed": seed,
            "step_index": step,
            "action_history": list(actions_taken),
            "profile_mode": profile_mode,
            "show_profile_hint": show_profile_hint,
        })

        if strategy == "random":
            action_type = rng.choice(all_actions)
        elif strategy == "heuristic":
            action_type = heuristic_action(obs)
        else:
            action_type = rng.choice(all_actions)

        action = RhythmAction(action_type=action_type)
        actions_taken.append(action_type.value)
        obs = env.step(action)

    return samples


def heuristic_action(obs) -> ActionType:
    """Priority-based heuristic baseline (profile-blind).

    Used both during dataset generation (to roll out diverse states) and
    by inference_eval as the heuristic baseline strategy.
    """
    slot = obs.slot
    v, c, p, s, cn = obs.vitality, obs.cognition, obs.progress, obs.serenity, obs.connection

    if v < 0.15:
        return ActionType.SLEEP
    if s < 0.15:
        return ActionType.MEDITATE
    if cn < 0.15:
        return ActionType.FAMILY_TIME
    if slot == 3:
        return ActionType.SLEEP
    if slot == 0:
        return ActionType.DEEP_WORK if (v > 0.4 and c > 0.3) else ActionType.EXERCISE
    if slot == 1:
        if cn < 0.3:
            return ActionType.FAMILY_TIME
        if p < 0.3 and v > 0.3:
            return ActionType.LEARN
        return ActionType.ADMIN_WORK
    if cn < 0.4:
        return ActionType.SOCIALIZE
    if s < 0.5:
        return ActionType.ME_TIME
    return ActionType.MEDITATE


def generate_dataset(
    num_episodes: int = 200,
    strategy: str = "mixed",
    max_samples: int = 2000,
    profile_mode: str = "continuous",
    hint_fraction: float = 0.2,
) -> list:
    """Generate a training dataset by playing multiple episodes.

    Curriculum is baked into the dataset: hint_fraction of samples have the
    true profile visible (visible-phase warmup). After shuffle, GRPOTrainer
    sees a mix early on; we can sort to put hint samples first if needed.

    Args:
        num_episodes: Number of episodes to play.
        strategy: "random", "heuristic", or "mixed" (alternating).
        max_samples: Maximum samples to return.
        profile_mode: "continuous" (default, meta-RL) or "discrete" (3 profiles).
        hint_fraction: Fraction of episodes to play with profile hint visible.
    """
    all_samples = []
    n_hint_episodes = int(num_episodes * hint_fraction)

    for i in range(num_episodes):
        seed = i
        if strategy == "mixed":
            s = "heuristic" if i % 2 == 0 else "random"
        else:
            s = strategy
        show_hint = i < n_hint_episodes

        episode_samples = generate_episode_samples(
            seed=seed,
            strategy=s,
            profile_mode=profile_mode,
            show_profile_hint=show_hint,
        )
        all_samples.extend(episode_samples)

        if len(all_samples) >= max_samples:
            break

    # Shuffle (curriculum is per-sample via show_profile_hint flag, not order)
    random.shuffle(all_samples)
    all_samples = all_samples[:max_samples]

    n_hint = sum(1 for s in all_samples if s["show_profile_hint"])
    print(
        f"Generated {len(all_samples)} samples from {min(i+1, num_episodes)} episodes "
        f"({n_hint} with profile hint, {len(all_samples) - n_hint} without)"
    )
    return all_samples


if __name__ == "__main__":
    samples = generate_dataset(num_episodes=20, strategy="mixed", max_samples=80, hint_fraction=0.5)
    print(f"\nFirst sample (with hint):")
    hinted = next((s for s in samples if s["show_profile_hint"]), None)
    if hinted:
        print(hinted["prompt"][1]["content"])
        print(f"\nseed={hinted['seed']}, step={hinted['step_index']}, mode={hinted['profile_mode']}")

    print(f"\nFirst sample (without hint):")
    plain = next((s for s in samples if not s["show_profile_hint"]), None)
    if plain:
        print(plain["prompt"][1]["content"])