File size: 9,032 Bytes
cc6473a
ecbe0d8
cc6473a
ecbe0d8
 
cc6473a
ecbe0d8
 
 
 
 
 
cc6473a
ecbe0d8
cc6473a
ece0bbe
 
cc6473a
 
 
 
 
 
 
 
 
 
 
ecbe0d8
 
 
 
 
 
 
cc6473a
dc0186f
 
ecbe0d8
dc0186f
 
ecbe0d8
 
e21a960
 
 
 
ecbe0d8
 
cc6473a
 
 
 
 
 
ecbe0d8
cc6473a
 
 
 
 
 
 
 
 
ecbe0d8
cc6473a
 
ecbe0d8
cc6473a
 
 
 
 
 
ecbe0d8
cc6473a
 
 
 
 
 
 
 
 
ecbe0d8
cc6473a
 
 
 
ecbe0d8
cc6473a
 
 
 
 
 
 
 
 
ecbe0d8
cc6473a
 
 
 
ecbe0d8
cc6473a
 
 
 
ecbe0d8
cc6473a
 
ecbe0d8
cc6473a
 
ecbe0d8
cc6473a
 
ece0bbe
cc6473a
ece0bbe
 
cc6473a
 
ecbe0d8
cc6473a
 
ecbe0d8
cc6473a
 
 
 
dc0186f
 
cc6473a
ece0bbe
 
 
 
 
e21a960
dc0186f
 
e21a960
cc6473a
ecbe0d8
 
cc6473a
 
 
 
 
 
 
 
 
 
 
dc0186f
cc6473a
 
 
dc0186f
 
 
 
 
 
 
cc6473a
ecbe0d8
 
 
 
cc6473a
 
 
 
 
ecbe0d8
cc6473a
 
 
 
 
 
 
 
 
 
 
 
 
ecbe0d8
cc6473a
 
ecbe0d8
cc6473a
 
 
 
 
 
 
 
 
 
 
 
73c7ea0
 
 
 
 
ecbe0d8
 
73c7ea0
ecbe0d8
 
cc6473a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
RhythmEnv GRPO Training Script (Meta-RL version).

Trains an LLM agent to BOTH (a) balance life meters AND (b) infer the hidden
personality of the person it's helping. Four-layer reward stack:

  format_valid     — output parseable as ACTION + 3 belief digits
  action_legal     — action is one of 10 valid types
  env_reward       — actual env reward for the chosen action (seed replay)
  belief_accuracy  — how close the belief vector is to the hidden profile

Usage (Colab T4):
    !pip install unsloth transformers trl datasets
    !python training/train.py --max_steps 1500

Setup-check (no GPU): run the smoke tests instead of starting a real run:
    python -m pytest tests/test_pipeline_smoke.py -q
"""

import argparse
import json
import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))


def main():
    parser = argparse.ArgumentParser(description="Train RhythmEnv agent with GRPO (meta-RL)")
    parser.add_argument("--model_name", type=str, default="unsloth/Qwen2.5-3B-Instruct")
    parser.add_argument("--max_steps", type=int, default=1500,
                        help="Number of GRPO training steps (1500 recommended for meta-RL)")
    parser.add_argument("--num_episodes", type=int, default=300,
                        help="Number of episodes for dataset generation (more diversity = better meta-RL)")
    parser.add_argument("--max_samples", type=int, default=3000,
                        help="Maximum training samples")
    parser.add_argument("--num_generations", type=int, default=8,
                        help="Completions per prompt for GRPO (8 default, lower variance for continuous-profile meta-RL)")
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--beta", type=float, default=0.04,
                        help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
    parser.add_argument("--lora_rank", type=int, default=8,
                        help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
    parser.add_argument("--hint_fraction", type=float, default=0.0,
                        help="Fraction of dataset with profile hint visible. Default 0.0 (no hints) "
                             "to eliminate train-eval distribution mismatch. Set >0 only if you ALSO "
                             "show hints during eval.")
    parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_meta_trained")
    parser.add_argument("--report_to", type=str, default="none")
    args = parser.parse_args()

    # ---------------------------------------------------------------
    # 1. Generate dataset
    # ---------------------------------------------------------------
    print("=" * 60)
    print("Step 1: Generating training dataset (continuous profiles)")
    print("=" * 60)

    from dataset import generate_dataset
    from datasets import Dataset

    raw_samples = generate_dataset(
        num_episodes=args.num_episodes,
        strategy="mixed",
        max_samples=args.max_samples,
        hint_fraction=args.hint_fraction,
    )

    # Replay metadata so env_reward + belief_accuracy can reconstruct state
    dataset = Dataset.from_list([
        {
            "prompt": sample["prompt"],
            "seed": sample["seed"],
            "step_index": sample["step_index"],
            "action_history": sample["action_history"],
            "profile_mode": sample["profile_mode"],
        }
        for sample in raw_samples
    ])
    print(f"Dataset size: {len(dataset)}")

    # ---------------------------------------------------------------
    # 2. Load model with Unsloth
    # ---------------------------------------------------------------
    print("\n" + "=" * 60)
    print(f"Step 2: Loading model {args.model_name}")
    print("=" * 60)

    from unsloth import FastLanguageModel

    max_seq_length = 1024  # bumped from 768 to fit longer prompts with history

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model_name,
        load_in_4bit=True,
        max_seq_length=max_seq_length,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=args.lora_rank,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_alpha=args.lora_rank * 2,
        use_gradient_checkpointing="unsloth",
        random_state=3407,
    )

    print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_rank * 2}")

    # ---------------------------------------------------------------
    # 3. Reward functions (4-layer stack including belief_accuracy)
    # ---------------------------------------------------------------
    print("\n" + "=" * 60)
    print("Step 3: Setting up reward functions")
    print("=" * 60)

    from reward_functions import format_valid, action_legal, env_reward, belief_accuracy

    reward_funcs = [format_valid, action_legal, env_reward, belief_accuracy]
    print("Using: format_valid + action_legal + env_reward + belief_accuracy")

    # ---------------------------------------------------------------
    # 4. GRPO trainer config
    # ---------------------------------------------------------------
    print("\n" + "=" * 60)
    print("Step 4: Configuring GRPO trainer")
    print("=" * 60)

    from trl import GRPOConfig, GRPOTrainer

    max_prompt_length = 600          # history + hint room
    max_completion_length = 32       # bumped from 20 to prevent silent truncation of belief digits

    # reward_weights: suppress the format/action_legal layers (small, low-variance
    # signals — too constant across a GRPO group to contribute meaningful advantage)
    # and amplify the variable signals env_reward and belief_accuracy. belief_accuracy
    # at 3.0 is the dominant learning signal.
    # Order MUST match reward_funcs above: format_valid, action_legal, env_reward, belief_accuracy
    reward_weights = [0.05, 0.05, 1.5, 3.0]

    training_args_kwargs = dict(
        temperature=1.5,  # bumped from 1.0 to force diverse rollouts and break mode collapse
        learning_rate=args.learning_rate,
        beta=args.beta,
        max_grad_norm=0.5,
        weight_decay=0.001,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        optim="adamw_8bit",
        logging_steps=1,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_generations=args.num_generations,
        max_prompt_length=max_prompt_length,
        max_completion_length=max_completion_length,
        max_steps=args.max_steps,
        save_steps=250,
        report_to=args.report_to,
        output_dir=args.output_dir,
    )
    # reward_weights was added in TRL 0.13+; pass only if supported
    try:
        training_args = GRPOConfig(**training_args_kwargs, reward_weights=reward_weights)
        print(f"Using GRPOConfig with reward_weights={reward_weights}")
    except TypeError:
        training_args = GRPOConfig(**training_args_kwargs)
        print("WARN: TRL version does not support reward_weights; using uniform weighting")

    print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, "
          f"lr={args.learning_rate}, beta={args.beta}")
    print(f"max_prompt_length={max_prompt_length}, max_completion_length={max_completion_length}")
    print(f"hint_fraction={args.hint_fraction} (curriculum warmup)")

    # ---------------------------------------------------------------
    # 5. Train
    # ---------------------------------------------------------------
    print("\n" + "=" * 60)
    print("Step 5: Starting GRPO training")
    print("=" * 60)

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()

    # ---------------------------------------------------------------
    # 6. Save merged model
    # ---------------------------------------------------------------
    print("\n" + "=" * 60)
    print("Step 6: Saving model")
    print("=" * 60)

    model.save_pretrained_merged(
        args.output_dir,
        tokenizer,
        save_method="merged_16bit",
    )

    config_path = os.path.join(args.output_dir, "training_config.json")
    with open(config_path, "w") as f:
        json.dump(vars(args), f, indent=2)

    # Save log_history for offline plotting (job runs don't have a notebook to inspect trainer.state)
    log_path = os.path.join(args.output_dir, "log_history.json")
    with open(log_path, "w") as f:
        json.dump(trainer.state.log_history, f, indent=2)

    print(f"Model saved to: {args.output_dir}")
    print(f"Training config saved to: {config_path}")
    print(f"Log history saved to: {log_path}")
    print("\nNext: run inference_eval.py to compare baseline vs trained")
    print("      python training/inference_eval.py --model_path " + args.output_dir)


if __name__ == "__main__":
    main()