Spaces:
Paused
Paused
siddeshwar-kagatikar commited on
Commit ·
4aca4f5
1
Parent(s): 3e893cd
feat(training): improve self-play progress visibility and reward diagnostics
Browse filesBatch generation/sampling and add explicit reward instrumentation so HF Space runs do not appear stalled and answerer reward signals remain continuously observable in logs and W&B.
Made-with: Cursor
- src/osint_env/training/rewards.py +58 -1
- src/osint_env/training/self_play.py +255 -62
src/osint_env/training/rewards.py
CHANGED
|
@@ -1094,6 +1094,22 @@ class AnswererRewardFunction:
|
|
| 1094 |
self.reward_model = build_reward_model(graph)
|
| 1095 |
self.pipeline_mode = str(pipeline_mode).strip().lower() or "legacy"
|
| 1096 |
self.parl_max_parallel_hint = max(0, int(parl_max_parallel_hint or 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1097 |
|
| 1098 |
@staticmethod
|
| 1099 |
def _parse_support_edges(value: Any) -> list[Edge]:
|
|
@@ -1172,6 +1188,8 @@ class AnswererRewardFunction:
|
|
| 1172 |
**kwargs: Any,
|
| 1173 |
) -> list[float]:
|
| 1174 |
rewards: list[float] = []
|
|
|
|
|
|
|
| 1175 |
|
| 1176 |
for idx, completion in enumerate(completions):
|
| 1177 |
completion_text = decode_completion_text(completion)
|
|
@@ -1204,6 +1222,45 @@ class AnswererRewardFunction:
|
|
| 1204 |
model=self.reward_model,
|
| 1205 |
difficulty=difficulty_level,
|
| 1206 |
)
|
| 1207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1208 |
|
| 1209 |
return rewards
|
|
|
|
| 1094 |
self.reward_model = build_reward_model(graph)
|
| 1095 |
self.pipeline_mode = str(pipeline_mode).strip().lower() or "legacy"
|
| 1096 |
self.parl_max_parallel_hint = max(0, int(parl_max_parallel_hint or 0))
|
| 1097 |
+
# Mirror GeneratorRewardFunction observability: TRL's GRPOTrainer
|
| 1098 |
+
# already logs `rewards/AnswererRewardFunction/{mean,std}` to W&B
|
| 1099 |
+
# at every `logging_steps`, but we ALSO publish a per-batch debug
|
| 1100 |
+
# snapshot so the [reward_debug][last_batch] line appears in stdout
|
| 1101 |
+
# for the answerer phase, exactly like it does for the generator.
|
| 1102 |
+
self._debug_batches_seen = 0
|
| 1103 |
+
self._debug_reward_window: list[float] = []
|
| 1104 |
+
self._debug_last_batch: dict[str, Any] = {}
|
| 1105 |
+
|
| 1106 |
+
@staticmethod
|
| 1107 |
+
def _std(values: list[float]) -> float:
|
| 1108 |
+
if len(values) <= 1:
|
| 1109 |
+
return 0.0
|
| 1110 |
+
mean = sum(values) / len(values)
|
| 1111 |
+
variance = sum((value - mean) ** 2 for value in values) / len(values)
|
| 1112 |
+
return variance ** 0.5
|
| 1113 |
|
| 1114 |
@staticmethod
|
| 1115 |
def _parse_support_edges(value: Any) -> list[Edge]:
|
|
|
|
| 1188 |
**kwargs: Any,
|
| 1189 |
) -> list[float]:
|
| 1190 |
rewards: list[float] = []
|
| 1191 |
+
success_count = 0
|
| 1192 |
+
graph_f1_sum = 0.0
|
| 1193 |
|
| 1194 |
for idx, completion in enumerate(completions):
|
| 1195 |
completion_text = decode_completion_text(completion)
|
|
|
|
| 1222 |
model=self.reward_model,
|
| 1223 |
difficulty=difficulty_level,
|
| 1224 |
)
|
| 1225 |
+
final_reward = self._extract_orchestrator_reward(completion_text, breakdown.total)
|
| 1226 |
+
rewards.append(final_reward)
|
| 1227 |
+
|
| 1228 |
+
if predicted_answer and target_answer and normalize_answer(predicted_answer) == target_answer:
|
| 1229 |
+
success_count += 1
|
| 1230 |
+
graph_f1_sum += float(getattr(breakdown, "graph_f1", 0.0) or 0.0)
|
| 1231 |
+
|
| 1232 |
+
# Mirror GeneratorRewardFunction debug surface so the answerer reward
|
| 1233 |
+
# is visible to the same downstream tooling (printed by
|
| 1234 |
+
# `_train_grpo_phase` and forwarded to W&B by TRL).
|
| 1235 |
+
self._debug_batches_seen += 1
|
| 1236 |
+
self._debug_reward_window.extend(rewards)
|
| 1237 |
+
self._debug_reward_window = self._debug_reward_window[-512:]
|
| 1238 |
+
batch_size = max(1, len(rewards))
|
| 1239 |
+
batch_mean = float(sum(rewards) / batch_size)
|
| 1240 |
+
batch_std = float(self._std(rewards))
|
| 1241 |
+
advantages = [float(value - batch_mean) for value in rewards]
|
| 1242 |
+
self._debug_last_batch = {
|
| 1243 |
+
"batch_rewards": list(rewards),
|
| 1244 |
+
"batch_reward_mean": batch_mean,
|
| 1245 |
+
"batch_reward_std": batch_std,
|
| 1246 |
+
"advantage_proxy_min": min(advantages) if advantages else 0.0,
|
| 1247 |
+
"advantage_proxy_max": max(advantages) if advantages else 0.0,
|
| 1248 |
+
"advantage_proxy_std": float(self._std(advantages)),
|
| 1249 |
+
"exact_match_count": int(success_count),
|
| 1250 |
+
"exact_match_ratio": float(success_count / batch_size),
|
| 1251 |
+
"avg_graph_f1": float(graph_f1_sum / batch_size),
|
| 1252 |
+
}
|
| 1253 |
+
if self._debug_batches_seen % 10 == 0:
|
| 1254 |
+
window_std = self._std(self._debug_reward_window)
|
| 1255 |
+
print(
|
| 1256 |
+
"[reward_debug][answerer] "
|
| 1257 |
+
f"batches={self._debug_batches_seen} "
|
| 1258 |
+
f"window_reward_std={window_std:.6f} "
|
| 1259 |
+
f"last_batch_mean={batch_mean:.6f} "
|
| 1260 |
+
f"last_batch_std={batch_std:.6f} "
|
| 1261 |
+
f"exact_match_ratio={self._debug_last_batch['exact_match_ratio']:.3f} "
|
| 1262 |
+
f"avg_graph_f1={self._debug_last_batch['avg_graph_f1']:.3f}",
|
| 1263 |
+
flush=True,
|
| 1264 |
+
)
|
| 1265 |
|
| 1266 |
return rewards
|
src/osint_env/training/self_play.py
CHANGED
|
@@ -4,6 +4,7 @@ import inspect
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import re
|
|
|
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
import random
|
|
@@ -753,7 +754,21 @@ def _train_grpo_phase(
|
|
| 753 |
trainer_kwargs["peft_config"] = _build_lora_config(lora)
|
| 754 |
|
| 755 |
phase_label = str(run_name).strip() or str(output_dir.name)
|
| 756 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
|
| 758 |
trainer = GRPOTrainer(**trainer_kwargs)
|
| 759 |
tracked_params = [
|
|
@@ -873,11 +888,35 @@ def _train_grpo_phase(
|
|
| 873 |
|
| 874 |
reward_debug = getattr(reward_function, "_debug_last_batch", None)
|
| 875 |
if isinstance(reward_debug, dict):
|
| 876 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
|
| 878 |
print(
|
| 879 |
"[self_play] Finished phase: "
|
| 880 |
-
f"{phase_label} global_step={global_step} training_loss={training_loss} output={final_dir}"
|
|
|
|
| 881 |
)
|
| 882 |
return result
|
| 883 |
|
|
@@ -956,30 +995,52 @@ def _sample_generated_tasks_with_model(
|
|
| 956 |
count: int,
|
| 957 |
max_support_edges: int,
|
| 958 |
max_new_tokens: int,
|
|
|
|
| 959 |
) -> list[TaskInstance]:
|
| 960 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 961 |
import torch
|
| 962 |
|
| 963 |
-
if count <= 0:
|
| 964 |
return []
|
| 965 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 966 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 967 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 968 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 969 |
model_kwargs: dict[str, Any] = {}
|
| 970 |
if torch.cuda.is_available():
|
| 971 |
model_kwargs["device_map"] = "auto"
|
| 972 |
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 973 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 974 |
model.eval()
|
| 975 |
-
|
| 976 |
device = next(model.parameters()).device
|
| 977 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
|
| 979 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
if len(generated) >= count:
|
| 981 |
break
|
| 982 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
encoded = {k: v.to(device) for k, v in encoded.items()}
|
| 984 |
|
| 985 |
with torch.no_grad():
|
|
@@ -993,35 +1054,52 @@ def _sample_generated_tasks_with_model(
|
|
| 993 |
pad_token_id=tokenizer.eos_token_id,
|
| 994 |
)
|
| 995 |
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
|
|
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
|
|
|
| 1022 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
)
|
| 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1025 |
return generated
|
| 1026 |
|
| 1027 |
|
|
@@ -1083,13 +1161,25 @@ def _generate_answerer_completion_texts_with_model(
|
|
| 1083 |
model_name_or_path: str,
|
| 1084 |
prompts: list[str],
|
| 1085 |
max_new_tokens: int,
|
|
|
|
| 1086 |
) -> list[str]:
|
| 1087 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1088 |
import torch
|
| 1089 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1090 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 1091 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1092 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 1093 |
|
| 1094 |
model_kwargs: dict[str, Any] = {}
|
| 1095 |
if torch.cuda.is_available():
|
|
@@ -1098,10 +1188,24 @@ def _generate_answerer_completion_texts_with_model(
|
|
| 1098 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 1099 |
model.eval()
|
| 1100 |
device = next(model.parameters()).device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1101 |
|
| 1102 |
completions: list[str] = []
|
| 1103 |
-
|
| 1104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
encoded = {key: value.to(device) for key, value in encoded.items()}
|
| 1106 |
with torch.no_grad():
|
| 1107 |
output = model.generate(
|
|
@@ -1110,8 +1214,22 @@ def _generate_answerer_completion_texts_with_model(
|
|
| 1110 |
do_sample=False,
|
| 1111 |
pad_token_id=tokenizer.eos_token_id,
|
| 1112 |
)
|
| 1113 |
-
|
| 1114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
return completions
|
| 1116 |
|
| 1117 |
|
|
@@ -1404,40 +1522,74 @@ def _sample_swarm_v2_completion_texts_with_model(
|
|
| 1404 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1405 |
import torch
|
| 1406 |
|
| 1407 |
-
if count <= 0:
|
| 1408 |
return []
|
| 1409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1410 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 1411 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1412 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 1413 |
model_kwargs: dict[str, Any] = {}
|
| 1414 |
if torch.cuda.is_available():
|
| 1415 |
model_kwargs["device_map"] = "auto"
|
| 1416 |
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 1417 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 1418 |
model.eval()
|
| 1419 |
-
|
| 1420 |
device = next(model.parameters()).device
|
| 1421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1422 |
validator = SwarmV2ReplayValidator(
|
| 1423 |
graph=env.graph,
|
| 1424 |
validation=cfg.swarm_v2.validation,
|
| 1425 |
shared_context=cfg.swarm_v2.shared_context,
|
| 1426 |
seen_questions=seen_questions,
|
| 1427 |
)
|
| 1428 |
-
|
| 1429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1430 |
break
|
| 1431 |
-
|
| 1432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1433 |
|
| 1434 |
-
best_completion = ""
|
| 1435 |
-
best_score = -999
|
| 1436 |
-
for attempt_idx, (temperature, top_p) in enumerate([(0.7, 0.9), (0.5, 0.85), (0.3, 0.8)]):
|
| 1437 |
with torch.no_grad():
|
| 1438 |
output = model.generate(
|
| 1439 |
**encoded,
|
| 1440 |
-
max_new_tokens=
|
| 1441 |
do_sample=True,
|
| 1442 |
top_p=top_p,
|
| 1443 |
temperature=temperature,
|
|
@@ -1445,22 +1597,63 @@ def _sample_swarm_v2_completion_texts_with_model(
|
|
| 1445 |
pad_token_id=tokenizer.eos_token_id,
|
| 1446 |
)
|
| 1447 |
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
completion,
|
| 1452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1453 |
)
|
| 1454 |
-
|
| 1455 |
-
score = int(bool(candidate.question)) + int(bool(candidate.answer)) + len(candidate.supporting_edges)
|
| 1456 |
-
if validation.is_valid:
|
| 1457 |
-
print(f"[self_play][generation_retry] valid_completion attempt={attempt_idx + 1}")
|
| 1458 |
-
best_completion = completion
|
| 1459 |
break
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1464 |
return completions
|
| 1465 |
|
| 1466 |
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import re
|
| 7 |
+
import time
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from pathlib import Path
|
| 10 |
import random
|
|
|
|
| 754 |
trainer_kwargs["peft_config"] = _build_lora_config(lora)
|
| 755 |
|
| 756 |
phase_label = str(run_name).strip() or str(output_dir.name)
|
| 757 |
+
reward_class_name = type(reward_function).__name__
|
| 758 |
+
print(
|
| 759 |
+
f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
|
| 760 |
+
f"max_steps={phase.max_steps}",
|
| 761 |
+
flush=True,
|
| 762 |
+
)
|
| 763 |
+
print(
|
| 764 |
+
f"[self_play][reward_setup] phase={phase_label} "
|
| 765 |
+
f"reward_function={reward_class_name} "
|
| 766 |
+
f"wandb_metric=rewards/{reward_class_name}/mean "
|
| 767 |
+
f"logging_steps={phase.logging_steps} "
|
| 768 |
+
f"num_generations={phase.num_generations} "
|
| 769 |
+
f"per_device_train_batch_size={phase.per_device_train_batch_size}",
|
| 770 |
+
flush=True,
|
| 771 |
+
)
|
| 772 |
strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
|
| 773 |
trainer = GRPOTrainer(**trainer_kwargs)
|
| 774 |
tracked_params = [
|
|
|
|
| 888 |
|
| 889 |
reward_debug = getattr(reward_function, "_debug_last_batch", None)
|
| 890 |
if isinstance(reward_debug, dict):
|
| 891 |
+
print(
|
| 892 |
+
f"[reward_debug][last_batch] {phase_label} reward_function={reward_class_name} "
|
| 893 |
+
f"{json.dumps(reward_debug, sort_keys=True)}",
|
| 894 |
+
flush=True,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
if reward_values:
|
| 898 |
+
print(
|
| 899 |
+
f"[self_play][reward_history] {phase_label} reward_function={reward_class_name} "
|
| 900 |
+
f"steps_logged={len(reward_values)} "
|
| 901 |
+
f"reward_first={reward_values[0]:.6f} "
|
| 902 |
+
f"reward_last={reward_values[-1]:.6f} "
|
| 903 |
+
f"reward_mean={(sum(reward_values) / len(reward_values)):.6f} "
|
| 904 |
+
f"reward_min={min(reward_values):.6f} "
|
| 905 |
+
f"reward_max={max(reward_values):.6f} "
|
| 906 |
+
f"wandb_metric=rewards/{reward_class_name}/mean",
|
| 907 |
+
flush=True,
|
| 908 |
+
)
|
| 909 |
+
else:
|
| 910 |
+
print(
|
| 911 |
+
f"[self_play][reward_history] {phase_label} reward_function={reward_class_name} "
|
| 912 |
+
"no_reward_logs_in_state (TRL never wrote a 'reward' field; check logging_steps / num_generations)",
|
| 913 |
+
flush=True,
|
| 914 |
+
)
|
| 915 |
|
| 916 |
print(
|
| 917 |
"[self_play] Finished phase: "
|
| 918 |
+
f"{phase_label} global_step={global_step} training_loss={training_loss} output={final_dir}",
|
| 919 |
+
flush=True,
|
| 920 |
)
|
| 921 |
return result
|
| 922 |
|
|
|
|
| 995 |
count: int,
|
| 996 |
max_support_edges: int,
|
| 997 |
max_new_tokens: int,
|
| 998 |
+
batch_size: int = 4,
|
| 999 |
) -> list[TaskInstance]:
|
| 1000 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1001 |
import torch
|
| 1002 |
|
| 1003 |
+
if count <= 0 or not prompts:
|
| 1004 |
return []
|
| 1005 |
|
| 1006 |
+
print(
|
| 1007 |
+
f"[self_play][sample_generator_legacy] start model={model_name_or_path} "
|
| 1008 |
+
f"prompts={len(prompts)} target_valid={count} max_new_tokens={max_new_tokens}",
|
| 1009 |
+
flush=True,
|
| 1010 |
+
)
|
| 1011 |
+
load_start = time.monotonic()
|
| 1012 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 1013 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1014 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1015 |
+
if getattr(tokenizer, "padding_side", "right") != "left":
|
| 1016 |
+
tokenizer.padding_side = "left"
|
| 1017 |
model_kwargs: dict[str, Any] = {}
|
| 1018 |
if torch.cuda.is_available():
|
| 1019 |
model_kwargs["device_map"] = "auto"
|
| 1020 |
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 1021 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 1022 |
model.eval()
|
|
|
|
| 1023 |
device = next(model.parameters()).device
|
| 1024 |
+
print(
|
| 1025 |
+
f"[self_play][sample_generator_legacy] model_loaded device={device} "
|
| 1026 |
+
f"load_elapsed={time.monotonic() - load_start:.1f}s",
|
| 1027 |
+
flush=True,
|
| 1028 |
+
)
|
| 1029 |
|
| 1030 |
+
generated: list[TaskInstance] = []
|
| 1031 |
+
overall_start = time.monotonic()
|
| 1032 |
+
effective_batch = max(1, int(batch_size or 1))
|
| 1033 |
+
processed = 0
|
| 1034 |
+
for batch_start in range(0, len(prompts), effective_batch):
|
| 1035 |
if len(generated) >= count:
|
| 1036 |
break
|
| 1037 |
+
batch_prompts = prompts[batch_start : batch_start + effective_batch]
|
| 1038 |
+
encoded = tokenizer(
|
| 1039 |
+
batch_prompts,
|
| 1040 |
+
return_tensors="pt",
|
| 1041 |
+
padding=True,
|
| 1042 |
+
truncation=True,
|
| 1043 |
+
)
|
| 1044 |
encoded = {k: v.to(device) for k, v in encoded.items()}
|
| 1045 |
|
| 1046 |
with torch.no_grad():
|
|
|
|
| 1054 |
pad_token_id=tokenizer.eos_token_id,
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
+
input_len = encoded["input_ids"].shape[1]
|
| 1058 |
+
for row_offset in range(len(batch_prompts)):
|
| 1059 |
+
completion_ids = output[row_offset][input_len:]
|
| 1060 |
+
completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
|
| 1061 |
+
candidate = parse_generated_task_completion(completion, max_support_edges=max_support_edges)
|
| 1062 |
+
processed += 1
|
| 1063 |
+
if not candidate.is_valid:
|
| 1064 |
+
continue
|
| 1065 |
|
| 1066 |
+
metadata = {
|
| 1067 |
+
"generated_by": "generator_model",
|
| 1068 |
+
"round": round_index,
|
| 1069 |
+
"difficulty": "hard",
|
| 1070 |
+
"scenario": "adversarial_trace",
|
| 1071 |
+
"grader": {
|
| 1072 |
+
"type": "difficulty_exact_match",
|
| 1073 |
+
"answer_type": "node_id",
|
| 1074 |
+
"case_sensitive": True,
|
| 1075 |
+
"reward_profile": "hard",
|
| 1076 |
+
},
|
| 1077 |
+
}
|
| 1078 |
+
generated.append(
|
| 1079 |
+
TaskInstance(
|
| 1080 |
+
task_id=f"adv_r{round_index}_{len(generated)}",
|
| 1081 |
+
task_type=candidate.task_type,
|
| 1082 |
+
question=candidate.question,
|
| 1083 |
+
answer=candidate.answer,
|
| 1084 |
+
supporting_edges=list(candidate.supporting_edges),
|
| 1085 |
+
metadata=metadata,
|
| 1086 |
+
)
|
| 1087 |
)
|
| 1088 |
+
if len(generated) >= count:
|
| 1089 |
+
break
|
| 1090 |
+
|
| 1091 |
+
print(
|
| 1092 |
+
f"[self_play][sample_generator_legacy] processed={processed}/{len(prompts)} "
|
| 1093 |
+
f"valid={len(generated)}/{count} "
|
| 1094 |
+
f"elapsed={time.monotonic() - overall_start:.1f}s",
|
| 1095 |
+
flush=True,
|
| 1096 |
)
|
| 1097 |
|
| 1098 |
+
print(
|
| 1099 |
+
f"[self_play][sample_generator_legacy] finished generated={len(generated)}/{count} "
|
| 1100 |
+
f"total_elapsed={time.monotonic() - overall_start:.1f}s",
|
| 1101 |
+
flush=True,
|
| 1102 |
+
)
|
| 1103 |
return generated
|
| 1104 |
|
| 1105 |
|
|
|
|
| 1161 |
model_name_or_path: str,
|
| 1162 |
prompts: list[str],
|
| 1163 |
max_new_tokens: int,
|
| 1164 |
+
batch_size: int = 4,
|
| 1165 |
) -> list[str]:
|
| 1166 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1167 |
import torch
|
| 1168 |
|
| 1169 |
+
if not prompts:
|
| 1170 |
+
return []
|
| 1171 |
+
|
| 1172 |
+
print(
|
| 1173 |
+
f"[self_play][sample_answerer] start model={model_name_or_path} "
|
| 1174 |
+
f"prompts={len(prompts)} max_new_tokens={max_new_tokens}",
|
| 1175 |
+
flush=True,
|
| 1176 |
+
)
|
| 1177 |
+
load_start = time.monotonic()
|
| 1178 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 1179 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1180 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1181 |
+
if getattr(tokenizer, "padding_side", "right") != "left":
|
| 1182 |
+
tokenizer.padding_side = "left"
|
| 1183 |
|
| 1184 |
model_kwargs: dict[str, Any] = {}
|
| 1185 |
if torch.cuda.is_available():
|
|
|
|
| 1188 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 1189 |
model.eval()
|
| 1190 |
device = next(model.parameters()).device
|
| 1191 |
+
print(
|
| 1192 |
+
f"[self_play][sample_answerer] model_loaded device={device} "
|
| 1193 |
+
f"load_elapsed={time.monotonic() - load_start:.1f}s",
|
| 1194 |
+
flush=True,
|
| 1195 |
+
)
|
| 1196 |
|
| 1197 |
completions: list[str] = []
|
| 1198 |
+
overall_start = time.monotonic()
|
| 1199 |
+
effective_batch = max(1, int(batch_size or 1))
|
| 1200 |
+
processed = 0
|
| 1201 |
+
for batch_start in range(0, len(prompts), effective_batch):
|
| 1202 |
+
batch_prompts = prompts[batch_start : batch_start + effective_batch]
|
| 1203 |
+
encoded = tokenizer(
|
| 1204 |
+
batch_prompts,
|
| 1205 |
+
return_tensors="pt",
|
| 1206 |
+
padding=True,
|
| 1207 |
+
truncation=True,
|
| 1208 |
+
)
|
| 1209 |
encoded = {key: value.to(device) for key, value in encoded.items()}
|
| 1210 |
with torch.no_grad():
|
| 1211 |
output = model.generate(
|
|
|
|
| 1214 |
do_sample=False,
|
| 1215 |
pad_token_id=tokenizer.eos_token_id,
|
| 1216 |
)
|
| 1217 |
+
input_len = encoded["input_ids"].shape[1]
|
| 1218 |
+
for row_offset in range(len(batch_prompts)):
|
| 1219 |
+
completion_ids = output[row_offset][input_len:]
|
| 1220 |
+
completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
|
| 1221 |
+
processed += len(batch_prompts)
|
| 1222 |
+
print(
|
| 1223 |
+
f"[self_play][sample_answerer] processed={processed}/{len(prompts)} "
|
| 1224 |
+
f"elapsed={time.monotonic() - overall_start:.1f}s",
|
| 1225 |
+
flush=True,
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
print(
|
| 1229 |
+
f"[self_play][sample_answerer] finished completions={len(completions)} "
|
| 1230 |
+
f"total_elapsed={time.monotonic() - overall_start:.1f}s",
|
| 1231 |
+
flush=True,
|
| 1232 |
+
)
|
| 1233 |
return completions
|
| 1234 |
|
| 1235 |
|
|
|
|
| 1522 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1523 |
import torch
|
| 1524 |
|
| 1525 |
+
if count <= 0 or not prompts:
|
| 1526 |
return []
|
| 1527 |
|
| 1528 |
+
print(
|
| 1529 |
+
f"[self_play][sample_generator] start model={model_name_or_path} "
|
| 1530 |
+
f"prompts={len(prompts)} target_valid={count} "
|
| 1531 |
+
f"max_new_tokens={cfg.generated_task_max_new_tokens}",
|
| 1532 |
+
flush=True,
|
| 1533 |
+
)
|
| 1534 |
+
load_start = time.monotonic()
|
| 1535 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 1536 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1537 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1538 |
+
if getattr(tokenizer, "padding_side", "right") != "left":
|
| 1539 |
+
tokenizer.padding_side = "left"
|
| 1540 |
model_kwargs: dict[str, Any] = {}
|
| 1541 |
if torch.cuda.is_available():
|
| 1542 |
model_kwargs["device_map"] = "auto"
|
| 1543 |
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 1544 |
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 1545 |
model.eval()
|
|
|
|
| 1546 |
device = next(model.parameters()).device
|
| 1547 |
+
print(
|
| 1548 |
+
f"[self_play][sample_generator] model_loaded device={device} "
|
| 1549 |
+
f"load_elapsed={time.monotonic() - load_start:.1f}s",
|
| 1550 |
+
flush=True,
|
| 1551 |
+
)
|
| 1552 |
+
|
| 1553 |
validator = SwarmV2ReplayValidator(
|
| 1554 |
graph=env.graph,
|
| 1555 |
validation=cfg.swarm_v2.validation,
|
| 1556 |
shared_context=cfg.swarm_v2.shared_context,
|
| 1557 |
seen_questions=seen_questions,
|
| 1558 |
)
|
| 1559 |
+
completions: list[str] = []
|
| 1560 |
+
valid_count = 0
|
| 1561 |
+
batch_size = max(1, int(getattr(cfg.generator_phase, "generation_batch_size", 4) or 4))
|
| 1562 |
+
max_new_tokens = max(64, int(cfg.generated_task_max_new_tokens))
|
| 1563 |
+
decode_schedule = [(0.7, 0.9), (0.5, 0.85), (0.3, 0.8)]
|
| 1564 |
+
overall_start = time.monotonic()
|
| 1565 |
+
|
| 1566 |
+
pending_indices = list(range(len(prompts)))
|
| 1567 |
+
best_completions: dict[int, str] = {}
|
| 1568 |
+
best_scores: dict[int, int] = {}
|
| 1569 |
+
valid_marks: dict[int, bool] = {}
|
| 1570 |
+
|
| 1571 |
+
for attempt_idx, (temperature, top_p) in enumerate(decode_schedule):
|
| 1572 |
+
if not pending_indices:
|
| 1573 |
break
|
| 1574 |
+
attempt_start = time.monotonic()
|
| 1575 |
+
next_pending: list[int] = []
|
| 1576 |
+
processed = 0
|
| 1577 |
+
for batch_start in range(0, len(pending_indices), batch_size):
|
| 1578 |
+
batch_indices = pending_indices[batch_start : batch_start + batch_size]
|
| 1579 |
+
batch_prompts = [prompts[i] for i in batch_indices]
|
| 1580 |
+
encoded = tokenizer(
|
| 1581 |
+
batch_prompts,
|
| 1582 |
+
return_tensors="pt",
|
| 1583 |
+
padding=True,
|
| 1584 |
+
truncation=True,
|
| 1585 |
+
max_length=int(getattr(cfg.generator_phase, "max_prompt_length", 1024) or 1024),
|
| 1586 |
+
)
|
| 1587 |
+
encoded = {key: value.to(device) for key, value in encoded.items()}
|
| 1588 |
|
|
|
|
|
|
|
|
|
|
| 1589 |
with torch.no_grad():
|
| 1590 |
output = model.generate(
|
| 1591 |
**encoded,
|
| 1592 |
+
max_new_tokens=max_new_tokens,
|
| 1593 |
do_sample=True,
|
| 1594 |
top_p=top_p,
|
| 1595 |
temperature=temperature,
|
|
|
|
| 1597 |
pad_token_id=tokenizer.eos_token_id,
|
| 1598 |
)
|
| 1599 |
|
| 1600 |
+
input_len = encoded["input_ids"].shape[1]
|
| 1601 |
+
for row_offset, prompt_idx in enumerate(batch_indices):
|
| 1602 |
+
completion_ids = output[row_offset][input_len:]
|
| 1603 |
+
completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
|
| 1604 |
+
candidate = parse_generated_task_completion(
|
| 1605 |
+
completion,
|
| 1606 |
+
max_support_edges=cfg.swarm_v2.validation.max_support_edges,
|
| 1607 |
+
)
|
| 1608 |
+
validation = validator.validate(candidate)
|
| 1609 |
+
score = (
|
| 1610 |
+
int(bool(candidate.question))
|
| 1611 |
+
+ int(bool(candidate.answer))
|
| 1612 |
+
+ len(candidate.supporting_edges)
|
| 1613 |
+
)
|
| 1614 |
+
if validation.is_valid:
|
| 1615 |
+
if not valid_marks.get(prompt_idx):
|
| 1616 |
+
valid_count += 1
|
| 1617 |
+
valid_marks[prompt_idx] = True
|
| 1618 |
+
best_completions[prompt_idx] = completion
|
| 1619 |
+
best_scores[prompt_idx] = score
|
| 1620 |
+
else:
|
| 1621 |
+
if score > best_scores.get(prompt_idx, -999):
|
| 1622 |
+
best_scores[prompt_idx] = score
|
| 1623 |
+
best_completions[prompt_idx] = completion
|
| 1624 |
+
if not valid_marks.get(prompt_idx):
|
| 1625 |
+
next_pending.append(prompt_idx)
|
| 1626 |
+
|
| 1627 |
+
processed += len(batch_indices)
|
| 1628 |
+
print(
|
| 1629 |
+
f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
|
| 1630 |
+
f"processed={processed}/{len(pending_indices)} "
|
| 1631 |
+
f"valid_so_far={valid_count}/{len(prompts)} "
|
| 1632 |
+
f"target_valid={count} "
|
| 1633 |
+
f"elapsed={time.monotonic() - overall_start:.1f}s",
|
| 1634 |
+
flush=True,
|
| 1635 |
)
|
| 1636 |
+
if valid_count >= count:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1637 |
break
|
| 1638 |
+
print(
|
| 1639 |
+
f"[self_play][sample_generator] attempt={attempt_idx + 1} done "
|
| 1640 |
+
f"valid={valid_count}/{len(prompts)} "
|
| 1641 |
+
f"attempt_elapsed={time.monotonic() - attempt_start:.1f}s",
|
| 1642 |
+
flush=True,
|
| 1643 |
+
)
|
| 1644 |
+
if valid_count >= count:
|
| 1645 |
+
break
|
| 1646 |
+
pending_indices = next_pending
|
| 1647 |
+
|
| 1648 |
+
for prompt_idx in range(len(prompts)):
|
| 1649 |
+
completions.append(best_completions.get(prompt_idx, ""))
|
| 1650 |
+
|
| 1651 |
+
print(
|
| 1652 |
+
f"[self_play][sample_generator] finished completions={len(completions)} "
|
| 1653 |
+
f"valid={valid_count}/{len(prompts)} target_valid={count} "
|
| 1654 |
+
f"total_elapsed={time.monotonic() - overall_start:.1f}s",
|
| 1655 |
+
flush=True,
|
| 1656 |
+
)
|
| 1657 |
return completions
|
| 1658 |
|
| 1659 |
|