AniketAsla commited on
Commit
37ec93d
·
verified ·
1 Parent(s): 4398db7

deploy: update train/post_training_eval.py

Browse files
Files changed (1) hide show
  1. train/post_training_eval.py +194 -194
train/post_training_eval.py CHANGED
@@ -1,194 +1,194 @@
1
- """
2
- post_training_eval.py — Re-run before/after component eval without GRPO training.
3
-
4
- Use when:
5
- - Training finished but the process exited before save_training_artifacts(), or
6
- - You want fresh eval plots/JSON from an existing checkpoint.
7
-
8
- Prerequisites:
9
- - Live ClaimCourt / DebateFloor env at ENV_BASE_URL (or let this script start uvicorn on :7860).
10
- - Checkpoint folder from training (default ./debatefloor_checkpoint).
11
-
12
- Match training episode count so eval episodes are drawn from the same pool as train_minimal:
13
- EPISODES=10000 EPOCHS=2 BATCH_SIZE=4 python train/post_training_eval.py
14
-
15
- Optional: larger held-out eval (more stable headline numbers):
16
- EVAL_EPISODES=18 python train/post_training_eval.py
17
-
18
- Usage:
19
- cd repo-root
20
- set PYTHONPATH=.
21
- python train/post_training_eval.py
22
- python train/post_training_eval.py --checkpoint path/to/merged_model
23
- """
24
- from __future__ import annotations
25
-
26
- import argparse
27
- import json
28
- import os
29
- import sys
30
- from pathlib import Path
31
- from types import SimpleNamespace
32
-
33
- # Repo root = parent of train/
34
- REPO_ROOT = Path(__file__).resolve().parent.parent
35
- os.chdir(REPO_ROOT)
36
- sys.path.insert(0, str(REPO_ROOT))
37
-
38
- os.environ.setdefault("PYTHONUNBUFFERED", "1")
39
-
40
-
41
- def _parse_args() -> argparse.Namespace:
42
- p = argparse.ArgumentParser(description="Post-training eval only (refresh reports + docs plots).")
43
- p.add_argument(
44
- "--checkpoint",
45
- default=os.environ.get("CHECKPOINT_PATH", "debatefloor_checkpoint"),
46
- help="HF-style folder with config + weights (default: ./debatefloor_checkpoint)",
47
- )
48
- p.add_argument(
49
- "--fresh-summary",
50
- action="store_true",
51
- help="Do not merge log_history from reports/training_summary.json (eval-only; empty reward curve).",
52
- )
53
- return p.parse_args()
54
-
55
-
56
- def run_eval(
57
- checkpoint: str | Path,
58
- *,
59
- fresh_summary: bool = False,
60
- stop_env_server: bool | None = None,
61
- ) -> None:
62
- """
63
- Run before/after component eval and write reports + docs plots.
64
-
65
- stop_env_server: if True, terminate subprocess uvicorn started here.
66
- If False, leave running. If None (default), stop only if we started it
67
- (same as CLI behaviour).
68
- """
69
- ckpt = Path(checkpoint).resolve()
70
- if not ckpt.is_dir():
71
- raise FileNotFoundError(f"checkpoint directory not found: {ckpt}")
72
-
73
- import torch
74
-
75
- import train.train_minimal as tm
76
-
77
- tm.EPISODES = int(os.environ.get("EPISODES", str(tm.EPISODES)))
78
- tm.EPOCHS = int(os.environ.get("EPOCHS", str(tm.EPOCHS)))
79
- tm.BATCH_SIZE = int(os.environ.get("BATCH_SIZE", str(tm.BATCH_SIZE)))
80
- tm.ENV_BASE_URL = os.environ.get("ENV_BASE_URL", tm.ENV_BASE_URL)
81
- tm.MODEL_NAME = os.environ.get("MODEL_NAME", tm.MODEL_NAME)
82
- tm.HAS_BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
83
- tm.USE_FP16 = torch.cuda.is_available() and not tm.HAS_BF16
84
- tm.DTYPE = torch.bfloat16 if tm.HAS_BF16 else torch.float16
85
-
86
- server_proc = tm._start_env_server_if_needed(tm.ENV_BASE_URL)
87
- _we_started_env = server_proc is not None
88
- if stop_env_server is None:
89
- stop_env_server = _we_started_env
90
-
91
- print(f"[OK] Env: {tm.ENV_BASE_URL} | EPISODES={tm.EPISODES} EVAL_EPISODES={tm.EVAL_EPISODES}")
92
-
93
- episode_pool = tm.generate_episode_pool(count=tm.EPISODES + (tm.EVAL_EPISODES * 4))
94
- eval_episodes = tm._select_eval_episodes(episode_pool[tm.EPISODES :])
95
- print(f" Eval pool: {len(eval_episodes)} episodes")
96
-
97
- if tm.USE_UNSLOTH:
98
- print(f"Loading base via Unsloth: {tm.MODEL_NAME}")
99
- model, tok = tm.FastLanguageModel.from_pretrained(
100
- model_name=tm.MODEL_NAME,
101
- max_seq_length=512,
102
- dtype=None,
103
- load_in_4bit=True,
104
- )
105
- tm.FastLanguageModel.for_inference(model)
106
- else:
107
- from transformers import AutoModelForCausalLM, AutoTokenizer
108
-
109
- print(f"Loading base via transformers: {tm.MODEL_NAME}")
110
- tok = AutoTokenizer.from_pretrained(tm.MODEL_NAME)
111
- if tok.pad_token is None:
112
- tok.pad_token = tok.eos_token
113
- model = AutoModelForCausalLM.from_pretrained(
114
- tm.MODEL_NAME,
115
- torch_dtype=tm.DTYPE,
116
- device_map="auto",
117
- )
118
-
119
- tm._tok_ref = tok
120
- print("Baseline eval (before)...")
121
- before_eval = tm.evaluate_component_shift(model, tok, eval_episodes)
122
- print(f" Before: {before_eval['means']}")
123
-
124
- del model
125
- if torch.cuda.is_available():
126
- torch.cuda.empty_cache()
127
-
128
- from transformers import AutoModelForCausalLM, AutoTokenizer
129
-
130
- print(f"Loading checkpoint: {ckpt}")
131
- tok_ft = AutoTokenizer.from_pretrained(str(ckpt))
132
- if tok_ft.pad_token is None:
133
- tok_ft.pad_token = tok_ft.eos_token
134
- model_ft = AutoModelForCausalLM.from_pretrained(
135
- str(ckpt),
136
- torch_dtype=tm.DTYPE,
137
- device_map="auto",
138
- )
139
- tm._tok_ref = tok_ft
140
-
141
- print("Post-training eval (after)...")
142
- after_eval = tm.evaluate_component_shift(model_ft, tok_ft, eval_episodes)
143
- print(f" After: {after_eval['means']}")
144
-
145
- log_history: list = []
146
- global_step = 0
147
- training_loss = 0.0
148
- summary_path = Path("reports/training_summary.json")
149
- if not fresh_summary and summary_path.exists():
150
- try:
151
- prev = json.loads(summary_path.read_text(encoding="utf-8"))
152
- log_history = list(prev.get("log_history") or [])
153
- global_step = int(prev.get("global_step") or 0)
154
- training_loss = float(prev.get("training_loss") or 0.0)
155
- print(f" Preserved {len(log_history)} log_history rows from existing summary.")
156
- except Exception as exc:
157
- print(f" [WARN] Could not read prior summary: {exc}")
158
-
159
- trainer = SimpleNamespace(state=SimpleNamespace(log_history=log_history))
160
- result = SimpleNamespace(global_step=global_step, training_loss=training_loss)
161
-
162
- tm.save_training_artifacts(
163
- trainer,
164
- result,
165
- before_eval["means"],
166
- after_eval["means"],
167
- )
168
- print("[OK] Updated reports/training_summary.json, docs/*.svg, reports/component_shift_summary.json")
169
-
170
- if stop_env_server and server_proc is not None:
171
- server_proc.terminate()
172
- try:
173
- server_proc.wait(timeout=5)
174
- except Exception:
175
- server_proc.kill()
176
- print("[STOP] Stopped subprocess env server.")
177
-
178
-
179
- def main() -> None:
180
- args = _parse_args()
181
- ckpt = Path(args.checkpoint).resolve()
182
- if not ckpt.is_dir():
183
- print(f"ERROR: checkpoint directory not found: {ckpt}")
184
- print("Train first (saves ./debatefloor_checkpoint) or pass --checkpoint /path/to/model")
185
- sys.exit(1)
186
- try:
187
- run_eval(ckpt, fresh_summary=args.fresh_summary)
188
- except Exception as exc:
189
- print(f"ERROR: {type(exc).__name__}: {exc}")
190
- raise
191
-
192
-
193
- if __name__ == "__main__":
194
- main()
 
1
+ """
2
+ post_training_eval.py — Re-run before/after component eval without GRPO training.
3
+
4
+ Use when:
5
+ - Training finished but the process exited before save_training_artifacts(), or
6
+ - You want fresh eval plots/JSON from an existing checkpoint.
7
+
8
+ Prerequisites:
9
+ - Live ClaimCourt / DebateFloor env at ENV_BASE_URL (or let this script start uvicorn on :7860).
10
+ - Checkpoint folder from training (default ./debatefloor_checkpoint).
11
+
12
+ Match training episode count so eval episodes are drawn from the same pool as train_minimal:
13
+ EPISODES=10000 EPOCHS=2 BATCH_SIZE=4 python train/post_training_eval.py
14
+
15
+ Optional: larger held-out eval (more stable headline numbers):
16
+ EVAL_EPISODES=18 python train/post_training_eval.py
17
+
18
+ Usage:
19
+ cd repo-root
20
+ set PYTHONPATH=.
21
+ python train/post_training_eval.py
22
+ python train/post_training_eval.py --checkpoint path/to/merged_model
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import json
28
+ import os
29
+ import sys
30
+ from pathlib import Path
31
+ from types import SimpleNamespace
32
+
33
+ # Repo root = parent of train/
34
+ REPO_ROOT = Path(__file__).resolve().parent.parent
35
+ os.chdir(REPO_ROOT)
36
+ sys.path.insert(0, str(REPO_ROOT))
37
+
38
+ os.environ.setdefault("PYTHONUNBUFFERED", "1")
39
+
40
+
41
+ def _parse_args() -> argparse.Namespace:
42
+ p = argparse.ArgumentParser(description="Post-training eval only (refresh reports + docs plots).")
43
+ p.add_argument(
44
+ "--checkpoint",
45
+ default=os.environ.get("CHECKPOINT_PATH", "debatefloor_checkpoint"),
46
+ help="HF-style folder with config + weights (default: ./debatefloor_checkpoint)",
47
+ )
48
+ p.add_argument(
49
+ "--fresh-summary",
50
+ action="store_true",
51
+ help="Do not merge log_history from reports/training_summary.json (eval-only; empty reward curve).",
52
+ )
53
+ return p.parse_args()
54
+
55
+
56
+ def run_eval(
57
+ checkpoint: str | Path,
58
+ *,
59
+ fresh_summary: bool = False,
60
+ stop_env_server: bool | None = None,
61
+ ) -> None:
62
+ """
63
+ Run before/after component eval and write reports + docs plots.
64
+
65
+ stop_env_server: if True, terminate subprocess uvicorn started here.
66
+ If False, leave running. If None (default), stop only if we started it
67
+ (same as CLI behaviour).
68
+ """
69
+ ckpt = Path(checkpoint).resolve()
70
+ if not ckpt.is_dir():
71
+ raise FileNotFoundError(f"checkpoint directory not found: {ckpt}")
72
+
73
+ import torch
74
+
75
+ import train.train_minimal as tm
76
+
77
+ tm.EPISODES = int(os.environ.get("EPISODES", str(tm.EPISODES)))
78
+ tm.EPOCHS = int(os.environ.get("EPOCHS", str(tm.EPOCHS)))
79
+ tm.BATCH_SIZE = int(os.environ.get("BATCH_SIZE", str(tm.BATCH_SIZE)))
80
+ tm.ENV_BASE_URL = os.environ.get("ENV_BASE_URL", tm.ENV_BASE_URL)
81
+ tm.MODEL_NAME = os.environ.get("MODEL_NAME", tm.MODEL_NAME)
82
+ tm.HAS_BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
83
+ tm.USE_FP16 = torch.cuda.is_available() and not tm.HAS_BF16
84
+ tm.DTYPE = torch.bfloat16 if tm.HAS_BF16 else torch.float16
85
+
86
+ server_proc = tm._start_env_server_if_needed(tm.ENV_BASE_URL)
87
+ _we_started_env = server_proc is not None
88
+ if stop_env_server is None:
89
+ stop_env_server = _we_started_env
90
+
91
+ print(f"[OK] Env: {tm.ENV_BASE_URL} | EPISODES={tm.EPISODES} EVAL_EPISODES={tm.EVAL_EPISODES}")
92
+
93
+ episode_pool = tm.generate_episode_pool(count=tm.EPISODES + (tm.EVAL_EPISODES * 4))
94
+ eval_episodes = tm._select_eval_episodes(episode_pool[tm.EPISODES :])
95
+ print(f" Eval pool: {len(eval_episodes)} episodes")
96
+
97
+ if tm.USE_UNSLOTH:
98
+ print(f"Loading base via Unsloth: {tm.MODEL_NAME}")
99
+ model, tok = tm.FastLanguageModel.from_pretrained(
100
+ model_name=tm.MODEL_NAME,
101
+ max_seq_length=512,
102
+ dtype=None,
103
+ load_in_4bit=True,
104
+ )
105
+ tm.FastLanguageModel.for_inference(model)
106
+ else:
107
+ from transformers import AutoModelForCausalLM, AutoTokenizer
108
+
109
+ print(f"Loading base via transformers: {tm.MODEL_NAME}")
110
+ tok = AutoTokenizer.from_pretrained(tm.MODEL_NAME)
111
+ if tok.pad_token is None:
112
+ tok.pad_token = tok.eos_token
113
+ model = AutoModelForCausalLM.from_pretrained(
114
+ tm.MODEL_NAME,
115
+ torch_dtype=tm.DTYPE,
116
+ device_map="auto",
117
+ )
118
+
119
+ tm._tok_ref = tok
120
+ print("Baseline eval (before)...")
121
+ before_eval = tm.evaluate_component_shift(model, tok, eval_episodes)
122
+ print(f" Before: {before_eval['means']}")
123
+
124
+ del model
125
+ if torch.cuda.is_available():
126
+ torch.cuda.empty_cache()
127
+
128
+ from transformers import AutoModelForCausalLM, AutoTokenizer
129
+
130
+ print(f"Loading checkpoint: {ckpt}")
131
+ tok_ft = AutoTokenizer.from_pretrained(str(ckpt))
132
+ if tok_ft.pad_token is None:
133
+ tok_ft.pad_token = tok_ft.eos_token
134
+ model_ft = AutoModelForCausalLM.from_pretrained(
135
+ str(ckpt),
136
+ torch_dtype=tm.DTYPE,
137
+ device_map="auto",
138
+ )
139
+ tm._tok_ref = tok_ft
140
+
141
+ print("Post-training eval (after)...")
142
+ after_eval = tm.evaluate_component_shift(model_ft, tok_ft, eval_episodes)
143
+ print(f" After: {after_eval['means']}")
144
+
145
+ log_history: list = []
146
+ global_step = 0
147
+ training_loss = 0.0
148
+ summary_path = Path("reports/training_summary.json")
149
+ if not fresh_summary and summary_path.exists():
150
+ try:
151
+ prev = json.loads(summary_path.read_text(encoding="utf-8"))
152
+ log_history = list(prev.get("log_history") or [])
153
+ global_step = int(prev.get("global_step") or 0)
154
+ training_loss = float(prev.get("training_loss") or 0.0)
155
+ print(f" Preserved {len(log_history)} log_history rows from existing summary.")
156
+ except Exception as exc:
157
+ print(f" [WARN] Could not read prior summary: {exc}")
158
+
159
+ trainer = SimpleNamespace(state=SimpleNamespace(log_history=log_history))
160
+ result = SimpleNamespace(global_step=global_step, training_loss=training_loss)
161
+
162
+ tm.save_training_artifacts(
163
+ trainer,
164
+ result,
165
+ before_eval["means"],
166
+ after_eval["means"],
167
+ )
168
+ print("[OK] Updated reports/training_summary.json, docs/*.svg, reports/component_shift_summary.json")
169
+
170
+ if stop_env_server and server_proc is not None:
171
+ server_proc.terminate()
172
+ try:
173
+ server_proc.wait(timeout=5)
174
+ except Exception:
175
+ server_proc.kill()
176
+ print("[STOP] Stopped subprocess env server.")
177
+
178
+
179
+ def main() -> None:
180
+ args = _parse_args()
181
+ ckpt = Path(args.checkpoint).resolve()
182
+ if not ckpt.is_dir():
183
+ print(f"ERROR: checkpoint directory not found: {ckpt}")
184
+ print("Train first (saves ./debatefloor_checkpoint) or pass --checkpoint /path/to/model")
185
+ sys.exit(1)
186
+ try:
187
+ run_eval(ckpt, fresh_summary=args.fresh_summary)
188
+ except Exception as exc:
189
+ print(f"ERROR: {type(exc).__name__}: {exc}")
190
+ raise
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()