mahir-m01 commited on
Commit
ead6cb5
Β·
1 Parent(s): 8848300

feat(hf): export manifest, MVP save@50, long checkpoints@50-200

Browse files

- mvp: save_steps=50, save_total_limit=1, final lora-adapter + export_manifest
- long: save_steps=50, save_total_limit=4 to retain 50/100/150/200
- write export_manifest.json (config, checkpoint dirs) before Hub upload

Made-with: Cursor

Files changed (1) hide show
  1. hf/v6_train.py +39 -4
hf/v6_train.py CHANGED
@@ -79,7 +79,10 @@ class Cfg:
79
  max_eval: int = 16
80
  warmup_steps: int = 10
81
  out_dir: str = "runs/v6-mvp"
82
- save_steps: int = 25
 
 
 
83
 
84
  def _default_profile() -> str:
85
  if p := os.environ.get("ARMGYM_PROFILE"):
@@ -98,11 +101,19 @@ def _apply_profile(cfg: Cfg) -> Cfg:
98
  cfg.steps = 200
99
  cfg.out_dir = "runs/v6-200"
100
  cfg.save_steps = 50
 
101
  return cfg
102
 
103
  cfg = _apply_profile(Cfg())
104
- log.info("Profile: %s hub=%s out=%s steps=%d",
105
- _default_profile(), cfg.hub_model_id, cfg.out_dir, cfg.steps)
 
 
 
 
 
 
 
106
  log.info("Config: model=%s steps=%d G=%d temp=%.1f",
107
  cfg.model_id, cfg.steps, cfg.num_generations, cfg.temperature)
108
 
@@ -650,7 +661,7 @@ def build_grpo_config(cfg):
650
  remove_unused_columns=False,
651
  logging_steps=1,
652
  save_steps=cfg.save_steps,
653
- save_total_limit=1,
654
  report_to="none",
655
  )
656
  while True:
@@ -664,6 +675,28 @@ def build_grpo_config(cfg):
664
  p.pop(m.group(1), None)
665
 
666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  # ── TRAIN ─────────────────────────────────────────────────────────────────────
668
  out = Path(cfg.out_dir)
669
  out.mkdir(parents=True, exist_ok=True)
@@ -731,6 +764,8 @@ finally:
731
  except Exception as e:
732
  log.error("Failed to save LoRA adapter: %s", e)
733
 
 
 
734
  if hf_tok := os.environ.get("HF_TOKEN"):
735
  try:
736
  from huggingface_hub import HfApi
 
79
  max_eval: int = 16
80
  warmup_steps: int = 10
81
  out_dir: str = "runs/v6-mvp"
82
+ # mvp: one checkpoint at step 50 (end of 50-step run) + final lora-adapter
83
+ save_steps: int = 50
84
+ # long profile overrides to 4 so steps 50/100/150/200 are all retained
85
+ save_total_limit: int = 1
86
 
87
  def _default_profile() -> str:
88
  if p := os.environ.get("ARMGYM_PROFILE"):
 
101
  cfg.steps = 200
102
  cfg.out_dir = "runs/v6-200"
103
  cfg.save_steps = 50
104
+ cfg.save_total_limit = 4
105
  return cfg
106
 
107
  cfg = _apply_profile(Cfg())
108
+ log.info(
109
+ "Profile: %s hub=%s out=%s steps=%d save_every=%d keep_ckpt=%d",
110
+ _default_profile(),
111
+ cfg.hub_model_id,
112
+ cfg.out_dir,
113
+ cfg.steps,
114
+ cfg.save_steps,
115
+ cfg.save_total_limit,
116
+ )
117
  log.info("Config: model=%s steps=%d G=%d temp=%.1f",
118
  cfg.model_id, cfg.steps, cfg.num_generations, cfg.temperature)
119
 
 
661
  remove_unused_columns=False,
662
  logging_steps=1,
663
  save_steps=cfg.save_steps,
664
+ save_total_limit=cfg.save_total_limit,
665
  report_to="none",
666
  )
667
  while True:
 
675
  p.pop(m.group(1), None)
676
 
677
 
678
+ def _write_export_manifest(out: Path, cfg: Cfg) -> None:
679
+ """Record config + produced paths for HF upload and post-hoc analysis."""
680
+ checkpoints = sorted(
681
+ p.name for p in out.iterdir() if p.is_dir() and p.name.startswith("checkpoint-")
682
+ )
683
+ manifest: dict = {
684
+ "profile": _default_profile(),
685
+ "export_time_epoch": int(time.time()),
686
+ "train_config": asdict(cfg),
687
+ "artifacts": {
688
+ "config_json": "config.json",
689
+ "log_csv": "log.csv",
690
+ "lora_adapter_dir": "lora-adapter",
691
+ "checkpoint_dirs": checkpoints,
692
+ },
693
+ }
694
+ (out / "export_manifest.json").write_text(
695
+ json.dumps(manifest, indent=2) + "\n", encoding="utf-8"
696
+ )
697
+ log.info("Wrote export manifest (%d checkpoints): %s", len(checkpoints), out / "export_manifest.json")
698
+
699
+
700
  # ── TRAIN ─────────────────────────────────────────────────────────────────────
701
  out = Path(cfg.out_dir)
702
  out.mkdir(parents=True, exist_ok=True)
 
764
  except Exception as e:
765
  log.error("Failed to save LoRA adapter: %s", e)
766
 
767
+ _write_export_manifest(out, cfg)
768
+
769
  if hf_tok := os.environ.get("HF_TOKEN"):
770
  try:
771
  from huggingface_hub import HfApi