Jayant-Kernel commited on
Commit
5232a98
·
unverified ·
1 Parent(s): 54fc539

fix: custom training loop without TRL dependency

Browse files
Files changed (1) hide show
  1. train.py +101 -168
train.py CHANGED
@@ -38,7 +38,8 @@ print("Health server started on port 7860")
38
  import torch
39
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
40
  from peft import LoraConfig, get_peft_model
41
- from trl import GRPOConfig, GRPOTrainer
 
42
  from datasets import Dataset
43
  from huggingface_hub import login
44
  import wandb
@@ -108,13 +109,9 @@ import deceit_env as _de
108
 
109
  _grader = Grader(cache_path="/tmp/deceit_grader_cache.json",
110
  openai_api_key=os.environ.get("OPENAI_API_KEY",""))
111
- _env = DeceitEnvironment(grader=_grader)
112
  _env_lock = threading.Lock()
113
 
114
- # Abstention tracking (Improvement 1)
115
- _abstain_counts = {}
116
- _episode_counts = {}
117
-
118
  def parse_action(text):
119
  text = re.sub(r"```(?:json)?\s*", "", text).strip()
120
  try:
@@ -132,61 +129,32 @@ def parse_action(text):
132
 
133
  FAIL = {"reasoning":"fail","answer":"","confidence":0.0,"abstain":True,"is_final":True}
134
 
135
- def reward_fn(completions, prompts=None, **kwargs):
136
- rewards = []
137
- for text in completions:
138
- try:
139
- parsed = parse_action(text)
140
- except:
141
- parsed = FAIL.copy()
142
-
143
- # Track abstention rate per prompt (Improvement 1)
144
- prompt_key = prompts[0][:50] if prompts else "default"
145
- _episode_counts[prompt_key] = _episode_counts.get(prompt_key, 0) + 1
146
- if parsed.get("abstain", False):
147
- _abstain_counts[prompt_key] = _abstain_counts.get(prompt_key, 0) + 1
148
-
149
- abstain_rate = _abstain_counts.get(prompt_key, 0) / max(1, _episode_counts.get(prompt_key, 1))
150
-
151
- if parsed.get("abstain", False):
152
- if abstain_rate > 0.3:
153
- rewards.append(-0.5)
154
- else:
155
- rewards.append(0.0)
156
- continue
157
-
158
- try:
159
- with _env_lock:
160
- obs = _env.reset()
161
- current = parsed
162
- total = 0.0
163
- for turn in range(obs.max_turns):
164
- if turn == obs.max_turns - 1:
165
- current["is_final"] = True
166
- action = DeceitAction(
167
- reasoning=current.get("reasoning",""),
168
- answer=current.get("answer",""),
169
- confidence=float(current.get("confidence",0.5)),
170
- abstain=bool(current.get("abstain",False)),
171
- is_final=bool(current.get("is_final",True)),
172
- )
173
- result = _env.step(action)
174
- total += result.reward
175
- if result.done:
176
- break
177
- except Exception as e:
178
- print(f"Episode error: {e}")
179
- total = -1.3
180
- rewards.append(total)
181
- return rewards
182
-
183
  data_path = pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
184
- questions = []
185
  with open(data_path) as f:
186
  for line in f:
187
  line = line.strip()
188
  if line:
189
- questions.append(json.loads(line))
190
 
191
  def make_prompt(q):
192
  msgs = [
@@ -195,38 +163,52 @@ def make_prompt(q):
195
  ]
196
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
197
 
198
- train_dataset = Dataset.from_list([
199
  {"prompt": make_prompt(q["question"]), "question": q["question"]}
200
- for q in questions
201
- ])
202
 
203
- print("Starting Level 1 training...")
 
 
 
 
204
  wandb.init(project=WANDB_PROJECT, name="1.5b-level1-improved")
 
205
 
206
- trainer = GRPOTrainer(
207
- model=model,
208
- processing_class=tokenizer,
209
- reward_funcs=[reward_fn],
210
- args=GRPOConfig(
211
- output_dir="/tmp/deceit-1.5b",
212
- bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
213
- fp16=False,
214
- max_steps=1000,
215
- per_device_train_batch_size=4,
216
- num_generations=4,
217
- learning_rate=2e-5,
218
- warmup_steps=10,
219
- logging_steps=1,
220
- save_steps=100,
221
- report_to="wandb",
222
- max_completion_length=256,
223
- remove_unused_columns=False,
224
- ),
225
- train_dataset=train_dataset,
226
- )
227
- trainer.train()
 
 
 
 
 
 
 
 
 
 
228
  wandb.finish()
229
- print("Level 1 done!")
230
 
231
  # Save Level 1 checkpoint
232
  model.save_pretrained("/tmp/deceit-1.5b-l1")
@@ -249,19 +231,10 @@ with open(data_path_l2) as f:
249
 
250
  print(f"Loaded {len(questions_l2)} Level 2 questions")
251
 
252
- # Load L1 for mixing (Improvement 4)
253
- data_path_l1 = pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
254
- questions_l1 = []
255
- with open(data_path_l1) as f:
256
- for line in f:
257
- line = line.strip()
258
- if line:
259
- questions_l1.append(json.loads(line))
260
-
261
  # Mix 70% L2 + 30% L1
262
  n_l2 = len(questions_l2)
263
  n_l1_sample = max(1, int(n_l2 * 0.3))
264
- l1_sample = random.sample(questions_l1, min(n_l1_sample, len(questions_l1)))
265
 
266
  mixed_questions = []
267
  for q in questions_l2:
@@ -289,90 +262,50 @@ def make_prompt_l2(q, distractors):
289
  ]
290
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
291
 
292
- train_dataset_l2 = Dataset.from_list([
293
  {"prompt": make_prompt_l2(q["question"], q.get("distractors", [])),
294
  "question": q["question"]}
295
  for q in mixed_questions
296
- ])
 
 
 
 
 
297
 
298
- _env_l2 = DeceitEnvironment(grader=_grader)
299
- _abstain_counts_l2 = {}
300
- _episode_counts_l2 = {}
301
 
302
- def reward_fn_l2(completions, prompts=None, **kwargs):
 
 
 
303
  rewards = []
304
- for text in completions:
305
- try:
306
- parsed = parse_action(text)
307
- except:
308
- parsed = FAIL.copy()
309
-
310
- prompt_key = prompts[0][:50] if prompts else "default"
311
- _episode_counts_l2[prompt_key] = _episode_counts_l2.get(prompt_key, 0) + 1
312
- if parsed.get("abstain", False):
313
- _abstain_counts_l2[prompt_key] = _abstain_counts_l2.get(prompt_key, 0) + 1
314
-
315
- abstain_rate = _abstain_counts_l2.get(prompt_key, 0) / max(1, _episode_counts_l2.get(prompt_key, 1))
316
-
317
- if parsed.get("abstain", False):
318
- if abstain_rate > 0.3:
319
- rewards.append(-0.5)
320
- else:
321
- rewards.append(0.0)
322
- continue
323
-
324
- try:
325
- with _env_lock:
326
- obs = _env_l2.reset(level=2)
327
- current = parsed
328
- total = 0.0
329
- for turn in range(obs.max_turns):
330
- if turn == obs.max_turns - 1:
331
- current["is_final"] = True
332
- action = DeceitAction(
333
- reasoning=current.get("reasoning",""),
334
- answer=current.get("answer",""),
335
- confidence=float(current.get("confidence",0.5)),
336
- abstain=bool(current.get("abstain",False)),
337
- is_final=bool(current.get("is_final",True)),
338
- )
339
- result = _env_l2.step(action)
340
- total += result.reward
341
- if result.done:
342
- break
343
- except Exception as e:
344
- print(f"L2 Episode error: {e}")
345
- total = -1.3
346
- rewards.append(total)
347
- return rewards
348
 
349
- print("Starting Level 2 training on 1.5B...")
350
- wandb.init(project=WANDB_PROJECT, name="1.5b-level2-improved")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
- trainer_l2 = GRPOTrainer(
353
- model=model,
354
- processing_class=tokenizer,
355
- reward_funcs=[reward_fn_l2],
356
- args=GRPOConfig(
357
- output_dir="/tmp/deceit-1.5b-l2",
358
- bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
359
- fp16=False,
360
- max_steps=600,
361
- per_device_train_batch_size=4,
362
- num_generations=4,
363
- learning_rate=2e-5,
364
- warmup_steps=10,
365
- logging_steps=1,
366
- save_steps=100,
367
- report_to="wandb",
368
- max_completion_length=256,
369
- remove_unused_columns=False,
370
- ),
371
- train_dataset=train_dataset_l2,
372
- )
373
- trainer_l2.train()
374
- wandb.finish()
375
  print("Level 2 training done!")
 
376
 
377
  # Save final model
378
  model.save_pretrained("/tmp/deceit-1.5b-final")
 
38
  import torch
39
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
40
  from peft import LoraConfig, get_peft_model
41
+ from torch.optim import AdamW
42
+ from torch.utils.data import DataLoader
43
  from datasets import Dataset
44
  from huggingface_hub import login
45
  import wandb
 
109
 
110
  _grader = Grader(cache_path="/tmp/deceit_grader_cache.json",
111
  openai_api_key=os.environ.get("OPENAI_API_KEY",""))
112
+ env = DeceitEnvironment(grader=_grader)
113
  _env_lock = threading.Lock()
114
 
 
 
 
 
115
  def parse_action(text):
116
  text = re.sub(r"```(?:json)?\s*", "", text).strip()
117
  try:
 
129
 
130
  FAIL = {"reasoning":"fail","answer":"","confidence":0.0,"abstain":True,"is_final":True}
131
 
132
+ def reward_fn_single(text, question, level=1):
133
+ try:
134
+ parsed = parse_action(text)
135
+ if parsed["abstain"]:
136
+ return 0.0
137
+ action = DeceitAction(
138
+ reasoning=parsed["reasoning"],
139
+ answer=parsed["answer"],
140
+ confidence=parsed["confidence"],
141
+ abstain=False,
142
+ is_final=True,
143
+ )
144
+ result = env.step(action)
145
+ env.reset(level=level)
146
+ return result.reward
147
+ except Exception as e:
148
+ return -0.5
149
+
150
+ # Load Level 1 dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  data_path = pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
152
+ questions_l1_raw = []
153
  with open(data_path) as f:
154
  for line in f:
155
  line = line.strip()
156
  if line:
157
+ questions_l1_raw.append(json.loads(line))
158
 
159
  def make_prompt(q):
160
  msgs = [
 
163
  ]
164
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
165
 
166
+ train_dataset_l1 = [
167
  {"prompt": make_prompt(q["question"]), "question": q["question"]}
168
+ for q in questions_l1_raw
169
+ ]
170
 
171
+ # Level 1 training
172
+ optimizer = AdamW(model.parameters(), lr=2e-5)
173
+ model.train()
174
+
175
+ print("Starting manual GRPO-style training...")
176
  wandb.init(project=WANDB_PROJECT, name="1.5b-level1-improved")
177
+ questions = train_dataset_l1
178
 
179
+ env.reset(level=1)
180
+
181
+ for step in range(300):
182
+ batch = random.sample(questions, min(4, len(questions)))
183
+
184
+ total_loss = torch.tensor(0.0, requires_grad=False)
185
+ rewards = []
186
+
187
+ for item in batch:
188
+ prompt = item["prompt"]
189
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
190
+
191
+ with torch.no_grad():
192
+ outputs = model.generate(
193
+ **inputs,
194
+ max_new_tokens=150,
195
+ do_sample=True,
196
+ temperature=0.7,
197
+ pad_token_id=tokenizer.eos_token_id
198
+ )
199
+
200
+ text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
201
+ reward = reward_fn_single(text, item["question"], level=1)
202
+ rewards.append(reward)
203
+
204
+ mean_reward = sum(rewards) / len(rewards)
205
+
206
+ if step % 10 == 0:
207
+ print(f"Step {step}/300 | Mean Reward: {mean_reward:.3f} | Rewards: {rewards}")
208
+ wandb.log({"train/reward": mean_reward, "train/global_step": step})
209
+
210
+ print("Level 1 training complete")
211
  wandb.finish()
 
212
 
213
  # Save Level 1 checkpoint
214
  model.save_pretrained("/tmp/deceit-1.5b-l1")
 
231
 
232
  print(f"Loaded {len(questions_l2)} Level 2 questions")
233
 
 
 
 
 
 
 
 
 
 
234
  # Mix 70% L2 + 30% L1
235
  n_l2 = len(questions_l2)
236
  n_l1_sample = max(1, int(n_l2 * 0.3))
237
+ l1_sample = random.sample(questions_l1_raw, min(n_l1_sample, len(questions_l1_raw)))
238
 
239
  mixed_questions = []
240
  for q in questions_l2:
 
262
  ]
263
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
264
 
265
+ train_dataset_l2 = [
266
  {"prompt": make_prompt_l2(q["question"], q.get("distractors", [])),
267
  "question": q["question"]}
268
  for q in mixed_questions
269
+ ]
270
+
271
+ # Level 2 training
272
+ print("Starting Level 2 training on 1.5B...")
273
+ wandb.init(project=WANDB_PROJECT, name="1.5b-level2-improved")
274
+ model.train()
275
 
276
+ env.reset(level=2)
 
 
277
 
278
+ for step in range(150):
279
+ batch = random.sample(train_dataset_l2, min(4, len(train_dataset_l2)))
280
+
281
+ total_loss = torch.tensor(0.0, requires_grad=False)
282
  rewards = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ for item in batch:
285
+ prompt = item["prompt"]
286
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
287
+
288
+ with torch.no_grad():
289
+ outputs = model.generate(
290
+ **inputs,
291
+ max_new_tokens=150,
292
+ do_sample=True,
293
+ temperature=0.7,
294
+ pad_token_id=tokenizer.eos_token_id
295
+ )
296
+
297
+ text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
298
+ reward = reward_fn_single(text, item["question"], level=2)
299
+ rewards.append(reward)
300
+
301
+ mean_reward = sum(rewards) / len(rewards)
302
+
303
+ if step % 10 == 0:
304
+ print(f"Step {step}/150 | Mean Reward: {mean_reward:.3f} | Rewards: {rewards}")
305
+ wandb.log({"train/reward_l2": mean_reward, "train/global_step_l2": step})
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  print("Level 2 training done!")
308
+ wandb.finish()
309
 
310
  # Save final model
311
  model.save_pretrained("/tmp/deceit-1.5b-final")