Jayant-Kernel commited on
Commit ·
825578d
1
Parent(s): 0bdaeb6
add: Level 2 training for 1.5B model after Level 1
Browse files
train.py
CHANGED
|
@@ -175,8 +175,101 @@ trainer.train()
|
|
| 175 |
wandb.finish()
|
| 176 |
print("Training done!")
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
model.save_pretrained("deceit-1.5b-final")
|
| 179 |
tokenizer.save_pretrained("deceit-1.5b-final")
|
| 180 |
model.push_to_hub(HF_REPO_ID)
|
| 181 |
tokenizer.push_to_hub(HF_REPO_ID)
|
| 182 |
-
print(f"
|
|
|
|
| 175 |
wandb.finish()
|
| 176 |
print("Training done!")
|
| 177 |
|
| 178 |
+
# Save Level 1 checkpoint
|
| 179 |
+
model.save_pretrained("deceit-1.5b-l1")
|
| 180 |
+
tokenizer.save_pretrained("deceit-1.5b-l1")
|
| 181 |
+
print("Level 1 checkpoint saved locally")
|
| 182 |
+
|
| 183 |
+
# Load Level 2 dataset
|
| 184 |
+
import pathlib as _pl2
|
| 185 |
+
data_path_l2 = _pl2.Path("/home/trainer/.local/lib/python3.10/site-packages/deceit_env/data/level2.jsonl")
|
| 186 |
+
questions_l2 = []
|
| 187 |
+
with open(data_path_l2) as f:
|
| 188 |
+
for line in f:
|
| 189 |
+
line = line.strip()
|
| 190 |
+
if line:
|
| 191 |
+
questions_l2.append(json.loads(line))
|
| 192 |
+
|
| 193 |
+
print(f"Loaded {len(questions_l2)} Level 2 questions")
|
| 194 |
+
|
| 195 |
+
def make_prompt_l2(q, distractors):
|
| 196 |
+
context = "\n".join(distractors)
|
| 197 |
+
msgs = [
|
| 198 |
+
{"role":"system","content":SYSTEM_PROMPT},
|
| 199 |
+
{"role":"user","content":f"Question: {q}\n\nContext:\n{context}\n\nTurn 1 of 3. Respond in JSON."},
|
| 200 |
+
]
|
| 201 |
+
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 202 |
+
|
| 203 |
+
train_dataset_l2 = Dataset.from_list([
|
| 204 |
+
{"prompt": make_prompt_l2(q["question"], q.get("distractors", [])), "question": q["question"]}
|
| 205 |
+
for q in questions_l2
|
| 206 |
+
])
|
| 207 |
+
|
| 208 |
+
# Update env to level 2
|
| 209 |
+
_env_l2 = DeceitEnvironment(grader=_grader)
|
| 210 |
+
|
| 211 |
+
def reward_fn_l2(completions, prompts=None, **kwargs):
|
| 212 |
+
rewards = []
|
| 213 |
+
for text in completions:
|
| 214 |
+
try:
|
| 215 |
+
parsed = parse_action(text)
|
| 216 |
+
except:
|
| 217 |
+
parsed = FAIL.copy()
|
| 218 |
+
try:
|
| 219 |
+
with _env_lock:
|
| 220 |
+
obs = _env_l2.reset(level=2)
|
| 221 |
+
current = parsed
|
| 222 |
+
total = 0.0
|
| 223 |
+
for turn in range(obs.max_turns):
|
| 224 |
+
if turn == obs.max_turns - 1:
|
| 225 |
+
current["is_final"] = True
|
| 226 |
+
action = DeceitAction(
|
| 227 |
+
reasoning=current.get("reasoning",""),
|
| 228 |
+
answer=current.get("answer",""),
|
| 229 |
+
confidence=float(current.get("confidence",0.5)),
|
| 230 |
+
abstain=bool(current.get("abstain",False)),
|
| 231 |
+
is_final=bool(current.get("is_final",True)),
|
| 232 |
+
)
|
| 233 |
+
result = _env_l2.step(action)
|
| 234 |
+
total += result.reward
|
| 235 |
+
if result.done:
|
| 236 |
+
break
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f"L2 Episode error: {e}")
|
| 239 |
+
total = -1.3
|
| 240 |
+
rewards.append(total)
|
| 241 |
+
return rewards
|
| 242 |
+
|
| 243 |
+
# Train Level 2
|
| 244 |
+
print("Starting Level 2 training on 1.5B...")
|
| 245 |
+
wandb.init(project=WANDB_PROJECT, name="1.5b-level2")
|
| 246 |
+
|
| 247 |
+
trainer_l2 = GRPOTrainer(
|
| 248 |
+
model=model,
|
| 249 |
+
processing_class=tokenizer,
|
| 250 |
+
reward_funcs=[reward_fn_l2],
|
| 251 |
+
args=GRPOConfig(
|
| 252 |
+
output_dir="./deceit-1.5b-l2",
|
| 253 |
+
max_steps=80,
|
| 254 |
+
per_device_train_batch_size=4,
|
| 255 |
+
num_generations=4,
|
| 256 |
+
learning_rate=2e-6,
|
| 257 |
+
warmup_steps=5,
|
| 258 |
+
logging_steps=1,
|
| 259 |
+
save_steps=40,
|
| 260 |
+
report_to="wandb",
|
| 261 |
+
max_completion_length=256,
|
| 262 |
+
remove_unused_columns=False,
|
| 263 |
+
),
|
| 264 |
+
train_dataset=train_dataset_l2,
|
| 265 |
+
)
|
| 266 |
+
trainer_l2.train()
|
| 267 |
+
wandb.finish()
|
| 268 |
+
print("Level 2 training done!")
|
| 269 |
+
|
| 270 |
+
# Save final model
|
| 271 |
model.save_pretrained("deceit-1.5b-final")
|
| 272 |
tokenizer.save_pretrained("deceit-1.5b-final")
|
| 273 |
model.push_to_hub(HF_REPO_ID)
|
| 274 |
tokenizer.push_to_hub(HF_REPO_ID)
|
| 275 |
+
print(f"Final model saved to {HF_REPO_ID}")
|