AniketAsla commited on
Commit
726b8bb
·
verified ·
1 Parent(s): 16966ba

deploy: update train/jobs_run.py

Browse files
Files changed (1) hide show
  1. train/jobs_run.py +437 -437
train/jobs_run.py CHANGED
@@ -1,437 +1,437 @@
1
- """
2
- jobs_run.py — single-entry driver for HF Jobs.
3
-
4
- Designed to run inside `pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime` on HF
5
- Jobs (L4/A10G/A100). Submits as:
6
-
7
- hf jobs run \\
8
- --flavor l4x1 \\
9
- --timeout 12h \\
10
- --secret HF_TOKEN=hf_xxx \\
11
- --secret WANDB_API_KEY=wandb_xxx \\
12
- --env EPISODES=10000 \\
13
- --env EPOCHS=2 \\
14
- --env DISABLE_VARIANCE_GUARD=1 \\
15
- --image pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime \\
16
- python train/jobs_run.py
17
-
18
- Phases (each one logs a clear banner so you can grep the log):
19
-
20
- [1/6] Install deps from train/requirements.txt + root requirements.txt
21
- [2/6] Boot env server (uvicorn) on 127.0.0.1:7860
22
- [3/6] Wait for /health == healthy
23
- [4/6] Run train.train_minimal.main()
24
- [5/6] Push checkpoint + reports/ + docs/ to the HF model repo
25
- [6/6] Cleanly exit (kills env server so billing stops)
26
-
27
- Eval-only job (fast — refresh README metrics from Hub checkpoint, no GRPO):
28
-
29
- hf jobs run ... \\
30
- --env JOBS_EVAL_ONLY=1 \\
31
- --env EPISODES=10000 \\
32
- --env EVAL_EPISODES=18 \\
33
- --secret HF_TOKEN=hf_xxx \\
34
- python train/jobs_run.py
35
-
36
- If training finished but reports/ were not updated, run locally (with checkpoint + env):
37
- EPISODES=<same as job> python train/post_training_eval.py
38
-
39
- Environment variables consumed:
40
-
41
- Required:
42
- HF_TOKEN — HF write token (used to push checkpoint)
43
- Optional (with defaults):
44
- WANDB_API_KEY — enables WandB logging if set
45
- WANDB_ENTITY — wandb entity (default: aniketaslaliya-lnmiit)
46
- EPISODES — training episodes (default: 10000)
47
- EPOCHS — training epochs (default: 2)
48
- BATCH_SIZE — per-device batch (default: 4)
49
- NUM_GENERATIONS — GRPO group size (default: 4)
50
- GRAD_ACCUM — gradient accumulation steps (default: 2)
51
- MAX_COMPLETION_LENGTH — output token cap (default: 80)
52
- MAX_PROMPT_LENGTH — prompt token cap (default: 512)
53
- DISABLE_VARIANCE_GUARD — bypass CF-1 guard (default: 1)
54
- HF_MODEL_REPO — where to push the trained model
55
- (default: AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct)
56
- JOBS_EVAL_ONLY — if 1: skip training; download checkpoint from HF_MODEL_REPO,
57
- run post-training eval, upload reports + docs only (fast).
58
- EVAL_EPISODES — optional; larger = more stable eval means (e.g. 18).
59
- """
60
- from __future__ import annotations
61
-
62
- import functools
63
- import os
64
- import signal
65
- import subprocess
66
- import sys
67
- import time
68
- from pathlib import Path
69
-
70
- # Force unbuffered stdout/stderr so HF Jobs log viewer shows every line in
71
- # real time. Without this, prints sit in a 4KB buffer and the user only sees
72
- # "Job started" for several minutes — making working jobs look broken.
73
- os.environ["PYTHONUNBUFFERED"] = "1"
74
- try:
75
- sys.stdout.reconfigure(line_buffering=True)
76
- sys.stderr.reconfigure(line_buffering=True)
77
- except AttributeError:
78
- pass
79
- print = functools.partial(print, flush=True) # noqa: A001 — intentional shadow
80
-
81
- # Heartbeat: a single line every minute so the user knows the job is alive
82
- # even during slow phases (pip install, model download, dataset prep).
83
- _HEARTBEAT_START = time.time()
84
-
85
-
86
- def _hb(label: str) -> None:
87
- elapsed = int(time.time() - _HEARTBEAT_START)
88
- mm, ss = divmod(elapsed, 60)
89
- print(f"[heartbeat +{mm:02d}:{ss:02d}] {label}")
90
-
91
-
92
- # ── [0/6] Bootstrap the repo (when running as a one-shot script) ────────────
93
- # When this file is executed via `python -c "exec(...)"` or downloaded as a
94
- # raw script, it has no surrounding repo. Detect that and `git clone` ourselves
95
- # so the rest of the script sees the real layout.
96
- _BOOTSTRAP_MARKER = Path(__file__).resolve().parent.parent / "app" / "main.py"
97
- if not _BOOTSTRAP_MARKER.exists():
98
- print("[0/6] Bootstrap: no repo on disk, cloning from GitHub", flush=True)
99
- _clone_dir = Path("/tmp/debatefloor")
100
- if not _clone_dir.exists():
101
- subprocess.check_call(
102
- ["git", "clone", "--depth", "1",
103
- "https://github.com/AniketAslaliya/debateFloor.git",
104
- str(_clone_dir)]
105
- )
106
- os.chdir(_clone_dir)
107
- REPO_ROOT = _clone_dir
108
- else:
109
- REPO_ROOT = Path(__file__).resolve().parent.parent
110
- os.chdir(REPO_ROOT)
111
-
112
- sys.path.insert(0, str(REPO_ROOT))
113
-
114
- _hb("driver script started")
115
- print("=" * 70)
116
- print("[1/6] Installing pinned deps from requirements files")
117
- print("=" * 70)
118
-
119
-
120
- def _pip_install(*args: str) -> None:
121
- cmd = [sys.executable, "-m", "pip", "install", "--quiet", *args]
122
- print(f" $ {' '.join(cmd)}")
123
- subprocess.check_call(cmd)
124
-
125
-
126
- _pip_install("--upgrade", "pip")
127
- _hb("upgraded pip")
128
- _pip_install("-r", "requirements.txt")
129
- _hb("installed root requirements.txt")
130
- _pip_install("-r", "train/requirements.txt")
131
- _hb("installed train/requirements.txt")
132
-
133
- # ── [1.4/6] Purge torchvision AND evict it from sys.modules.
134
- #
135
- # Two-part problem:
136
- # (1) The HF Jobs base image claims 'pytorch:2.4.0-cuda12.1' but actually
137
- # ships torch 2.11.0+cu130, so any torchvision pin we make is wrong.
138
- # (2) Even after `pip uninstall torchvision`, Python keeps the partially-
139
- # loaded torchvision modules in sys.modules from earlier `pip install`
140
- # work, so `import transformers` still hits the broken cached state and
141
- # fails with "partially initialized module 'torchvision' has no
142
- # attribute 'extension'".
143
- #
144
- # Fix: uninstall the package AND surgically evict every torchvision.* entry
145
- # from sys.modules so the next import attempt sees a clean slate.
146
- print("\n Purging torchvision (text-only training, not needed)...")
147
- try:
148
- subprocess.check_call(
149
- [sys.executable, "-m", "pip", "uninstall", "-y", "-q", "torchvision"]
150
- )
151
- print(" Removed torchvision package from environment")
152
- except subprocess.CalledProcessError:
153
- print(" torchvision not installed — nothing to remove")
154
-
155
- _evicted = [k for k in list(sys.modules) if k == "torchvision" or k.startswith("torchvision.")]
156
- for _k in _evicted:
157
- del sys.modules[_k]
158
- if _evicted:
159
- print(f" Evicted {len(_evicted)} torchvision modules from sys.modules cache")
160
-
161
- # Also evict any partially-loaded transformers modules that might have already
162
- # tried to import torchvision and cached a broken state (e.g. from this script
163
- # importing `requests` earlier, which doesn't touch transformers, but be safe).
164
- _tf_evicted = [k for k in list(sys.modules) if k == "transformers" or k.startswith("transformers.")]
165
- for _k in _tf_evicted:
166
- del sys.modules[_k]
167
- if _tf_evicted:
168
- print(f" Evicted {len(_tf_evicted)} transformers modules from sys.modules cache")
169
-
170
- # Tell transformers to be tolerant of missing optional vision deps (defense in
171
- # depth; the uninstall + sys.modules eviction is what actually fixes it).
172
- os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
173
-
174
- # ── [1.5/6] Sanity-check critical imports BEFORE we boot the env + load model.
175
- print("\n Sanity-checking critical imports...")
176
- _failed = []
177
- for _mod, _from in [
178
- ("torch", None),
179
- ("transformers", "PreTrainedModel"), # forces full transformers init
180
- ("trl", "GRPOConfig"), # forces grpo_trainer import
181
- ("peft", "LoraConfig"),
182
- ("accelerate", "Accelerator"),
183
- ("datasets", "Dataset"),
184
- ("wandb", None),
185
- ]:
186
- try:
187
- if _from:
188
- _m = __import__(_mod, fromlist=[_from])
189
- getattr(_m, _from)
190
- else:
191
- __import__(_mod)
192
- try:
193
- _v = __import__(_mod).__version__
194
- except Exception:
195
- _v = "?"
196
- print(f" ok {_mod:14s} {_v}")
197
- except Exception as _e:
198
- print(f" FAIL {_mod:14s} → {type(_e).__name__}: {_e}")
199
- _failed.append((_mod, _from, _e))
200
-
201
- if _failed:
202
- print("\n Sanity check failed — aborting before model download.")
203
- raise SystemExit(1)
204
-
205
- print(" All critical imports OK.\n")
206
- _hb("import sanity check passed")
207
- print(" Deps installed.\n")
208
-
209
-
210
- # ── [2/6] Boot the env server in the background ─────────────────────────────
211
- import requests as _requests # imported AFTER pip install -r requirements.txt
212
-
213
- print("=" * 70)
214
- print("[2/6] Booting DebateFloor env server on 127.0.0.1:7860")
215
- print("=" * 70)
216
-
217
- ENV_BASE_URL = "http://127.0.0.1:7860"
218
- _log_path = Path("/tmp/uvicorn_debatefloor.log")
219
- _log_file = open(_log_path, "w")
220
-
221
- env_proc = subprocess.Popen(
222
- [
223
- sys.executable,
224
- "-m",
225
- "uvicorn",
226
- "app.main:app",
227
- "--host",
228
- "127.0.0.1",
229
- "--port",
230
- "7860",
231
- "--log-level",
232
- "warning",
233
- ],
234
- cwd=str(REPO_ROOT),
235
- stdout=_log_file,
236
- stderr=subprocess.STDOUT,
237
- )
238
- print(f" uvicorn PID = {env_proc.pid}")
239
-
240
-
241
- # ── [3/6] Wait for /health ──────────────────────────────────────────────────
242
- print("\n" + "=" * 70)
243
- print("[3/6] Waiting for env server /health")
244
- print("=" * 70)
245
-
246
-
247
- def _wait_for_env(max_tries: int = 60) -> None:
248
- for i in range(max_tries):
249
- if env_proc.poll() is not None:
250
- log = _log_path.read_text()[-4000:]
251
- raise RuntimeError(f"uvicorn died before /health was ready. Log:\n{log}")
252
- try:
253
- r = _requests.get(f"{ENV_BASE_URL}/health", timeout=3)
254
- if r.status_code == 200 and r.json().get("status") == "healthy":
255
- print(f" Healthy after {i + 1} attempts.")
256
- return
257
- except Exception:
258
- pass
259
- time.sleep(2)
260
- log = _log_path.read_text()[-4000:]
261
- raise RuntimeError(f"Env never became healthy. Log:\n{log}")
262
-
263
-
264
- _wait_for_env()
265
- _hb("env server is healthy and accepting requests")
266
-
267
-
268
- # ── [4/6] Run training ─────────────��────────────────────────────────────────
269
- print("\n" + "=" * 70)
270
- print("[4/6] Running train.train_minimal.main()")
271
- print("=" * 70)
272
- _hb("starting training phase — model download may take 1–2 min on first run")
273
-
274
- # Surface key config so the log shows what we ran with
275
- EPISODES = int(os.environ.get("EPISODES", "10000"))
276
- EPOCHS = int(os.environ.get("EPOCHS", "2"))
277
- BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
278
- print(f" EPISODES={EPISODES} EPOCHS={EPOCHS} BATCH_SIZE={BATCH_SIZE}")
279
- print(f" NUM_GENERATIONS={os.environ.get('NUM_GENERATIONS', '4')}")
280
- print(f" GRAD_ACCUM={os.environ.get('GRAD_ACCUM', '2')}")
281
- print(f" MAX_COMPLETION_LENGTH={os.environ.get('MAX_COMPLETION_LENGTH', '80')}")
282
- print(
283
- f" DISABLE_VARIANCE_GUARD={os.environ.get('DISABLE_VARIANCE_GUARD', '1')}"
284
- )
285
- os.environ.setdefault("DISABLE_VARIANCE_GUARD", "1")
286
- os.environ.setdefault("NUM_GENERATIONS", "4")
287
- os.environ.setdefault("GRAD_ACCUM", "2")
288
- os.environ.setdefault("MAX_COMPLETION_LENGTH", "80")
289
- os.environ.setdefault("MAX_PROMPT_LENGTH", "512")
290
- os.environ["ENV_BASE_URL"] = ENV_BASE_URL
291
-
292
- import train.train_minimal as tm # noqa: E402
293
-
294
- tm.MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
295
- tm.EPISODES = EPISODES
296
- tm.EPOCHS = EPOCHS
297
- tm.BATCH_SIZE = BATCH_SIZE
298
- tm.USE_WANDB = bool(os.environ.get("WANDB_API_KEY", ""))
299
- tm.WANDB_KEY = os.environ.get("WANDB_API_KEY", "")
300
- tm.WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "aniketaslaliya-lnmiit")
301
- tm.ENV_BASE_URL = ENV_BASE_URL
302
-
303
- import torch # noqa: E402
304
-
305
- tm.HAS_BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
306
- tm.USE_FP16 = torch.cuda.is_available() and not tm.HAS_BF16
307
- tm.DTYPE = torch.bfloat16 if tm.HAS_BF16 else torch.float16
308
- print(f" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
309
- print(f" dtype: {tm.DTYPE} | Unsloth: {tm.USE_UNSLOTH}\n")
310
-
311
- _ee = os.getenv("EVAL_EPISODES", "").strip()
312
- if _ee:
313
- tm.EVAL_EPISODES = int(_ee)
314
- print(f" EVAL_EPISODES={tm.EVAL_EPISODES} (env override)\n")
315
-
316
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
317
- HF_MODEL_REPO = os.environ.get(
318
- "HF_MODEL_REPO",
319
- "AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct",
320
- )
321
-
322
- train_exit_code = 0
323
- EVAL_ONLY = os.getenv("JOBS_EVAL_ONLY", "").strip().lower() in ("1", "true", "yes")
324
-
325
- if EVAL_ONLY:
326
- print("\n" + "=" * 70)
327
- print("[4/6] JOBS_EVAL_ONLY=1 — skip GRPO; Hub checkpoint + post-training eval")
328
- print("=" * 70)
329
- if not HF_TOKEN:
330
- print(" ERROR: JOBS_EVAL_ONLY requires HF_TOKEN (download checkpoint).")
331
- train_exit_code = 1
332
- else:
333
- try:
334
- import shutil
335
-
336
- from huggingface_hub import snapshot_download
337
-
338
- ckpt_dl = REPO_ROOT / "debatefloor_checkpoint"
339
- if ckpt_dl.exists():
340
- shutil.rmtree(ckpt_dl)
341
- print(f" snapshot_download {HF_MODEL_REPO} -> {ckpt_dl}")
342
- snapshot_download(
343
- repo_id=HF_MODEL_REPO,
344
- repo_type="model",
345
- local_dir=str(ckpt_dl),
346
- token=HF_TOKEN,
347
- ignore_patterns=[
348
- "reports/**",
349
- "docs/**",
350
- "*.md",
351
- ".gitattributes",
352
- ],
353
- )
354
- from train.post_training_eval import run_eval # noqa: E402
355
-
356
- run_eval(ckpt_dl, fresh_summary=False, stop_env_server=False)
357
- print(" Eval-only run completed.")
358
- except Exception as exc:
359
- train_exit_code = 1
360
- print(f" JOBS_EVAL_ONLY raised: {type(exc).__name__}: {exc}")
361
- import traceback
362
-
363
- traceback.print_exc()
364
- else:
365
- try:
366
- tm.main()
367
- print(" Training completed.")
368
- except Exception as exc: # don't crash the whole job — we still want artifacts
369
- train_exit_code = 1
370
- print(f" Training raised: {type(exc).__name__}: {exc}")
371
- import traceback
372
-
373
- traceback.print_exc()
374
-
375
-
376
- # ── [5/6] Push artifacts to the HF Hub model repo ───────────────────────────
377
- print("\n" + "=" * 70)
378
- print("[5/6] Uploading artifacts to HF Hub")
379
- print("=" * 70)
380
-
381
- if not HF_TOKEN:
382
- print(" HF_TOKEN not set — skipping upload (artifacts remain in job storage).")
383
- else:
384
- try:
385
- from huggingface_hub import HfApi, login
386
-
387
- login(token=HF_TOKEN, add_to_git_credential=False)
388
- api = HfApi(token=HF_TOKEN)
389
- api.create_repo(repo_id=HF_MODEL_REPO, repo_type="model", exist_ok=True)
390
-
391
- ckpt_dir = Path("./debatefloor_checkpoint")
392
- if EVAL_ONLY:
393
- print(" JOBS_EVAL_ONLY: skipping checkpoint upload (weights already on Hub).")
394
- elif ckpt_dir.exists() and any(ckpt_dir.iterdir()):
395
- print(f" Uploading checkpoint folder -> {HF_MODEL_REPO}")
396
- api.upload_folder(
397
- folder_path=str(ckpt_dir),
398
- repo_id=HF_MODEL_REPO,
399
- repo_type="model",
400
- commit_message=f"GRPO HF Jobs run: {EPISODES} episodes x {EPOCHS} epochs",
401
- )
402
- else:
403
- print(" No ./debatefloor_checkpoint to upload (training may have failed early).")
404
-
405
- for artifact in [
406
- "reports/training_summary.json",
407
- "reports/component_shift_summary.json",
408
- "docs/reward_curve.svg",
409
- "docs/component_shift.svg",
410
- ]:
411
- p = Path(artifact)
412
- if p.exists():
413
- print(f" Uploading {artifact}")
414
- api.upload_file(
415
- path_or_fileobj=str(p),
416
- path_in_repo=artifact,
417
- repo_id=HF_MODEL_REPO,
418
- repo_type="model",
419
- commit_message=f"Update {artifact} from HF Jobs run",
420
- )
421
- else:
422
- print(f" Skipping {artifact} (not found)")
423
- except Exception as exc:
424
- print(f" Upload step raised: {type(exc).__name__}: {exc}")
425
-
426
-
427
- # ── [6/6] Clean shutdown so HF Jobs stops billing ───────────────────────────
428
- print("\n" + "=" * 70)
429
- print("[6/6] Shutting down env server cleanly")
430
- print("=" * 70)
431
- try:
432
- env_proc.send_signal(signal.SIGTERM)
433
- env_proc.wait(timeout=10)
434
- except Exception:
435
- env_proc.kill()
436
- print(" Done.")
437
- sys.exit(train_exit_code)
 
1
+ """
2
+ jobs_run.py — single-entry driver for HF Jobs.
3
+
4
+ Designed to run inside `pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime` on HF
5
+ Jobs (L4/A10G/A100). Submits as:
6
+
7
+ hf jobs run \\
8
+ --flavor l4x1 \\
9
+ --timeout 12h \\
10
+ --secret HF_TOKEN=hf_xxx \\
11
+ --secret WANDB_API_KEY=wandb_xxx \\
12
+ --env EPISODES=10000 \\
13
+ --env EPOCHS=2 \\
14
+ --env DISABLE_VARIANCE_GUARD=1 \\
15
+ --image pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime \\
16
+ python train/jobs_run.py
17
+
18
+ Phases (each one logs a clear banner so you can grep the log):
19
+
20
+ [1/6] Install deps from train/requirements.txt + root requirements.txt
21
+ [2/6] Boot env server (uvicorn) on 127.0.0.1:7860
22
+ [3/6] Wait for /health == healthy
23
+ [4/6] Run train.train_minimal.main()
24
+ [5/6] Push checkpoint + reports/ + docs/ to the HF model repo
25
+ [6/6] Cleanly exit (kills env server so billing stops)
26
+
27
+ Eval-only job (fast — refresh README metrics from Hub checkpoint, no GRPO):
28
+
29
+ hf jobs run ... \\
30
+ --env JOBS_EVAL_ONLY=1 \\
31
+ --env EPISODES=10000 \\
32
+ --env EVAL_EPISODES=18 \\
33
+ --secret HF_TOKEN=hf_xxx \\
34
+ python train/jobs_run.py
35
+
36
+ If training finished but reports/ were not updated, run locally (with checkpoint + env):
37
+ EPISODES=<same as job> python train/post_training_eval.py
38
+
39
+ Environment variables consumed:
40
+
41
+ Required:
42
+ HF_TOKEN — HF write token (used to push checkpoint)
43
+ Optional (with defaults):
44
+ WANDB_API_KEY — enables WandB logging if set
45
+ WANDB_ENTITY — wandb entity (default: aniketaslaliya-lnmiit)
46
+ EPISODES — training episodes (default: 10000)
47
+ EPOCHS — training epochs (default: 2)
48
+ BATCH_SIZE — per-device batch (default: 4)
49
+ NUM_GENERATIONS — GRPO group size (default: 4)
50
+ GRAD_ACCUM — gradient accumulation steps (default: 2)
51
+ MAX_COMPLETION_LENGTH — output token cap (default: 80)
52
+ MAX_PROMPT_LENGTH — prompt token cap (default: 512)
53
+ DISABLE_VARIANCE_GUARD — bypass CF-1 guard (default: 1)
54
+ HF_MODEL_REPO — where to push the trained model
55
+ (default: AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct)
56
+ JOBS_EVAL_ONLY — if 1: skip training; download checkpoint from HF_MODEL_REPO,
57
+ run post-training eval, upload reports + docs only (fast).
58
+ EVAL_EPISODES — optional; larger = more stable eval means (e.g. 18).
59
+ """
60
+ from __future__ import annotations
61
+
62
+ import functools
63
+ import os
64
+ import signal
65
+ import subprocess
66
+ import sys
67
+ import time
68
+ from pathlib import Path
69
+
70
+ # Force unbuffered stdout/stderr so HF Jobs log viewer shows every line in
71
+ # real time. Without this, prints sit in a 4KB buffer and the user only sees
72
+ # "Job started" for several minutes — making working jobs look broken.
73
+ os.environ["PYTHONUNBUFFERED"] = "1"
74
+ try:
75
+ sys.stdout.reconfigure(line_buffering=True)
76
+ sys.stderr.reconfigure(line_buffering=True)
77
+ except AttributeError:
78
+ pass
79
+ print = functools.partial(print, flush=True) # noqa: A001 — intentional shadow
80
+
81
+ # Heartbeat: a single line every minute so the user knows the job is alive
82
+ # even during slow phases (pip install, model download, dataset prep).
83
+ _HEARTBEAT_START = time.time()
84
+
85
+
86
+ def _hb(label: str) -> None:
87
+ elapsed = int(time.time() - _HEARTBEAT_START)
88
+ mm, ss = divmod(elapsed, 60)
89
+ print(f"[heartbeat +{mm:02d}:{ss:02d}] {label}")
90
+
91
+
92
+ # ── [0/6] Bootstrap the repo (when running as a one-shot script) ────────────
93
+ # When this file is executed via `python -c "exec(...)"` or downloaded as a
94
+ # raw script, it has no surrounding repo. Detect that and `git clone` ourselves
95
+ # so the rest of the script sees the real layout.
96
+ _BOOTSTRAP_MARKER = Path(__file__).resolve().parent.parent / "app" / "main.py"
97
+ if not _BOOTSTRAP_MARKER.exists():
98
+ print("[0/6] Bootstrap: no repo on disk, cloning from GitHub", flush=True)
99
+ _clone_dir = Path("/tmp/debatefloor")
100
+ if not _clone_dir.exists():
101
+ subprocess.check_call(
102
+ ["git", "clone", "--depth", "1",
103
+ "https://github.com/AniketAslaliya/debateFloor.git",
104
+ str(_clone_dir)]
105
+ )
106
+ os.chdir(_clone_dir)
107
+ REPO_ROOT = _clone_dir
108
+ else:
109
+ REPO_ROOT = Path(__file__).resolve().parent.parent
110
+ os.chdir(REPO_ROOT)
111
+
112
+ sys.path.insert(0, str(REPO_ROOT))
113
+
114
+ _hb("driver script started")
115
+ print("=" * 70)
116
+ print("[1/6] Installing pinned deps from requirements files")
117
+ print("=" * 70)
118
+
119
+
120
+ def _pip_install(*args: str) -> None:
121
+ cmd = [sys.executable, "-m", "pip", "install", "--quiet", *args]
122
+ print(f" $ {' '.join(cmd)}")
123
+ subprocess.check_call(cmd)
124
+
125
+
126
+ _pip_install("--upgrade", "pip")
127
+ _hb("upgraded pip")
128
+ _pip_install("-r", "requirements.txt")
129
+ _hb("installed root requirements.txt")
130
+ _pip_install("-r", "train/requirements.txt")
131
+ _hb("installed train/requirements.txt")
132
+
133
+ # ── [1.4/6] Purge torchvision AND evict it from sys.modules.
134
+ #
135
+ # Two-part problem:
136
+ # (1) The HF Jobs base image claims 'pytorch:2.4.0-cuda12.1' but actually
137
+ # ships torch 2.11.0+cu130, so any torchvision pin we make is wrong.
138
+ # (2) Even after `pip uninstall torchvision`, Python keeps the partially-
139
+ # loaded torchvision modules in sys.modules from earlier `pip install`
140
+ # work, so `import transformers` still hits the broken cached state and
141
+ # fails with "partially initialized module 'torchvision' has no
142
+ # attribute 'extension'".
143
+ #
144
+ # Fix: uninstall the package AND surgically evict every torchvision.* entry
145
+ # from sys.modules so the next import attempt sees a clean slate.
146
+ print("\n Purging torchvision (text-only training, not needed)...")
147
+ try:
148
+ subprocess.check_call(
149
+ [sys.executable, "-m", "pip", "uninstall", "-y", "-q", "torchvision"]
150
+ )
151
+ print(" Removed torchvision package from environment")
152
+ except subprocess.CalledProcessError:
153
+ print(" torchvision not installed — nothing to remove")
154
+
155
+ _evicted = [k for k in list(sys.modules) if k == "torchvision" or k.startswith("torchvision.")]
156
+ for _k in _evicted:
157
+ del sys.modules[_k]
158
+ if _evicted:
159
+ print(f" Evicted {len(_evicted)} torchvision modules from sys.modules cache")
160
+
161
+ # Also evict any partially-loaded transformers modules that might have already
162
+ # tried to import torchvision and cached a broken state (e.g. from this script
163
+ # importing `requests` earlier, which doesn't touch transformers, but be safe).
164
+ _tf_evicted = [k for k in list(sys.modules) if k == "transformers" or k.startswith("transformers.")]
165
+ for _k in _tf_evicted:
166
+ del sys.modules[_k]
167
+ if _tf_evicted:
168
+ print(f" Evicted {len(_tf_evicted)} transformers modules from sys.modules cache")
169
+
170
+ # Tell transformers to be tolerant of missing optional vision deps (defense in
171
+ # depth; the uninstall + sys.modules eviction is what actually fixes it).
172
+ os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
173
+
174
+ # ── [1.5/6] Sanity-check critical imports BEFORE we boot the env + load model.
175
+ print("\n Sanity-checking critical imports...")
176
+ _failed = []
177
+ for _mod, _from in [
178
+ ("torch", None),
179
+ ("transformers", "PreTrainedModel"), # forces full transformers init
180
+ ("trl", "GRPOConfig"), # forces grpo_trainer import
181
+ ("peft", "LoraConfig"),
182
+ ("accelerate", "Accelerator"),
183
+ ("datasets", "Dataset"),
184
+ ("wandb", None),
185
+ ]:
186
+ try:
187
+ if _from:
188
+ _m = __import__(_mod, fromlist=[_from])
189
+ getattr(_m, _from)
190
+ else:
191
+ __import__(_mod)
192
+ try:
193
+ _v = __import__(_mod).__version__
194
+ except Exception:
195
+ _v = "?"
196
+ print(f" ok {_mod:14s} {_v}")
197
+ except Exception as _e:
198
+ print(f" FAIL {_mod:14s} → {type(_e).__name__}: {_e}")
199
+ _failed.append((_mod, _from, _e))
200
+
201
+ if _failed:
202
+ print("\n Sanity check failed — aborting before model download.")
203
+ raise SystemExit(1)
204
+
205
+ print(" All critical imports OK.\n")
206
+ _hb("import sanity check passed")
207
+ print(" Deps installed.\n")
208
+
209
+
210
+ # ── [2/6] Boot the env server in the background ─────────────────────────────
211
+ import requests as _requests # imported AFTER pip install -r requirements.txt
212
+
213
+ print("=" * 70)
214
+ print("[2/6] Booting DebateFloor env server on 127.0.0.1:7860")
215
+ print("=" * 70)
216
+
217
+ ENV_BASE_URL = "http://127.0.0.1:7860"
218
+ _log_path = Path("/tmp/uvicorn_debatefloor.log")
219
+ _log_file = open(_log_path, "w")
220
+
221
+ env_proc = subprocess.Popen(
222
+ [
223
+ sys.executable,
224
+ "-m",
225
+ "uvicorn",
226
+ "app.main:app",
227
+ "--host",
228
+ "127.0.0.1",
229
+ "--port",
230
+ "7860",
231
+ "--log-level",
232
+ "warning",
233
+ ],
234
+ cwd=str(REPO_ROOT),
235
+ stdout=_log_file,
236
+ stderr=subprocess.STDOUT,
237
+ )
238
+ print(f" uvicorn PID = {env_proc.pid}")
239
+
240
+
241
+ # ── [3/6] Wait for /health ──────────────────────────────────────────────────
242
+ print("\n" + "=" * 70)
243
+ print("[3/6] Waiting for env server /health")
244
+ print("=" * 70)
245
+
246
+
247
+ def _wait_for_env(max_tries: int = 60) -> None:
248
+ for i in range(max_tries):
249
+ if env_proc.poll() is not None:
250
+ log = _log_path.read_text()[-4000:]
251
+ raise RuntimeError(f"uvicorn died before /health was ready. Log:\n{log}")
252
+ try:
253
+ r = _requests.get(f"{ENV_BASE_URL}/health", timeout=3)
254
+ if r.status_code == 200 and r.json().get("status") == "healthy":
255
+ print(f" Healthy after {i + 1} attempts.")
256
+ return
257
+ except Exception:
258
+ pass
259
+ time.sleep(2)
260
+ log = _log_path.read_text()[-4000:]
261
+ raise RuntimeError(f"Env never became healthy. Log:\n{log}")
262
+
263
+
264
+ _wait_for_env()
265
+ _hb("env server is healthy and accepting requests")
266
+
267
+
268
+ # ── [4/6] Run training ─────────────────────────────────────────────────────
269
+ print("\n" + "=" * 70)
270
+ print("[4/6] Running train.train_minimal.main()")
271
+ print("=" * 70)
272
+ _hb("starting training phase — model download may take 1–2 min on first run")
273
+
274
+ # Surface key config so the log shows what we ran with
275
+ EPISODES = int(os.environ.get("EPISODES", "10000"))
276
+ EPOCHS = int(os.environ.get("EPOCHS", "2"))
277
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
278
+ print(f" EPISODES={EPISODES} EPOCHS={EPOCHS} BATCH_SIZE={BATCH_SIZE}")
279
+ print(f" NUM_GENERATIONS={os.environ.get('NUM_GENERATIONS', '4')}")
280
+ print(f" GRAD_ACCUM={os.environ.get('GRAD_ACCUM', '2')}")
281
+ print(f" MAX_COMPLETION_LENGTH={os.environ.get('MAX_COMPLETION_LENGTH', '80')}")
282
+ print(
283
+ f" DISABLE_VARIANCE_GUARD={os.environ.get('DISABLE_VARIANCE_GUARD', '1')}"
284
+ )
285
+ os.environ.setdefault("DISABLE_VARIANCE_GUARD", "1")
286
+ os.environ.setdefault("NUM_GENERATIONS", "4")
287
+ os.environ.setdefault("GRAD_ACCUM", "2")
288
+ os.environ.setdefault("MAX_COMPLETION_LENGTH", "80")
289
+ os.environ.setdefault("MAX_PROMPT_LENGTH", "512")
290
+ os.environ["ENV_BASE_URL"] = ENV_BASE_URL
291
+
292
+ import train.train_minimal as tm # noqa: E402
293
+
294
+ tm.MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
295
+ tm.EPISODES = EPISODES
296
+ tm.EPOCHS = EPOCHS
297
+ tm.BATCH_SIZE = BATCH_SIZE
298
+ tm.USE_WANDB = bool(os.environ.get("WANDB_API_KEY", ""))
299
+ tm.WANDB_KEY = os.environ.get("WANDB_API_KEY", "")
300
+ tm.WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "aniketaslaliya-lnmiit")
301
+ tm.ENV_BASE_URL = ENV_BASE_URL
302
+
303
+ import torch # noqa: E402
304
+
305
+ tm.HAS_BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
306
+ tm.USE_FP16 = torch.cuda.is_available() and not tm.HAS_BF16
307
+ tm.DTYPE = torch.bfloat16 if tm.HAS_BF16 else torch.float16
308
+ print(f" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
309
+ print(f" dtype: {tm.DTYPE} | Unsloth: {tm.USE_UNSLOTH}\n")
310
+
311
+ _ee = os.getenv("EVAL_EPISODES", "").strip()
312
+ if _ee:
313
+ tm.EVAL_EPISODES = int(_ee)
314
+ print(f" EVAL_EPISODES={tm.EVAL_EPISODES} (env override)\n")
315
+
316
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
317
+ HF_MODEL_REPO = os.environ.get(
318
+ "HF_MODEL_REPO",
319
+ "AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct",
320
+ )
321
+
322
+ train_exit_code = 0
323
+ EVAL_ONLY = os.getenv("JOBS_EVAL_ONLY", "").strip().lower() in ("1", "true", "yes")
324
+
325
+ if EVAL_ONLY:
326
+ print("\n" + "=" * 70)
327
+ print("[4/6] JOBS_EVAL_ONLY=1 — skip GRPO; Hub checkpoint + post-training eval")
328
+ print("=" * 70)
329
+ if not HF_TOKEN:
330
+ print(" ERROR: JOBS_EVAL_ONLY requires HF_TOKEN (download checkpoint).")
331
+ train_exit_code = 1
332
+ else:
333
+ try:
334
+ import shutil
335
+
336
+ from huggingface_hub import snapshot_download
337
+
338
+ ckpt_dl = REPO_ROOT / "debatefloor_checkpoint"
339
+ if ckpt_dl.exists():
340
+ shutil.rmtree(ckpt_dl)
341
+ print(f" snapshot_download {HF_MODEL_REPO} -> {ckpt_dl}")
342
+ snapshot_download(
343
+ repo_id=HF_MODEL_REPO,
344
+ repo_type="model",
345
+ local_dir=str(ckpt_dl),
346
+ token=HF_TOKEN,
347
+ ignore_patterns=[
348
+ "reports/**",
349
+ "docs/**",
350
+ "*.md",
351
+ ".gitattributes",
352
+ ],
353
+ )
354
+ from train.post_training_eval import run_eval # noqa: E402
355
+
356
+ run_eval(ckpt_dl, fresh_summary=False, stop_env_server=False)
357
+ print(" Eval-only run completed.")
358
+ except Exception as exc:
359
+ train_exit_code = 1
360
+ print(f" JOBS_EVAL_ONLY raised: {type(exc).__name__}: {exc}")
361
+ import traceback
362
+
363
+ traceback.print_exc()
364
+ else:
365
+ try:
366
+ tm.main()
367
+ print(" Training completed.")
368
+ except Exception as exc: # don't crash the whole job — we still want artifacts
369
+ train_exit_code = 1
370
+ print(f" Training raised: {type(exc).__name__}: {exc}")
371
+ import traceback
372
+
373
+ traceback.print_exc()
374
+
375
+
376
+ # ── [5/6] Push artifacts to the HF Hub model repo ───────────────────────────
377
+ print("\n" + "=" * 70)
378
+ print("[5/6] Uploading artifacts to HF Hub")
379
+ print("=" * 70)
380
+
381
+ if not HF_TOKEN:
382
+ print(" HF_TOKEN not set — skipping upload (artifacts remain in job storage).")
383
+ else:
384
+ try:
385
+ from huggingface_hub import HfApi, login
386
+
387
+ login(token=HF_TOKEN, add_to_git_credential=False)
388
+ api = HfApi(token=HF_TOKEN)
389
+ api.create_repo(repo_id=HF_MODEL_REPO, repo_type="model", exist_ok=True)
390
+
391
+ ckpt_dir = Path("./debatefloor_checkpoint")
392
+ if EVAL_ONLY:
393
+ print(" JOBS_EVAL_ONLY: skipping checkpoint upload (weights already on Hub).")
394
+ elif ckpt_dir.exists() and any(ckpt_dir.iterdir()):
395
+ print(f" Uploading checkpoint folder -> {HF_MODEL_REPO}")
396
+ api.upload_folder(
397
+ folder_path=str(ckpt_dir),
398
+ repo_id=HF_MODEL_REPO,
399
+ repo_type="model",
400
+ commit_message=f"GRPO HF Jobs run: {EPISODES} episodes x {EPOCHS} epochs",
401
+ )
402
+ else:
403
+ print(" No ./debatefloor_checkpoint to upload (training may have failed early).")
404
+
405
+ for artifact in [
406
+ "reports/training_summary.json",
407
+ "reports/component_shift_summary.json",
408
+ "docs/reward_curve.svg",
409
+ "docs/component_shift.svg",
410
+ ]:
411
+ p = Path(artifact)
412
+ if p.exists():
413
+ print(f" Uploading {artifact}")
414
+ api.upload_file(
415
+ path_or_fileobj=str(p),
416
+ path_in_repo=artifact,
417
+ repo_id=HF_MODEL_REPO,
418
+ repo_type="model",
419
+ commit_message=f"Update {artifact} from HF Jobs run",
420
+ )
421
+ else:
422
+ print(f" Skipping {artifact} (not found)")
423
+ except Exception as exc:
424
+ print(f" Upload step raised: {type(exc).__name__}: {exc}")
425
+
426
+
427
+ # ── [6/6] Clean shutdown so HF Jobs stops billing ───────────────────────────
428
+ print("\n" + "=" * 70)
429
+ print("[6/6] Shutting down env server cleanly")
430
+ print("=" * 70)
431
+ try:
432
+ env_proc.send_signal(signal.SIGTERM)
433
+ env_proc.wait(timeout=10)
434
+ except Exception:
435
+ env_proc.kill()
436
+ print(" Done.")
437
+ sys.exit(train_exit_code)