Spaces:
Runtime error
Runtime error
muskan singh Claude Opus 4.7 commited on
Commit ·
2ab0fe0
1
Parent(s): 7a0b2ce
fix: pin trl<=0.24, multi-step reward, lower LR, reduce NUM_GEN
Browse files- requirements.txt: pin trl>=0.18.2,<=0.24.0 (trl 1.x breaks Unsloth patches → silent crash at step 21)
- train.py: multi-step reward fn (REWARD_STEPS=2) — cumulative score not single-step
- train.py: NUM_GEN 4→2 to halve VRAM pressure from G×reward_steps inference calls
- train.py: LR 5e-5→8e-6 (5e-5 was unstable, caused reward oscillation)
- train.py: switch to max_steps=150 training (more reliable than epoch-based)
- train.py: model.config.max_length=None to silence max_new_tokens warning
- train.py: reward_funcs=[orgos_reward_fn] as list (required by TRL)
- train.py: BATCH_SIZE 4→1 with GRAD_ACCUM=2 (matches memory budget with multi-step reward)
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- requirements.txt +1 -1
- train.py +101 -51
requirements.txt
CHANGED
|
@@ -11,7 +11,7 @@ aiofiles>=23.0.0
|
|
| 11 |
torch
|
| 12 |
transformers
|
| 13 |
datasets
|
| 14 |
-
trl
|
| 15 |
unsloth
|
| 16 |
matplotlib
|
| 17 |
numpy
|
|
|
|
| 11 |
torch
|
| 12 |
transformers
|
| 13 |
datasets
|
| 14 |
+
trl>=0.18.2,<=0.24.0
|
| 15 |
unsloth
|
| 16 |
matplotlib
|
| 17 |
numpy
|
train.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
OrgOS GRPO Training Script
|
| 3 |
-
|
| 4 |
|
| 5 |
Outputs:
|
| 6 |
training_log.txt — structured training log for submission
|
|
@@ -35,20 +35,22 @@ from unsloth import FastLanguageModel
|
|
| 35 |
# Config
|
| 36 |
# ------------------------------------------------------------------
|
| 37 |
|
| 38 |
-
MODEL_NAME = os.environ.get("MODEL_NAME", "
|
| 39 |
ENV_URL = "http://localhost:8000"
|
| 40 |
LOG_FILE = "training_log.txt"
|
| 41 |
N_PROMPTS_PER_WORKFLOW = 20
|
| 42 |
N_EVAL = 10
|
| 43 |
-
|
| 44 |
-
BATCH_SIZE =
|
| 45 |
GRAD_ACCUM = 2
|
| 46 |
-
LR = 5e-5
|
| 47 |
-
NUM_GEN =
|
| 48 |
-
TEMPERATURE = 0.
|
| 49 |
BETA = 0.04
|
| 50 |
LORA_R = 16
|
| 51 |
MAX_SEQ_LEN = 4096
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# ------------------------------------------------------------------
|
| 54 |
# Logger
|
|
@@ -77,7 +79,6 @@ def start_env_server():
|
|
| 77 |
stdout=None,
|
| 78 |
stderr=None,
|
| 79 |
)
|
| 80 |
-
# Wait until healthy
|
| 81 |
for _ in range(20):
|
| 82 |
time.sleep(2)
|
| 83 |
try:
|
|
@@ -112,6 +113,9 @@ def load_model():
|
|
| 112 |
use_gradient_checkpointing = "unsloth",
|
| 113 |
random_state = 42,
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
| 115 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 116 |
tlog(f"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} "
|
| 117 |
f"max_seq_len={MAX_SEQ_LEN} trainable_params={trainable:,} quantization=4bit")
|
|
@@ -160,6 +164,12 @@ CRITICAL RULES:
|
|
| 160 |
6. Stop when pending_steps is empty or done=true.
|
| 161 |
"""
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
def obs_to_text(obs: dict) -> str:
|
| 165 |
hints = obs.get("schema_hints", {})
|
|
@@ -186,25 +196,15 @@ def obs_to_text(obs: dict) -> str:
|
|
| 186 |
"",
|
| 187 |
"=== APP STATES ===",
|
| 188 |
]
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
"A": {"jira", "zendesk", "salesforce", "workday"},
|
| 192 |
-
"B": {"zendesk", "salesforce", "workday"},
|
| 193 |
-
"C": {"jira", "zendesk", "salesforce"},
|
| 194 |
-
}
|
| 195 |
-
relevant = WORKFLOW_APPS.get(
|
| 196 |
-
obs.get("workflow_id", "A"),
|
| 197 |
-
{"jira", "zendesk", "salesforce", "workday"},
|
| 198 |
-
)
|
| 199 |
for app_name, view in obs.get("app_states", {}).items():
|
| 200 |
if app_name not in relevant:
|
| 201 |
continue
|
| 202 |
-
lines.append(f" [{app_name.upper()}]")
|
| 203 |
view_str = str(view)
|
| 204 |
if len(view_str) > 600:
|
| 205 |
view_str = view_str[:600] + "...[truncated]"
|
| 206 |
-
lines.
|
| 207 |
-
lines.append("")
|
| 208 |
return "\n".join(lines)
|
| 209 |
|
| 210 |
|
|
@@ -244,36 +244,83 @@ def build_prompt_dataset(tokenizer) -> Dataset:
|
|
| 244 |
rows.append({
|
| 245 |
"prompt": build_prompt(obs_text, tokenizer),
|
| 246 |
"workflow_id": wf,
|
| 247 |
-
"obs_text": obs_text,
|
| 248 |
})
|
| 249 |
tlog(f"[TRAIN_CONFIG] algorithm=GRPO prompts={len(rows)} "
|
| 250 |
f"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}")
|
|
|
|
| 251 |
return Dataset.from_list(rows)
|
| 252 |
|
| 253 |
|
| 254 |
# ------------------------------------------------------------------
|
| 255 |
-
# Reward function
|
| 256 |
# ------------------------------------------------------------------
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
workflow_ids = kwargs.get("workflow_id", ["A"] * len(completions))
|
| 260 |
rewards = []
|
|
|
|
| 261 |
for completion, wf_id in zip(completions, workflow_ids):
|
| 262 |
action = parse_action(completion)
|
| 263 |
if action is None:
|
| 264 |
rewards.append(-0.1)
|
| 265 |
continue
|
| 266 |
try:
|
| 267 |
-
|
|
|
|
|
|
|
| 268 |
result = httpx.post(f"{ENV_URL}/step", json=action, timeout=10).json()
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
except Exception:
|
| 271 |
rewards.append(-0.1)
|
|
|
|
| 272 |
return rewards
|
| 273 |
|
| 274 |
|
| 275 |
# ------------------------------------------------------------------
|
| 276 |
-
# Episode evaluation
|
| 277 |
# ------------------------------------------------------------------
|
| 278 |
|
| 279 |
def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int = 15) -> float:
|
|
@@ -284,20 +331,14 @@ def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int =
|
|
| 284 |
if obs["done"]:
|
| 285 |
break
|
| 286 |
|
| 287 |
-
# Stateless single-turn prompt — matches the GRPO training format.
|
| 288 |
-
# obs["message"] already carries last-action feedback, so no history needed.
|
| 289 |
obs_text = obs_to_text(obs)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 294 |
-
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 295 |
|
| 296 |
with torch.no_grad():
|
| 297 |
out = model.generate(
|
| 298 |
**inputs,
|
| 299 |
max_new_tokens = 256,
|
| 300 |
-
temperature = 0.0,
|
| 301 |
do_sample = False,
|
| 302 |
pad_token_id = tokenizer.eos_token_id,
|
| 303 |
)
|
|
@@ -314,7 +355,7 @@ def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int =
|
|
| 314 |
if obs["done"]:
|
| 315 |
break
|
| 316 |
|
| 317 |
-
return obs.get("current_score", 0.001)
|
| 318 |
|
| 319 |
|
| 320 |
def evaluate(model, tokenizer, phase: str) -> dict:
|
|
@@ -325,10 +366,10 @@ def evaluate(model, tokenizer, phase: str) -> dict:
|
|
| 325 |
score = run_episode_with_model(model, tokenizer, wf)
|
| 326 |
scores[wf].append(score)
|
| 327 |
tlog(f"[EVAL] phase={phase} workflow={wf} episode={ep+1} score={score:.4f}")
|
| 328 |
-
wf_mean = np.mean(scores[wf])
|
| 329 |
tlog(f"[EVAL_WORKFLOW] phase={phase} workflow={wf} "
|
| 330 |
f"mean={wf_mean:.4f} min={min(scores[wf]):.4f} max={max(scores[wf]):.4f}")
|
| 331 |
-
overall = np.mean([s for v in scores.values() for s in v])
|
| 332 |
tlog(f"[EVAL_END] phase={phase} overall_mean={overall:.4f}")
|
| 333 |
return scores
|
| 334 |
|
|
@@ -426,46 +467,55 @@ class OrgOSLogCallback(TrainerCallback):
|
|
| 426 |
# ------------------------------------------------------------------
|
| 427 |
|
| 428 |
def main():
|
|
|
|
|
|
|
| 429 |
server_proc = start_env_server()
|
| 430 |
|
| 431 |
try:
|
| 432 |
model, tokenizer = load_model()
|
| 433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
prompt_dataset = build_prompt_dataset(tokenizer)
|
|
|
|
|
|
|
| 435 |
|
| 436 |
# Sanity-check reward function
|
| 437 |
test_r = orgos_reward_fn(
|
| 438 |
-
completions
|
| 439 |
-
|
| 440 |
-
prompts
|
| 441 |
-
workflow_id
|
| 442 |
)
|
| 443 |
tlog(f"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}")
|
| 444 |
|
| 445 |
# Baseline evaluation
|
| 446 |
FastLanguageModel.for_inference(model)
|
| 447 |
baseline_scores = evaluate(model, tokenizer, phase="baseline")
|
| 448 |
-
baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])
|
| 449 |
|
| 450 |
# GRPO training
|
| 451 |
-
|
| 452 |
-
tlog(f"[TRAIN_CONFIG]
|
| 453 |
f"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} "
|
| 454 |
-
f"temperature={TEMPERATURE} beta_kl={BETA}")
|
| 455 |
|
| 456 |
grpo_config = GRPOConfig(
|
| 457 |
output_dir = "./orgos_grpo_ckpt",
|
| 458 |
-
num_train_epochs =
|
|
|
|
| 459 |
per_device_train_batch_size = BATCH_SIZE,
|
| 460 |
gradient_accumulation_steps = GRAD_ACCUM,
|
| 461 |
learning_rate = LR,
|
| 462 |
warmup_steps = 10,
|
| 463 |
logging_steps = 5,
|
| 464 |
-
save_steps = 100,
|
| 465 |
bf16 = torch.cuda.is_bf16_supported(),
|
| 466 |
fp16 = not torch.cuda.is_bf16_supported(),
|
| 467 |
max_grad_norm = 1.0,
|
| 468 |
num_generations = NUM_GEN,
|
|
|
|
| 469 |
temperature = TEMPERATURE,
|
| 470 |
beta = BETA,
|
| 471 |
report_to = "none",
|
|
@@ -475,7 +525,7 @@ def main():
|
|
| 475 |
trainer = GRPOTrainer(
|
| 476 |
model = model,
|
| 477 |
args = grpo_config,
|
| 478 |
-
reward_funcs = orgos_reward_fn,
|
| 479 |
train_dataset = prompt_dataset,
|
| 480 |
processing_class = tokenizer,
|
| 481 |
callbacks = [OrgOSLogCallback()],
|
|
@@ -490,7 +540,7 @@ def main():
|
|
| 490 |
# Post-training evaluation
|
| 491 |
FastLanguageModel.for_inference(model)
|
| 492 |
post_scores = evaluate(model, tokenizer, phase="post_training")
|
| 493 |
-
post_mean = np.mean([s for v in post_scores.values() for s in v])
|
| 494 |
improvement = post_mean - baseline_mean
|
| 495 |
|
| 496 |
tlog(
|
|
|
|
| 1 |
"""
|
| 2 |
OrgOS GRPO Training Script
|
| 3 |
+
Runs headlessly on HuggingFace Spaces (A100/T4 GPU).
|
| 4 |
|
| 5 |
Outputs:
|
| 6 |
training_log.txt — structured training log for submission
|
|
|
|
| 35 |
# Config
|
| 36 |
# ------------------------------------------------------------------
|
| 37 |
|
| 38 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct-bnb-4bit")
|
| 39 |
ENV_URL = "http://localhost:8000"
|
| 40 |
LOG_FILE = "training_log.txt"
|
| 41 |
N_PROMPTS_PER_WORKFLOW = 20
|
| 42 |
N_EVAL = 10
|
| 43 |
+
MAX_TRAIN_STEPS = 150 # step-based training (more reliable than epoch-based on Spaces)
|
| 44 |
+
BATCH_SIZE = 1
|
| 45 |
GRAD_ACCUM = 2
|
| 46 |
+
LR = 8e-6 # stable LR — 5e-5 was too high
|
| 47 |
+
NUM_GEN = 2 # candidates per prompt — keep low to save VRAM
|
| 48 |
+
TEMPERATURE = 0.9
|
| 49 |
BETA = 0.04
|
| 50 |
LORA_R = 16
|
| 51 |
MAX_SEQ_LEN = 4096
|
| 52 |
+
MAX_COMPLETION_LENGTH = 256
|
| 53 |
+
REWARD_STEPS = 2 # multi-step rollout depth in reward fn
|
| 54 |
|
| 55 |
# ------------------------------------------------------------------
|
| 56 |
# Logger
|
|
|
|
| 79 |
stdout=None,
|
| 80 |
stderr=None,
|
| 81 |
)
|
|
|
|
| 82 |
for _ in range(20):
|
| 83 |
time.sleep(2)
|
| 84 |
try:
|
|
|
|
| 113 |
use_gradient_checkpointing = "unsloth",
|
| 114 |
random_state = 42,
|
| 115 |
)
|
| 116 |
+
# Clear max_length to avoid max_new_tokens vs max_length warnings during generate()
|
| 117 |
+
model.config.max_length = None
|
| 118 |
+
|
| 119 |
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 120 |
tlog(f"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} "
|
| 121 |
f"max_seq_len={MAX_SEQ_LEN} trainable_params={trainable:,} quantization=4bit")
|
|
|
|
| 164 |
6. Stop when pending_steps is empty or done=true.
|
| 165 |
"""
|
| 166 |
|
| 167 |
+
WORKFLOW_APPS = {
|
| 168 |
+
"A": {"jira", "zendesk", "salesforce", "workday"},
|
| 169 |
+
"B": {"zendesk", "salesforce", "workday"},
|
| 170 |
+
"C": {"jira", "zendesk", "salesforce"},
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
|
| 174 |
def obs_to_text(obs: dict) -> str:
|
| 175 |
hints = obs.get("schema_hints", {})
|
|
|
|
| 196 |
"",
|
| 197 |
"=== APP STATES ===",
|
| 198 |
]
|
| 199 |
+
relevant = WORKFLOW_APPS.get(obs.get("workflow_id", "A"),
|
| 200 |
+
{"jira", "zendesk", "salesforce", "workday"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
for app_name, view in obs.get("app_states", {}).items():
|
| 202 |
if app_name not in relevant:
|
| 203 |
continue
|
|
|
|
| 204 |
view_str = str(view)
|
| 205 |
if len(view_str) > 600:
|
| 206 |
view_str = view_str[:600] + "...[truncated]"
|
| 207 |
+
lines += [f" [{app_name.upper()}]", f" {view_str}", ""]
|
|
|
|
| 208 |
return "\n".join(lines)
|
| 209 |
|
| 210 |
|
|
|
|
| 244 |
rows.append({
|
| 245 |
"prompt": build_prompt(obs_text, tokenizer),
|
| 246 |
"workflow_id": wf,
|
|
|
|
| 247 |
})
|
| 248 |
tlog(f"[TRAIN_CONFIG] algorithm=GRPO prompts={len(rows)} "
|
| 249 |
f"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}")
|
| 250 |
+
sample_tokens = None # set below after tokenizer is available
|
| 251 |
return Dataset.from_list(rows)
|
| 252 |
|
| 253 |
|
| 254 |
# ------------------------------------------------------------------
|
| 255 |
+
# Reward function — multi-step live environment rollout
|
| 256 |
# ------------------------------------------------------------------
|
| 257 |
+
# The model reference is set in main() before training starts.
|
| 258 |
+
_reward_model = None
|
| 259 |
+
_reward_tokenizer = None
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def orgos_reward_fn(completions: List[str], prompts: List[str] = None, **kwargs) -> List[float]:
|
| 263 |
+
"""
|
| 264 |
+
For each GRPO candidate:
|
| 265 |
+
1. Parse as JSON action.
|
| 266 |
+
2. Reset env and apply the action (step 1).
|
| 267 |
+
3. Continue REWARD_STEPS-1 more greedy steps with the current model.
|
| 268 |
+
4. Return cumulative episode score — not just single-step reward.
|
| 269 |
+
|
| 270 |
+
Multi-step signal prevents the model from collapsing to always outputting
|
| 271 |
+
list_tickets (which gives a small single-step reward but never advances the workflow).
|
| 272 |
+
"""
|
| 273 |
workflow_ids = kwargs.get("workflow_id", ["A"] * len(completions))
|
| 274 |
rewards = []
|
| 275 |
+
|
| 276 |
for completion, wf_id in zip(completions, workflow_ids):
|
| 277 |
action = parse_action(completion)
|
| 278 |
if action is None:
|
| 279 |
rewards.append(-0.1)
|
| 280 |
continue
|
| 281 |
try:
|
| 282 |
+
# Reset env and apply the GRPO-generated action (step 1)
|
| 283 |
+
obs = httpx.post(f"{ENV_URL}/reset",
|
| 284 |
+
json={"workflow_id": wf_id}, timeout=10).json()["observation"]
|
| 285 |
result = httpx.post(f"{ENV_URL}/step", json=action, timeout=10).json()
|
| 286 |
+
obs = result["observation"]
|
| 287 |
+
|
| 288 |
+
# Continue REWARD_STEPS-1 more steps with current model (greedy)
|
| 289 |
+
if _reward_model is not None:
|
| 290 |
+
for _ in range(REWARD_STEPS - 1):
|
| 291 |
+
if obs.get("done"):
|
| 292 |
+
break
|
| 293 |
+
prompt_text = build_prompt(obs_to_text(obs), _reward_tokenizer)
|
| 294 |
+
inputs = _reward_tokenizer(
|
| 295 |
+
prompt_text, return_tensors="pt"
|
| 296 |
+
).to(_reward_model.device)
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
out = _reward_model.generate(
|
| 299 |
+
**inputs,
|
| 300 |
+
max_new_tokens = 128,
|
| 301 |
+
do_sample = False,
|
| 302 |
+
pad_token_id = _reward_tokenizer.eos_token_id,
|
| 303 |
+
)
|
| 304 |
+
cont_str = _reward_tokenizer.decode(
|
| 305 |
+
out[0][inputs["input_ids"].shape[1]:],
|
| 306 |
+
skip_special_tokens=True,
|
| 307 |
+
).strip()
|
| 308 |
+
cont_action = parse_action(cont_str)
|
| 309 |
+
if cont_action is None:
|
| 310 |
+
break
|
| 311 |
+
result = httpx.post(f"{ENV_URL}/step",
|
| 312 |
+
json=cont_action, timeout=10).json()
|
| 313 |
+
obs = result["observation"]
|
| 314 |
+
|
| 315 |
+
rewards.append(float(obs.get("current_score", 0.001)))
|
| 316 |
except Exception:
|
| 317 |
rewards.append(-0.1)
|
| 318 |
+
|
| 319 |
return rewards
|
| 320 |
|
| 321 |
|
| 322 |
# ------------------------------------------------------------------
|
| 323 |
+
# Episode evaluation (stateless — each step is a fresh single-turn prompt)
|
| 324 |
# ------------------------------------------------------------------
|
| 325 |
|
| 326 |
def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int = 15) -> float:
|
|
|
|
| 331 |
if obs["done"]:
|
| 332 |
break
|
| 333 |
|
|
|
|
|
|
|
| 334 |
obs_text = obs_to_text(obs)
|
| 335 |
+
text = build_prompt(obs_text, tokenizer)
|
| 336 |
+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
with torch.no_grad():
|
| 339 |
out = model.generate(
|
| 340 |
**inputs,
|
| 341 |
max_new_tokens = 256,
|
|
|
|
| 342 |
do_sample = False,
|
| 343 |
pad_token_id = tokenizer.eos_token_id,
|
| 344 |
)
|
|
|
|
| 355 |
if obs["done"]:
|
| 356 |
break
|
| 357 |
|
| 358 |
+
return float(obs.get("current_score", 0.001))
|
| 359 |
|
| 360 |
|
| 361 |
def evaluate(model, tokenizer, phase: str) -> dict:
|
|
|
|
| 366 |
score = run_episode_with_model(model, tokenizer, wf)
|
| 367 |
scores[wf].append(score)
|
| 368 |
tlog(f"[EVAL] phase={phase} workflow={wf} episode={ep+1} score={score:.4f}")
|
| 369 |
+
wf_mean = float(np.mean(scores[wf]))
|
| 370 |
tlog(f"[EVAL_WORKFLOW] phase={phase} workflow={wf} "
|
| 371 |
f"mean={wf_mean:.4f} min={min(scores[wf]):.4f} max={max(scores[wf]):.4f}")
|
| 372 |
+
overall = float(np.mean([s for v in scores.values() for s in v]))
|
| 373 |
tlog(f"[EVAL_END] phase={phase} overall_mean={overall:.4f}")
|
| 374 |
return scores
|
| 375 |
|
|
|
|
| 467 |
# ------------------------------------------------------------------
|
| 468 |
|
| 469 |
def main():
|
| 470 |
+
global _reward_model, _reward_tokenizer
|
| 471 |
+
|
| 472 |
server_proc = start_env_server()
|
| 473 |
|
| 474 |
try:
|
| 475 |
model, tokenizer = load_model()
|
| 476 |
|
| 477 |
+
# Wire up the reward function's model reference (used for multi-step rollouts)
|
| 478 |
+
_reward_model = model
|
| 479 |
+
_reward_tokenizer = tokenizer
|
| 480 |
+
|
| 481 |
prompt_dataset = build_prompt_dataset(tokenizer)
|
| 482 |
+
tok_len = len(tokenizer(prompt_dataset[0]["prompt"]).input_ids)
|
| 483 |
+
tlog(f"[PROMPT_DEBUG] first_prompt_tokens={tok_len}")
|
| 484 |
|
| 485 |
# Sanity-check reward function
|
| 486 |
test_r = orgos_reward_fn(
|
| 487 |
+
completions = ['{"app": "zendesk", "operation": "list_tickets", "args": {}}',
|
| 488 |
+
"not json"],
|
| 489 |
+
prompts = ["", ""],
|
| 490 |
+
workflow_id = ["A", "A"],
|
| 491 |
)
|
| 492 |
tlog(f"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}")
|
| 493 |
|
| 494 |
# Baseline evaluation
|
| 495 |
FastLanguageModel.for_inference(model)
|
| 496 |
baseline_scores = evaluate(model, tokenizer, phase="baseline")
|
| 497 |
+
baseline_mean = float(np.mean([s for v in baseline_scores.values() for s in v]))
|
| 498 |
|
| 499 |
# GRPO training
|
| 500 |
+
FastLanguageModel.for_training(model)
|
| 501 |
+
tlog(f"[TRAIN_CONFIG] max_steps={MAX_TRAIN_STEPS} batch_size={BATCH_SIZE} "
|
| 502 |
f"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} "
|
| 503 |
+
f"temperature={TEMPERATURE} beta_kl={BETA} reward_steps={REWARD_STEPS}")
|
| 504 |
|
| 505 |
grpo_config = GRPOConfig(
|
| 506 |
output_dir = "./orgos_grpo_ckpt",
|
| 507 |
+
num_train_epochs = 1,
|
| 508 |
+
max_steps = MAX_TRAIN_STEPS,
|
| 509 |
per_device_train_batch_size = BATCH_SIZE,
|
| 510 |
gradient_accumulation_steps = GRAD_ACCUM,
|
| 511 |
learning_rate = LR,
|
| 512 |
warmup_steps = 10,
|
| 513 |
logging_steps = 5,
|
|
|
|
| 514 |
bf16 = torch.cuda.is_bf16_supported(),
|
| 515 |
fp16 = not torch.cuda.is_bf16_supported(),
|
| 516 |
max_grad_norm = 1.0,
|
| 517 |
num_generations = NUM_GEN,
|
| 518 |
+
max_new_tokens = MAX_COMPLETION_LENGTH,
|
| 519 |
temperature = TEMPERATURE,
|
| 520 |
beta = BETA,
|
| 521 |
report_to = "none",
|
|
|
|
| 525 |
trainer = GRPOTrainer(
|
| 526 |
model = model,
|
| 527 |
args = grpo_config,
|
| 528 |
+
reward_funcs = [orgos_reward_fn],
|
| 529 |
train_dataset = prompt_dataset,
|
| 530 |
processing_class = tokenizer,
|
| 531 |
callbacks = [OrgOSLogCallback()],
|
|
|
|
| 540 |
# Post-training evaluation
|
| 541 |
FastLanguageModel.for_inference(model)
|
| 542 |
post_scores = evaluate(model, tokenizer, phase="post_training")
|
| 543 |
+
post_mean = float(np.mean([s for v in post_scores.values() for s in v]))
|
| 544 |
improvement = post_mean - baseline_mean
|
| 545 |
|
| 546 |
tlog(
|