Jayant-Kernel commited on
Commit
0efac4a
·
unverified ·
1 Parent(s): 5232a98

fix: proper GRPO with trl 0.12.2 no-deps + force hub downgrade

Browse files
Files changed (2) hide show
  1. Dockerfile +8 -11
  2. train.py +165 -101
Dockerfile CHANGED
@@ -12,20 +12,17 @@ WORKDIR /app
12
 
13
  RUN pip install --no-cache-dir torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121
14
 
15
- RUN pip install --no-cache-dir \
16
- "huggingface_hub==0.24.7" \
17
- "transformers==4.45.2" \
18
- "tokenizers==0.20.3" \
19
- "accelerate==0.34.2" \
20
- "peft==0.12.0" \
21
- "datasets==2.21.0" \
22
- "bitsandbytes==0.44.0" \
23
- wandb matplotlib Pillow
24
 
25
  RUN pip install --no-cache-dir "trl==0.12.2" --no-deps
26
 
27
- RUN pip install --no-cache-dir \
28
- git+https://github.com/Jayant-kernel/DECEIT-the-ai-truth-environment-.git
 
29
 
30
  RUN pip install --no-cache-dir --force-reinstall "huggingface_hub==0.24.7"
31
 
 
12
 
13
  RUN pip install --no-cache-dir torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121
14
 
15
+ RUN python -c "import torch; print('CUDA:', torch.cuda.is_available()); print('Version:', torch.version.cuda)"
16
+
17
+ RUN pip install --no-cache-dir "huggingface_hub==0.24.7"
18
+
19
+ RUN pip install --no-cache-dir "transformers==4.45.2" "accelerate==0.34.2" "peft==0.12.0" "datasets==2.21.0" "bitsandbytes==0.44.0" wandb matplotlib Pillow
 
 
 
 
20
 
21
  RUN pip install --no-cache-dir "trl==0.12.2" --no-deps
22
 
23
+ RUN pip install --no-cache-dir "accelerate==0.34.2"
24
+
25
+ RUN pip install --no-cache-dir git+https://github.com/Jayant-kernel/DECEIT-the-ai-truth-environment-.git
26
 
27
  RUN pip install --no-cache-dir --force-reinstall "huggingface_hub==0.24.7"
28
 
train.py CHANGED
@@ -38,8 +38,7 @@ print("Health server started on port 7860")
38
  import torch
39
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
40
  from peft import LoraConfig, get_peft_model
41
- from torch.optim import AdamW
42
- from torch.utils.data import DataLoader
43
  from datasets import Dataset
44
  from huggingface_hub import login
45
  import wandb
@@ -109,9 +108,12 @@ import deceit_env as _de
109
 
110
  _grader = Grader(cache_path="/tmp/deceit_grader_cache.json",
111
  openai_api_key=os.environ.get("OPENAI_API_KEY",""))
112
- env = DeceitEnvironment(grader=_grader)
113
  _env_lock = threading.Lock()
114
 
 
 
 
115
  def parse_action(text):
116
  text = re.sub(r"```(?:json)?\s*", "", text).strip()
117
  try:
@@ -129,32 +131,60 @@ def parse_action(text):
129
 
130
  FAIL = {"reasoning":"fail","answer":"","confidence":0.0,"abstain":True,"is_final":True}
131
 
132
- def reward_fn_single(text, question, level=1):
133
- try:
134
- parsed = parse_action(text)
135
- if parsed["abstain"]:
136
- return 0.0
137
- action = DeceitAction(
138
- reasoning=parsed["reasoning"],
139
- answer=parsed["answer"],
140
- confidence=parsed["confidence"],
141
- abstain=False,
142
- is_final=True,
143
- )
144
- result = env.step(action)
145
- env.reset(level=level)
146
- return result.reward
147
- except Exception as e:
148
- return -0.5
149
-
150
- # Load Level 1 dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  data_path = pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
152
- questions_l1_raw = []
153
  with open(data_path) as f:
154
  for line in f:
155
  line = line.strip()
156
  if line:
157
- questions_l1_raw.append(json.loads(line))
158
 
159
  def make_prompt(q):
160
  msgs = [
@@ -163,52 +193,38 @@ def make_prompt(q):
163
  ]
164
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
165
 
166
- train_dataset_l1 = [
167
  {"prompt": make_prompt(q["question"]), "question": q["question"]}
168
- for q in questions_l1_raw
169
- ]
170
 
171
- # Level 1 training
172
- optimizer = AdamW(model.parameters(), lr=2e-5)
173
- model.train()
174
-
175
- print("Starting manual GRPO-style training...")
176
  wandb.init(project=WANDB_PROJECT, name="1.5b-level1-improved")
177
- questions = train_dataset_l1
178
-
179
- env.reset(level=1)
180
-
181
- for step in range(300):
182
- batch = random.sample(questions, min(4, len(questions)))
183
-
184
- total_loss = torch.tensor(0.0, requires_grad=False)
185
- rewards = []
186
-
187
- for item in batch:
188
- prompt = item["prompt"]
189
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
190
-
191
- with torch.no_grad():
192
- outputs = model.generate(
193
- **inputs,
194
- max_new_tokens=150,
195
- do_sample=True,
196
- temperature=0.7,
197
- pad_token_id=tokenizer.eos_token_id
198
- )
199
-
200
- text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
201
- reward = reward_fn_single(text, item["question"], level=1)
202
- rewards.append(reward)
203
 
204
- mean_reward = sum(rewards) / len(rewards)
205
-
206
- if step % 10 == 0:
207
- print(f"Step {step}/300 | Mean Reward: {mean_reward:.3f} | Rewards: {rewards}")
208
- wandb.log({"train/reward": mean_reward, "train/global_step": step})
209
-
210
- print("Level 1 training complete")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  wandb.finish()
 
212
 
213
  # Save Level 1 checkpoint
214
  model.save_pretrained("/tmp/deceit-1.5b-l1")
@@ -231,10 +247,18 @@ with open(data_path_l2) as f:
231
 
232
  print(f"Loaded {len(questions_l2)} Level 2 questions")
233
 
 
 
 
 
 
 
 
 
234
  # Mix 70% L2 + 30% L1
235
  n_l2 = len(questions_l2)
236
  n_l1_sample = max(1, int(n_l2 * 0.3))
237
- l1_sample = random.sample(questions_l1_raw, min(n_l1_sample, len(questions_l1_raw)))
238
 
239
  mixed_questions = []
240
  for q in questions_l2:
@@ -262,50 +286,90 @@ def make_prompt_l2(q, distractors):
262
  ]
263
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
264
 
265
- train_dataset_l2 = [
266
  {"prompt": make_prompt_l2(q["question"], q.get("distractors", [])),
267
  "question": q["question"]}
268
  for q in mixed_questions
269
- ]
270
-
271
- # Level 2 training
272
- print("Starting Level 2 training on 1.5B...")
273
- wandb.init(project=WANDB_PROJECT, name="1.5b-level2-improved")
274
- model.train()
275
 
276
- env.reset(level=2)
 
 
277
 
278
- for step in range(150):
279
- batch = random.sample(train_dataset_l2, min(4, len(train_dataset_l2)))
280
-
281
- total_loss = torch.tensor(0.0, requires_grad=False)
282
  rewards = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- for item in batch:
285
- prompt = item["prompt"]
286
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
287
-
288
- with torch.no_grad():
289
- outputs = model.generate(
290
- **inputs,
291
- max_new_tokens=150,
292
- do_sample=True,
293
- temperature=0.7,
294
- pad_token_id=tokenizer.eos_token_id
295
- )
296
-
297
- text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
298
- reward = reward_fn_single(text, item["question"], level=2)
299
- rewards.append(reward)
300
-
301
- mean_reward = sum(rewards) / len(rewards)
302
-
303
- if step % 10 == 0:
304
- print(f"Step {step}/150 | Mean Reward: {mean_reward:.3f} | Rewards: {rewards}")
305
- wandb.log({"train/reward_l2": mean_reward, "train/global_step_l2": step})
306
 
307
- print("Level 2 training done!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  wandb.finish()
 
309
 
310
  # Save final model
311
  model.save_pretrained("/tmp/deceit-1.5b-final")
 
38
  import torch
39
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
40
  from peft import LoraConfig, get_peft_model
41
+ from trl import GRPOConfig, GRPOTrainer
 
42
  from datasets import Dataset
43
  from huggingface_hub import login
44
  import wandb
 
108
 
109
  _grader = Grader(cache_path="/tmp/deceit_grader_cache.json",
110
  openai_api_key=os.environ.get("OPENAI_API_KEY",""))
111
+ _env = DeceitEnvironment(grader=_grader)
112
  _env_lock = threading.Lock()
113
 
114
+ _abstain_counts = {}
115
+ _episode_counts = {}
116
+
117
  def parse_action(text):
118
  text = re.sub(r"```(?:json)?\s*", "", text).strip()
119
  try:
 
131
 
132
  FAIL = {"reasoning":"fail","answer":"","confidence":0.0,"abstain":True,"is_final":True}
133
 
134
+ def reward_fn(completions, prompts=None, **kwargs):
135
+ rewards = []
136
+ for text in completions:
137
+ try:
138
+ parsed = parse_action(text)
139
+ except:
140
+ parsed = FAIL.copy()
141
+
142
+ prompt_key = prompts[0][:50] if prompts else "default"
143
+ _episode_counts[prompt_key] = _episode_counts.get(prompt_key, 0) + 1
144
+ if parsed.get("abstain", False):
145
+ _abstain_counts[prompt_key] = _abstain_counts.get(prompt_key, 0) + 1
146
+
147
+ abstain_rate = _abstain_counts.get(prompt_key, 0) / max(1, _episode_counts.get(prompt_key, 1))
148
+
149
+ if parsed.get("abstain", False):
150
+ if abstain_rate > 0.3:
151
+ rewards.append(-0.5)
152
+ else:
153
+ rewards.append(0.0)
154
+ continue
155
+
156
+ try:
157
+ with _env_lock:
158
+ obs = _env.reset()
159
+ current = parsed
160
+ total = 0.0
161
+ for turn in range(obs.max_turns):
162
+ if turn == obs.max_turns - 1:
163
+ current["is_final"] = True
164
+ action = DeceitAction(
165
+ reasoning=current.get("reasoning",""),
166
+ answer=current.get("answer",""),
167
+ confidence=float(current.get("confidence",0.5)),
168
+ abstain=bool(current.get("abstain",False)),
169
+ is_final=bool(current.get("is_final",True)),
170
+ )
171
+ result = _env.step(action)
172
+ total += result.reward
173
+ if result.done:
174
+ break
175
+ except Exception as e:
176
+ print(f"Episode error: {e}")
177
+ total = -1.3
178
+ rewards.append(total)
179
+ return rewards
180
+
181
  data_path = pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
182
+ questions = []
183
  with open(data_path) as f:
184
  for line in f:
185
  line = line.strip()
186
  if line:
187
+ questions.append(json.loads(line))
188
 
189
  def make_prompt(q):
190
  msgs = [
 
193
  ]
194
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
195
 
196
+ train_dataset = Dataset.from_list([
197
  {"prompt": make_prompt(q["question"]), "question": q["question"]}
198
+ for q in questions
199
+ ])
200
 
201
+ print("Starting Level 1 training...")
 
 
 
 
202
  wandb.init(project=WANDB_PROJECT, name="1.5b-level1-improved")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ trainer = GRPOTrainer(
205
+ model=model,
206
+ processing_class=tokenizer,
207
+ reward_funcs=[reward_fn],
208
+ args=GRPOConfig(
209
+ output_dir="/tmp/deceit-1.5b",
210
+ bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
211
+ fp16=False,
212
+ max_steps=1000,
213
+ per_device_train_batch_size=4,
214
+ num_generations=4,
215
+ learning_rate=2e-5,
216
+ warmup_steps=10,
217
+ logging_steps=1,
218
+ save_steps=100,
219
+ report_to="wandb",
220
+ max_completion_length=256,
221
+ remove_unused_columns=False,
222
+ ),
223
+ train_dataset=train_dataset,
224
+ )
225
+ trainer.train()
226
  wandb.finish()
227
+ print("Level 1 done!")
228
 
229
  # Save Level 1 checkpoint
230
  model.save_pretrained("/tmp/deceit-1.5b-l1")
 
247
 
248
  print(f"Loaded {len(questions_l2)} Level 2 questions")
249
 
250
+ data_path_l1 = pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
251
+ questions_l1 = []
252
+ with open(data_path_l1) as f:
253
+ for line in f:
254
+ line = line.strip()
255
+ if line:
256
+ questions_l1.append(json.loads(line))
257
+
258
  # Mix 70% L2 + 30% L1
259
  n_l2 = len(questions_l2)
260
  n_l1_sample = max(1, int(n_l2 * 0.3))
261
+ l1_sample = random.sample(questions_l1, min(n_l1_sample, len(questions_l1)))
262
 
263
  mixed_questions = []
264
  for q in questions_l2:
 
286
  ]
287
  return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
288
 
289
+ train_dataset_l2 = Dataset.from_list([
290
  {"prompt": make_prompt_l2(q["question"], q.get("distractors", [])),
291
  "question": q["question"]}
292
  for q in mixed_questions
293
+ ])
 
 
 
 
 
294
 
295
+ _env_l2 = DeceitEnvironment(grader=_grader)
296
+ _abstain_counts_l2 = {}
297
+ _episode_counts_l2 = {}
298
 
299
+ def reward_fn_l2(completions, prompts=None, **kwargs):
 
 
 
300
  rewards = []
301
+ for text in completions:
302
+ try:
303
+ parsed = parse_action(text)
304
+ except:
305
+ parsed = FAIL.copy()
306
+
307
+ prompt_key = prompts[0][:50] if prompts else "default"
308
+ _episode_counts_l2[prompt_key] = _episode_counts_l2.get(prompt_key, 0) + 1
309
+ if parsed.get("abstain", False):
310
+ _abstain_counts_l2[prompt_key] = _abstain_counts_l2.get(prompt_key, 0) + 1
311
+
312
+ abstain_rate = _abstain_counts_l2.get(prompt_key, 0) / max(1, _episode_counts_l2.get(prompt_key, 1))
313
+
314
+ if parsed.get("abstain", False):
315
+ if abstain_rate > 0.3:
316
+ rewards.append(-0.5)
317
+ else:
318
+ rewards.append(0.0)
319
+ continue
320
+
321
+ try:
322
+ with _env_lock:
323
+ obs = _env_l2.reset(level=2)
324
+ current = parsed
325
+ total = 0.0
326
+ for turn in range(obs.max_turns):
327
+ if turn == obs.max_turns - 1:
328
+ current["is_final"] = True
329
+ action = DeceitAction(
330
+ reasoning=current.get("reasoning",""),
331
+ answer=current.get("answer",""),
332
+ confidence=float(current.get("confidence",0.5)),
333
+ abstain=bool(current.get("abstain",False)),
334
+ is_final=bool(current.get("is_final",True)),
335
+ )
336
+ result = _env_l2.step(action)
337
+ total += result.reward
338
+ if result.done:
339
+ break
340
+ except Exception as e:
341
+ print(f"L2 Episode error: {e}")
342
+ total = -1.3
343
+ rewards.append(total)
344
+ return rewards
345
 
346
+ print("Starting Level 2 training on 1.5B...")
347
+ wandb.init(project=WANDB_PROJECT, name="1.5b-level2-improved")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ trainer_l2 = GRPOTrainer(
350
+ model=model,
351
+ processing_class=tokenizer,
352
+ reward_funcs=[reward_fn_l2],
353
+ args=GRPOConfig(
354
+ output_dir="/tmp/deceit-1.5b-l2",
355
+ bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
356
+ fp16=False,
357
+ max_steps=600,
358
+ per_device_train_batch_size=4,
359
+ num_generations=4,
360
+ learning_rate=2e-5,
361
+ warmup_steps=10,
362
+ logging_steps=1,
363
+ save_steps=100,
364
+ report_to="wandb",
365
+ max_completion_length=256,
366
+ remove_unused_columns=False,
367
+ ),
368
+ train_dataset=train_dataset_l2,
369
+ )
370
+ trainer_l2.train()
371
  wandb.finish()
372
+ print("Level 2 training done!")
373
 
374
  # Save final model
375
  model.save_pretrained("/tmp/deceit-1.5b-final")