Draken1606 commited on
Commit
37edd09
·
1 Parent(s): c1adced

Fix 5 bugs: inference mode reset, step_counts in curriculum, adapter-only save (x3), DEMO001 false defence claim, episode_id in /reset

Browse files
Files changed (3) hide show
  1. server/app.py +2 -2
  2. server/dataset.py +1 -1
  3. training/train_grpo.py +19 -10
server/app.py CHANGED
@@ -69,12 +69,12 @@ def health():
69
 
70
 
71
  @app.post("/reset")
72
- def reset(stage: int = 1, session_id: str = None, seed: int = None):
73
  if session_id is None:
74
  session_id = str(uuid.uuid4())
75
  env = get_or_create_env(session_id)
76
  env.set_stage(stage)
77
- obs = env.reset(stage=stage, seed=seed)
78
  return {
79
  "session_id": session_id,
80
  "observation": obs.model_dump(),
 
69
 
70
 
71
  @app.post("/reset")
72
+ def reset(stage: int = 1, session_id: str = None, seed: int = None, episode_id: str = None):
73
  if session_id is None:
74
  session_id = str(uuid.uuid4())
75
  env = get_or_create_env(session_id)
76
  env.set_stage(stage)
77
+ obs = env.reset(stage=stage, seed=seed, episode_id=episode_id)
78
  return {
79
  "session_id": session_id,
80
  "observation": obs.model_dump(),
server/dataset.py CHANGED
@@ -74,7 +74,7 @@ class BailDataset:
74
  "Investigation is still pending and accused may tamper with evidence.",
75
  ],
76
  "defence_arguments": [
77
- "Accused has been in custody for 8 months on a 7-year max offence — already served more than half the equivalent.",
78
  "No prior criminal record. Permanent resident of Delhi with family ties.",
79
  "No evidence of flight risk or evidence tampering.",
80
  ],
 
74
  "Investigation is still pending and accused may tamper with evidence.",
75
  ],
76
  "defence_arguments": [
77
+ "Accused has been in custody for 8 months; threshold under BNSS 479 for a 7-year offence is 42 months not yet met. Bail is sought on community ties and clean record, not statutory default.",
78
  "No prior criminal record. Permanent resident of Delhi with family ties.",
79
  "No evidence of flight risk or evidence tampering.",
80
  ],
training/train_grpo.py CHANGED
@@ -656,10 +656,10 @@ def train(
656
  results_path.write_text(json.dumps(results, indent=2))
657
  print(f"\nResults saved to {results_path}")
658
 
659
- # ── Save model ────────────────────────────────────────────
660
- model.save_pretrained(output_dir)
661
  tokenizer.save_pretrained(output_dir)
662
- print(f"\nModel saved to {output_dir}")
663
  return results
664
 
665
 
@@ -904,7 +904,10 @@ def train_curriculum(
904
 
905
  def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
906
  ep_objs = [json.loads(e) for e in episode]
907
- return combined_reward(completions, ep_objs)
 
 
 
908
 
909
  stage_output = f"{output_dir}/stage_{stage}"
910
  config = GRPOConfig(
@@ -924,6 +927,11 @@ def train_curriculum(
924
  remove_unused_columns=False,
925
  )
926
 
 
 
 
 
 
927
  trainer = GRPOTrainer(
928
  model=model,
929
  processing_class=tokenizer,
@@ -963,10 +971,11 @@ def train_curriculum(
963
  print(f" ✗ Stage {stage} below threshold ({post_reward:.2f} < {threshold:.2f})")
964
  print(f" → Continuing to next stage anyway (curriculum mode)")
965
 
966
- # Save checkpoint after each stage
967
- model.save_pretrained(stage_output)
 
968
  tokenizer.save_pretrained(stage_output)
969
- print(f" Checkpoint saved: {stage_output}")
970
 
971
  # ── Final summary ──
972
  print(f"\n{'═' * 60}")
@@ -978,11 +987,11 @@ def train_curriculum(
978
  f"(Δ = {r['delta']:+.4f})")
979
  print(f" Total traces harvested: {len(accumulated_traces)}")
980
 
981
- # Save final model
982
  final_dir = f"{output_dir}/final"
983
- model.save_pretrained(final_dir)
984
  tokenizer.save_pretrained(final_dir)
985
- print(f"\n Final model saved: {final_dir}")
986
 
987
  # Save results
988
  results_path = Path(output_dir) / "curriculum_results.json"
 
656
  results_path.write_text(json.dumps(results, indent=2))
657
  print(f"\nResults saved to {results_path}")
658
 
659
+ # Save LoRA adapters only — safe for 4-bit quantized models
660
+ model.save_pretrained(output_dir, save_adapters_only=True)
661
  tokenizer.save_pretrained(output_dir)
662
+ print(f"\nModel adapters saved to {output_dir}")
663
  return results
664
 
665
 
 
904
 
905
  def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
906
  ep_objs = [json.loads(e) for e in episode]
907
+ # Pass step_count=1 for curriculum training (single-shot XML, no multi-step env loop)
908
+ # This keeps efficiency contribution honest rather than silently 0.0
909
+ step_counts = [1] * len(completions)
910
+ return combined_reward(completions, ep_objs, step_counts=step_counts)
911
 
912
  stage_output = f"{output_dir}/stage_{stage}"
913
  config = GRPOConfig(
 
927
  remove_unused_columns=False,
928
  )
929
 
930
+ # ── Switch model back to training mode before trainer.train() ──
931
+ # evaluate_on_stage calls FastLanguageModel.for_inference(model);
932
+ # without this reset, stages 2-4 train in inference mode silently.
933
+ FastLanguageModel.for_training(model)
934
+
935
  trainer = GRPOTrainer(
936
  model=model,
937
  processing_class=tokenizer,
 
971
  print(f" ✗ Stage {stage} below threshold ({post_reward:.2f} < {threshold:.2f})")
972
  print(f" → Continuing to next stage anyway (curriculum mode)")
973
 
974
+ # Save LoRA adapters only — safe for 4-bit models (save_pretrained_merged
975
+ # requires a full merge which can OOM on T4)
976
+ model.save_pretrained(stage_output, save_adapters_only=True)
977
  tokenizer.save_pretrained(stage_output)
978
+ print(f" Checkpoint saved (adapters): {stage_output}")
979
 
980
  # ── Final summary ──
981
  print(f"\n{'═' * 60}")
 
987
  f"(Δ = {r['delta']:+.4f})")
988
  print(f" Total traces harvested: {len(accumulated_traces)}")
989
 
990
+ # Save final model (adapters only — merge separately if needed)
991
  final_dir = f"{output_dir}/final"
992
+ model.save_pretrained(final_dir, save_adapters_only=True)
993
  tokenizer.save_pretrained(final_dir)
994
+ print(f"\n Final model saved (adapters): {final_dir}")
995
 
996
  # Save results
997
  results_path = Path(output_dir) / "curriculum_results.json"