Jayant-Kernel commited on
Commit
825578d
·
1 Parent(s): 0bdaeb6

add: Level 2 training for 1.5B model after Level 1

Browse files
Files changed (1) hide show
  1. train.py +94 -1
train.py CHANGED
@@ -175,8 +175,101 @@ trainer.train()
175
  wandb.finish()
176
  print("Training done!")
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  model.save_pretrained("deceit-1.5b-final")
179
  tokenizer.save_pretrained("deceit-1.5b-final")
180
  model.push_to_hub(HF_REPO_ID)
181
  tokenizer.push_to_hub(HF_REPO_ID)
182
- print(f"Saved to {HF_REPO_ID}")
 
175
  wandb.finish()
176
  print("Training done!")
177
 
178
+ # Save Level 1 checkpoint
179
+ model.save_pretrained("deceit-1.5b-l1")
180
+ tokenizer.save_pretrained("deceit-1.5b-l1")
181
+ print("Level 1 checkpoint saved locally")
182
+
183
+ # Load Level 2 dataset
184
+ import pathlib as _pl2
185
+ data_path_l2 = _pl2.Path("/home/trainer/.local/lib/python3.10/site-packages/deceit_env/data/level2.jsonl")
186
+ questions_l2 = []
187
+ with open(data_path_l2) as f:
188
+ for line in f:
189
+ line = line.strip()
190
+ if line:
191
+ questions_l2.append(json.loads(line))
192
+
193
+ print(f"Loaded {len(questions_l2)} Level 2 questions")
194
+
195
+ def make_prompt_l2(q, distractors):
196
+ context = "\n".join(distractors)
197
+ msgs = [
198
+ {"role":"system","content":SYSTEM_PROMPT},
199
+ {"role":"user","content":f"Question: {q}\n\nContext:\n{context}\n\nTurn 1 of 3. Respond in JSON."},
200
+ ]
201
+ return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
202
+
203
+ train_dataset_l2 = Dataset.from_list([
204
+ {"prompt": make_prompt_l2(q["question"], q.get("distractors", [])), "question": q["question"]}
205
+ for q in questions_l2
206
+ ])
207
+
208
+ # Update env to level 2
209
+ _env_l2 = DeceitEnvironment(grader=_grader)
210
+
211
+ def reward_fn_l2(completions, prompts=None, **kwargs):
212
+ rewards = []
213
+ for text in completions:
214
+ try:
215
+ parsed = parse_action(text)
216
+ except:
217
+ parsed = FAIL.copy()
218
+ try:
219
+ with _env_lock:
220
+ obs = _env_l2.reset(level=2)
221
+ current = parsed
222
+ total = 0.0
223
+ for turn in range(obs.max_turns):
224
+ if turn == obs.max_turns - 1:
225
+ current["is_final"] = True
226
+ action = DeceitAction(
227
+ reasoning=current.get("reasoning",""),
228
+ answer=current.get("answer",""),
229
+ confidence=float(current.get("confidence",0.5)),
230
+ abstain=bool(current.get("abstain",False)),
231
+ is_final=bool(current.get("is_final",True)),
232
+ )
233
+ result = _env_l2.step(action)
234
+ total += result.reward
235
+ if result.done:
236
+ break
237
+ except Exception as e:
238
+ print(f"L2 Episode error: {e}")
239
+ total = -1.3
240
+ rewards.append(total)
241
+ return rewards
242
+
243
+ # Train Level 2
244
+ print("Starting Level 2 training on 1.5B...")
245
+ wandb.init(project=WANDB_PROJECT, name="1.5b-level2")
246
+
247
+ trainer_l2 = GRPOTrainer(
248
+ model=model,
249
+ processing_class=tokenizer,
250
+ reward_funcs=[reward_fn_l2],
251
+ args=GRPOConfig(
252
+ output_dir="./deceit-1.5b-l2",
253
+ max_steps=80,
254
+ per_device_train_batch_size=4,
255
+ num_generations=4,
256
+ learning_rate=2e-6,
257
+ warmup_steps=5,
258
+ logging_steps=1,
259
+ save_steps=40,
260
+ report_to="wandb",
261
+ max_completion_length=256,
262
+ remove_unused_columns=False,
263
+ ),
264
+ train_dataset=train_dataset_l2,
265
+ )
266
+ trainer_l2.train()
267
+ wandb.finish()
268
+ print("Level 2 training done!")
269
+
270
+ # Save final model
271
  model.save_pretrained("deceit-1.5b-final")
272
  tokenizer.save_pretrained("deceit-1.5b-final")
273
  model.push_to_hub(HF_REPO_ID)
274
  tokenizer.push_to_hub(HF_REPO_ID)
275
+ print(f"Final model saved to {HF_REPO_ID}")