Spaces:
Running
Running
Commit ·
37edd09
1
Parent(s): c1adced
Fix 5 bugs: inference mode reset, step_counts in curriculum, adapter-only save (x3), DEMO001 false defence claim, episode_id in /reset
Browse files- server/app.py +2 -2
- server/dataset.py +1 -1
- training/train_grpo.py +19 -10
server/app.py
CHANGED
|
@@ -69,12 +69,12 @@ def health():
|
|
| 69 |
|
| 70 |
|
| 71 |
@app.post("/reset")
|
| 72 |
-
def reset(stage: int = 1, session_id: str = None, seed: int = None):
|
| 73 |
if session_id is None:
|
| 74 |
session_id = str(uuid.uuid4())
|
| 75 |
env = get_or_create_env(session_id)
|
| 76 |
env.set_stage(stage)
|
| 77 |
-
obs = env.reset(stage=stage, seed=seed)
|
| 78 |
return {
|
| 79 |
"session_id": session_id,
|
| 80 |
"observation": obs.model_dump(),
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
@app.post("/reset")
|
| 72 |
+
def reset(stage: int = 1, session_id: str = None, seed: int = None, episode_id: str = None):
|
| 73 |
if session_id is None:
|
| 74 |
session_id = str(uuid.uuid4())
|
| 75 |
env = get_or_create_env(session_id)
|
| 76 |
env.set_stage(stage)
|
| 77 |
+
obs = env.reset(stage=stage, seed=seed, episode_id=episode_id)
|
| 78 |
return {
|
| 79 |
"session_id": session_id,
|
| 80 |
"observation": obs.model_dump(),
|
server/dataset.py
CHANGED
|
@@ -74,7 +74,7 @@ class BailDataset:
|
|
| 74 |
"Investigation is still pending and accused may tamper with evidence.",
|
| 75 |
],
|
| 76 |
"defence_arguments": [
|
| 77 |
-
"Accused has been in custody for 8 months
|
| 78 |
"No prior criminal record. Permanent resident of Delhi with family ties.",
|
| 79 |
"No evidence of flight risk or evidence tampering.",
|
| 80 |
],
|
|
|
|
| 74 |
"Investigation is still pending and accused may tamper with evidence.",
|
| 75 |
],
|
| 76 |
"defence_arguments": [
|
| 77 |
+
"Accused has been in custody for 8 months; threshold under BNSS 479 for a 7-year offence is 42 months — not yet met. Bail is sought on community ties and clean record, not statutory default.",
|
| 78 |
"No prior criminal record. Permanent resident of Delhi with family ties.",
|
| 79 |
"No evidence of flight risk or evidence tampering.",
|
| 80 |
],
|
training/train_grpo.py
CHANGED
|
@@ -656,10 +656,10 @@ def train(
|
|
| 656 |
results_path.write_text(json.dumps(results, indent=2))
|
| 657 |
print(f"\nResults saved to {results_path}")
|
| 658 |
|
| 659 |
-
#
|
| 660 |
-
model.save_pretrained(output_dir)
|
| 661 |
tokenizer.save_pretrained(output_dir)
|
| 662 |
-
print(f"\nModel saved to {output_dir}")
|
| 663 |
return results
|
| 664 |
|
| 665 |
|
|
@@ -904,7 +904,10 @@ def train_curriculum(
|
|
| 904 |
|
| 905 |
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 906 |
ep_objs = [json.loads(e) for e in episode]
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
stage_output = f"{output_dir}/stage_{stage}"
|
| 910 |
config = GRPOConfig(
|
|
@@ -924,6 +927,11 @@ def train_curriculum(
|
|
| 924 |
remove_unused_columns=False,
|
| 925 |
)
|
| 926 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
trainer = GRPOTrainer(
|
| 928 |
model=model,
|
| 929 |
processing_class=tokenizer,
|
|
@@ -963,10 +971,11 @@ def train_curriculum(
|
|
| 963 |
print(f" ✗ Stage {stage} below threshold ({post_reward:.2f} < {threshold:.2f})")
|
| 964 |
print(f" → Continuing to next stage anyway (curriculum mode)")
|
| 965 |
|
| 966 |
-
# Save
|
| 967 |
-
|
|
|
|
| 968 |
tokenizer.save_pretrained(stage_output)
|
| 969 |
-
print(f" Checkpoint saved: {stage_output}")
|
| 970 |
|
| 971 |
# ── Final summary ──
|
| 972 |
print(f"\n{'═' * 60}")
|
|
@@ -978,11 +987,11 @@ def train_curriculum(
|
|
| 978 |
f"(Δ = {r['delta']:+.4f})")
|
| 979 |
print(f" Total traces harvested: {len(accumulated_traces)}")
|
| 980 |
|
| 981 |
-
# Save final model
|
| 982 |
final_dir = f"{output_dir}/final"
|
| 983 |
-
model.save_pretrained(final_dir)
|
| 984 |
tokenizer.save_pretrained(final_dir)
|
| 985 |
-
print(f"\n Final model saved: {final_dir}")
|
| 986 |
|
| 987 |
# Save results
|
| 988 |
results_path = Path(output_dir) / "curriculum_results.json"
|
|
|
|
| 656 |
results_path.write_text(json.dumps(results, indent=2))
|
| 657 |
print(f"\nResults saved to {results_path}")
|
| 658 |
|
| 659 |
+
# Save LoRA adapters only — safe for 4-bit quantized models
|
| 660 |
+
model.save_pretrained(output_dir, save_adapters_only=True)
|
| 661 |
tokenizer.save_pretrained(output_dir)
|
| 662 |
+
print(f"\nModel adapters saved to {output_dir}")
|
| 663 |
return results
|
| 664 |
|
| 665 |
|
|
|
|
| 904 |
|
| 905 |
def reward_fn(completions: List[str], episode: List[str], **kwargs) -> List[float]:
|
| 906 |
ep_objs = [json.loads(e) for e in episode]
|
| 907 |
+
# Pass step_count=1 for curriculum training (single-shot XML, no multi-step env loop)
|
| 908 |
+
# This keeps efficiency contribution honest rather than silently 0.0
|
| 909 |
+
step_counts = [1] * len(completions)
|
| 910 |
+
return combined_reward(completions, ep_objs, step_counts=step_counts)
|
| 911 |
|
| 912 |
stage_output = f"{output_dir}/stage_{stage}"
|
| 913 |
config = GRPOConfig(
|
|
|
|
| 927 |
remove_unused_columns=False,
|
| 928 |
)
|
| 929 |
|
| 930 |
+
# ── Switch model back to training mode before trainer.train() ──
|
| 931 |
+
# evaluate_on_stage calls FastLanguageModel.for_inference(model);
|
| 932 |
+
# without this reset, stages 2-4 train in inference mode silently.
|
| 933 |
+
FastLanguageModel.for_training(model)
|
| 934 |
+
|
| 935 |
trainer = GRPOTrainer(
|
| 936 |
model=model,
|
| 937 |
processing_class=tokenizer,
|
|
|
|
| 971 |
print(f" ✗ Stage {stage} below threshold ({post_reward:.2f} < {threshold:.2f})")
|
| 972 |
print(f" → Continuing to next stage anyway (curriculum mode)")
|
| 973 |
|
| 974 |
+
# Save LoRA adapters only — safe for 4-bit models (save_pretrained_merged
|
| 975 |
+
# requires a full merge which can OOM on T4)
|
| 976 |
+
model.save_pretrained(stage_output, save_adapters_only=True)
|
| 977 |
tokenizer.save_pretrained(stage_output)
|
| 978 |
+
print(f" Checkpoint saved (adapters): {stage_output}")
|
| 979 |
|
| 980 |
# ── Final summary ──
|
| 981 |
print(f"\n{'═' * 60}")
|
|
|
|
| 987 |
f"(Δ = {r['delta']:+.4f})")
|
| 988 |
print(f" Total traces harvested: {len(accumulated_traces)}")
|
| 989 |
|
| 990 |
+
# Save final model (adapters only — merge separately if needed)
|
| 991 |
final_dir = f"{output_dir}/final"
|
| 992 |
+
model.save_pretrained(final_dir, save_adapters_only=True)
|
| 993 |
tokenizer.save_pretrained(final_dir)
|
| 994 |
+
print(f"\n Final model saved (adapters): {final_dir}")
|
| 995 |
|
| 996 |
# Save results
|
| 997 |
results_path = Path(output_dir) / "curriculum_results.json"
|