siddeshwar-kagatikar commited on
Commit
4aca4f5
·
1 Parent(s): 3e893cd

feat(training): improve self-play progress visibility and reward diagnostics

Browse files

Batch generation/sampling and add explicit reward instrumentation so HF Space runs do not appear stalled and answerer reward signals remain continuously observable in logs and W&B.

Made-with: Cursor

src/osint_env/training/rewards.py CHANGED
@@ -1094,6 +1094,22 @@ class AnswererRewardFunction:
1094
  self.reward_model = build_reward_model(graph)
1095
  self.pipeline_mode = str(pipeline_mode).strip().lower() or "legacy"
1096
  self.parl_max_parallel_hint = max(0, int(parl_max_parallel_hint or 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1097
 
1098
  @staticmethod
1099
  def _parse_support_edges(value: Any) -> list[Edge]:
@@ -1172,6 +1188,8 @@ class AnswererRewardFunction:
1172
  **kwargs: Any,
1173
  ) -> list[float]:
1174
  rewards: list[float] = []
 
 
1175
 
1176
  for idx, completion in enumerate(completions):
1177
  completion_text = decode_completion_text(completion)
@@ -1204,6 +1222,45 @@ class AnswererRewardFunction:
1204
  model=self.reward_model,
1205
  difficulty=difficulty_level,
1206
  )
1207
- rewards.append(self._extract_orchestrator_reward(completion_text, breakdown.total))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1208
 
1209
  return rewards
 
1094
  self.reward_model = build_reward_model(graph)
1095
  self.pipeline_mode = str(pipeline_mode).strip().lower() or "legacy"
1096
  self.parl_max_parallel_hint = max(0, int(parl_max_parallel_hint or 0))
1097
+ # Mirror GeneratorRewardFunction observability: TRL's GRPOTrainer
1098
+ # already logs `rewards/AnswererRewardFunction/{mean,std}` to W&B
1099
+ # at every `logging_steps`, but we ALSO publish a per-batch debug
1100
+ # snapshot so the [reward_debug][last_batch] line appears in stdout
1101
+ # for the answerer phase, exactly like it does for the generator.
1102
+ self._debug_batches_seen = 0
1103
+ self._debug_reward_window: list[float] = []
1104
+ self._debug_last_batch: dict[str, Any] = {}
1105
+
1106
+ @staticmethod
1107
+ def _std(values: list[float]) -> float:
1108
+ if len(values) <= 1:
1109
+ return 0.0
1110
+ mean = sum(values) / len(values)
1111
+ variance = sum((value - mean) ** 2 for value in values) / len(values)
1112
+ return variance ** 0.5
1113
 
1114
  @staticmethod
1115
  def _parse_support_edges(value: Any) -> list[Edge]:
 
1188
  **kwargs: Any,
1189
  ) -> list[float]:
1190
  rewards: list[float] = []
1191
+ success_count = 0
1192
+ graph_f1_sum = 0.0
1193
 
1194
  for idx, completion in enumerate(completions):
1195
  completion_text = decode_completion_text(completion)
 
1222
  model=self.reward_model,
1223
  difficulty=difficulty_level,
1224
  )
1225
+ final_reward = self._extract_orchestrator_reward(completion_text, breakdown.total)
1226
+ rewards.append(final_reward)
1227
+
1228
+ if predicted_answer and target_answer and normalize_answer(predicted_answer) == target_answer:
1229
+ success_count += 1
1230
+ graph_f1_sum += float(getattr(breakdown, "graph_f1", 0.0) or 0.0)
1231
+
1232
+ # Mirror GeneratorRewardFunction debug surface so the answerer reward
1233
+ # is visible to the same downstream tooling (printed by
1234
+ # `_train_grpo_phase` and forwarded to W&B by TRL).
1235
+ self._debug_batches_seen += 1
1236
+ self._debug_reward_window.extend(rewards)
1237
+ self._debug_reward_window = self._debug_reward_window[-512:]
1238
+ batch_size = max(1, len(rewards))
1239
+ batch_mean = float(sum(rewards) / batch_size)
1240
+ batch_std = float(self._std(rewards))
1241
+ advantages = [float(value - batch_mean) for value in rewards]
1242
+ self._debug_last_batch = {
1243
+ "batch_rewards": list(rewards),
1244
+ "batch_reward_mean": batch_mean,
1245
+ "batch_reward_std": batch_std,
1246
+ "advantage_proxy_min": min(advantages) if advantages else 0.0,
1247
+ "advantage_proxy_max": max(advantages) if advantages else 0.0,
1248
+ "advantage_proxy_std": float(self._std(advantages)),
1249
+ "exact_match_count": int(success_count),
1250
+ "exact_match_ratio": float(success_count / batch_size),
1251
+ "avg_graph_f1": float(graph_f1_sum / batch_size),
1252
+ }
1253
+ if self._debug_batches_seen % 10 == 0:
1254
+ window_std = self._std(self._debug_reward_window)
1255
+ print(
1256
+ "[reward_debug][answerer] "
1257
+ f"batches={self._debug_batches_seen} "
1258
+ f"window_reward_std={window_std:.6f} "
1259
+ f"last_batch_mean={batch_mean:.6f} "
1260
+ f"last_batch_std={batch_std:.6f} "
1261
+ f"exact_match_ratio={self._debug_last_batch['exact_match_ratio']:.3f} "
1262
+ f"avg_graph_f1={self._debug_last_batch['avg_graph_f1']:.3f}",
1263
+ flush=True,
1264
+ )
1265
 
1266
  return rewards
src/osint_env/training/self_play.py CHANGED
@@ -4,6 +4,7 @@ import inspect
4
  import json
5
  import os
6
  import re
 
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
  import random
@@ -753,7 +754,21 @@ def _train_grpo_phase(
753
  trainer_kwargs["peft_config"] = _build_lora_config(lora)
754
 
755
  phase_label = str(run_name).strip() or str(output_dir.name)
756
- print(f"[self_play] Starting phase: {phase_label} rows={len(rows)} max_steps={phase.max_steps}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
758
  trainer = GRPOTrainer(**trainer_kwargs)
759
  tracked_params = [
@@ -873,11 +888,35 @@ def _train_grpo_phase(
873
 
874
  reward_debug = getattr(reward_function, "_debug_last_batch", None)
875
  if isinstance(reward_debug, dict):
876
- print(f"[reward_debug][last_batch] {phase_label} {json.dumps(reward_debug, sort_keys=True)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
 
878
  print(
879
  "[self_play] Finished phase: "
880
- f"{phase_label} global_step={global_step} training_loss={training_loss} output={final_dir}"
 
881
  )
882
  return result
883
 
@@ -956,30 +995,52 @@ def _sample_generated_tasks_with_model(
956
  count: int,
957
  max_support_edges: int,
958
  max_new_tokens: int,
 
959
  ) -> list[TaskInstance]:
960
  from transformers import AutoModelForCausalLM, AutoTokenizer
961
  import torch
962
 
963
- if count <= 0:
964
  return []
965
 
 
 
 
 
 
 
966
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
967
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
968
  tokenizer.pad_token = tokenizer.eos_token
 
 
969
  model_kwargs: dict[str, Any] = {}
970
  if torch.cuda.is_available():
971
  model_kwargs["device_map"] = "auto"
972
  model_kwargs["torch_dtype"] = torch.bfloat16
973
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
974
  model.eval()
975
-
976
  device = next(model.parameters()).device
977
- generated: list[TaskInstance] = []
 
 
 
 
978
 
979
- for prompt in prompts:
 
 
 
 
980
  if len(generated) >= count:
981
  break
982
- encoded = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
983
  encoded = {k: v.to(device) for k, v in encoded.items()}
984
 
985
  with torch.no_grad():
@@ -993,35 +1054,52 @@ def _sample_generated_tasks_with_model(
993
  pad_token_id=tokenizer.eos_token_id,
994
  )
995
 
996
- completion_ids = output[0][encoded["input_ids"].shape[1] :]
997
- completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
998
- candidate = parse_generated_task_completion(completion, max_support_edges=max_support_edges)
999
- if not candidate.is_valid:
1000
- continue
 
 
 
1001
 
1002
- metadata = {
1003
- "generated_by": "generator_model",
1004
- "round": round_index,
1005
- "difficulty": "hard",
1006
- "scenario": "adversarial_trace",
1007
- "grader": {
1008
- "type": "difficulty_exact_match",
1009
- "answer_type": "node_id",
1010
- "case_sensitive": True,
1011
- "reward_profile": "hard",
1012
- },
1013
- }
1014
- generated.append(
1015
- TaskInstance(
1016
- task_id=f"adv_r{round_index}_{len(generated)}",
1017
- task_type=candidate.task_type,
1018
- question=candidate.question,
1019
- answer=candidate.answer,
1020
- supporting_edges=list(candidate.supporting_edges),
1021
- metadata=metadata,
 
1022
  )
 
 
 
 
 
 
 
 
1023
  )
1024
 
 
 
 
 
 
1025
  return generated
1026
 
1027
 
@@ -1083,13 +1161,25 @@ def _generate_answerer_completion_texts_with_model(
1083
  model_name_or_path: str,
1084
  prompts: list[str],
1085
  max_new_tokens: int,
 
1086
  ) -> list[str]:
1087
  from transformers import AutoModelForCausalLM, AutoTokenizer
1088
  import torch
1089
 
 
 
 
 
 
 
 
 
 
1090
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1091
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1092
  tokenizer.pad_token = tokenizer.eos_token
 
 
1093
 
1094
  model_kwargs: dict[str, Any] = {}
1095
  if torch.cuda.is_available():
@@ -1098,10 +1188,24 @@ def _generate_answerer_completion_texts_with_model(
1098
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
1099
  model.eval()
1100
  device = next(model.parameters()).device
 
 
 
 
 
1101
 
1102
  completions: list[str] = []
1103
- for prompt in prompts:
1104
- encoded = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
1105
  encoded = {key: value.to(device) for key, value in encoded.items()}
1106
  with torch.no_grad():
1107
  output = model.generate(
@@ -1110,8 +1214,22 @@ def _generate_answerer_completion_texts_with_model(
1110
  do_sample=False,
1111
  pad_token_id=tokenizer.eos_token_id,
1112
  )
1113
- completion_ids = output[0][encoded["input_ids"].shape[1] :]
1114
- completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  return completions
1116
 
1117
 
@@ -1404,40 +1522,74 @@ def _sample_swarm_v2_completion_texts_with_model(
1404
  from transformers import AutoModelForCausalLM, AutoTokenizer
1405
  import torch
1406
 
1407
- if count <= 0:
1408
  return []
1409
 
 
 
 
 
 
 
 
1410
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1411
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1412
  tokenizer.pad_token = tokenizer.eos_token
 
 
1413
  model_kwargs: dict[str, Any] = {}
1414
  if torch.cuda.is_available():
1415
  model_kwargs["device_map"] = "auto"
1416
  model_kwargs["torch_dtype"] = torch.bfloat16
1417
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
1418
  model.eval()
1419
-
1420
  device = next(model.parameters()).device
1421
- completions: list[str] = []
 
 
 
 
 
1422
  validator = SwarmV2ReplayValidator(
1423
  graph=env.graph,
1424
  validation=cfg.swarm_v2.validation,
1425
  shared_context=cfg.swarm_v2.shared_context,
1426
  seen_questions=seen_questions,
1427
  )
1428
- for prompt in prompts:
1429
- if len(completions) >= count:
 
 
 
 
 
 
 
 
 
 
 
 
1430
  break
1431
- encoded = tokenizer(prompt, return_tensors="pt")
1432
- encoded = {key: value.to(device) for key, value in encoded.items()}
 
 
 
 
 
 
 
 
 
 
 
 
1433
 
1434
- best_completion = ""
1435
- best_score = -999
1436
- for attempt_idx, (temperature, top_p) in enumerate([(0.7, 0.9), (0.5, 0.85), (0.3, 0.8)]):
1437
  with torch.no_grad():
1438
  output = model.generate(
1439
  **encoded,
1440
- max_new_tokens=max(64, int(cfg.generated_task_max_new_tokens)),
1441
  do_sample=True,
1442
  top_p=top_p,
1443
  temperature=temperature,
@@ -1445,22 +1597,63 @@ def _sample_swarm_v2_completion_texts_with_model(
1445
  pad_token_id=tokenizer.eos_token_id,
1446
  )
1447
 
1448
- completion_ids = output[0][encoded["input_ids"].shape[1] :]
1449
- completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
1450
- candidate = parse_generated_task_completion(
1451
- completion,
1452
- max_support_edges=cfg.swarm_v2.validation.max_support_edges,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1453
  )
1454
- validation = validator.validate(candidate)
1455
- score = int(bool(candidate.question)) + int(bool(candidate.answer)) + len(candidate.supporting_edges)
1456
- if validation.is_valid:
1457
- print(f"[self_play][generation_retry] valid_completion attempt={attempt_idx + 1}")
1458
- best_completion = completion
1459
  break
1460
- if score > best_score:
1461
- best_score = score
1462
- best_completion = completion
1463
- completions.append(best_completion)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1464
  return completions
1465
 
1466
 
 
4
  import json
5
  import os
6
  import re
7
+ import time
8
  from dataclasses import dataclass
9
  from pathlib import Path
10
  import random
 
754
  trainer_kwargs["peft_config"] = _build_lora_config(lora)
755
 
756
  phase_label = str(run_name).strip() or str(output_dir.name)
757
+ reward_class_name = type(reward_function).__name__
758
+ print(
759
+ f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
760
+ f"max_steps={phase.max_steps}",
761
+ flush=True,
762
+ )
763
+ print(
764
+ f"[self_play][reward_setup] phase={phase_label} "
765
+ f"reward_function={reward_class_name} "
766
+ f"wandb_metric=rewards/{reward_class_name}/mean "
767
+ f"logging_steps={phase.logging_steps} "
768
+ f"num_generations={phase.num_generations} "
769
+ f"per_device_train_batch_size={phase.per_device_train_batch_size}",
770
+ flush=True,
771
+ )
772
  strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
773
  trainer = GRPOTrainer(**trainer_kwargs)
774
  tracked_params = [
 
888
 
889
  reward_debug = getattr(reward_function, "_debug_last_batch", None)
890
  if isinstance(reward_debug, dict):
891
+ print(
892
+ f"[reward_debug][last_batch] {phase_label} reward_function={reward_class_name} "
893
+ f"{json.dumps(reward_debug, sort_keys=True)}",
894
+ flush=True,
895
+ )
896
+
897
+ if reward_values:
898
+ print(
899
+ f"[self_play][reward_history] {phase_label} reward_function={reward_class_name} "
900
+ f"steps_logged={len(reward_values)} "
901
+ f"reward_first={reward_values[0]:.6f} "
902
+ f"reward_last={reward_values[-1]:.6f} "
903
+ f"reward_mean={(sum(reward_values) / len(reward_values)):.6f} "
904
+ f"reward_min={min(reward_values):.6f} "
905
+ f"reward_max={max(reward_values):.6f} "
906
+ f"wandb_metric=rewards/{reward_class_name}/mean",
907
+ flush=True,
908
+ )
909
+ else:
910
+ print(
911
+ f"[self_play][reward_history] {phase_label} reward_function={reward_class_name} "
912
+ "no_reward_logs_in_state (TRL never wrote a 'reward' field; check logging_steps / num_generations)",
913
+ flush=True,
914
+ )
915
 
916
  print(
917
  "[self_play] Finished phase: "
918
+ f"{phase_label} global_step={global_step} training_loss={training_loss} output={final_dir}",
919
+ flush=True,
920
  )
921
  return result
922
 
 
995
  count: int,
996
  max_support_edges: int,
997
  max_new_tokens: int,
998
+ batch_size: int = 4,
999
  ) -> list[TaskInstance]:
1000
  from transformers import AutoModelForCausalLM, AutoTokenizer
1001
  import torch
1002
 
1003
+ if count <= 0 or not prompts:
1004
  return []
1005
 
1006
+ print(
1007
+ f"[self_play][sample_generator_legacy] start model={model_name_or_path} "
1008
+ f"prompts={len(prompts)} target_valid={count} max_new_tokens={max_new_tokens}",
1009
+ flush=True,
1010
+ )
1011
+ load_start = time.monotonic()
1012
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1013
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1014
  tokenizer.pad_token = tokenizer.eos_token
1015
+ if getattr(tokenizer, "padding_side", "right") != "left":
1016
+ tokenizer.padding_side = "left"
1017
  model_kwargs: dict[str, Any] = {}
1018
  if torch.cuda.is_available():
1019
  model_kwargs["device_map"] = "auto"
1020
  model_kwargs["torch_dtype"] = torch.bfloat16
1021
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
1022
  model.eval()
 
1023
  device = next(model.parameters()).device
1024
+ print(
1025
+ f"[self_play][sample_generator_legacy] model_loaded device={device} "
1026
+ f"load_elapsed={time.monotonic() - load_start:.1f}s",
1027
+ flush=True,
1028
+ )
1029
 
1030
+ generated: list[TaskInstance] = []
1031
+ overall_start = time.monotonic()
1032
+ effective_batch = max(1, int(batch_size or 1))
1033
+ processed = 0
1034
+ for batch_start in range(0, len(prompts), effective_batch):
1035
  if len(generated) >= count:
1036
  break
1037
+ batch_prompts = prompts[batch_start : batch_start + effective_batch]
1038
+ encoded = tokenizer(
1039
+ batch_prompts,
1040
+ return_tensors="pt",
1041
+ padding=True,
1042
+ truncation=True,
1043
+ )
1044
  encoded = {k: v.to(device) for k, v in encoded.items()}
1045
 
1046
  with torch.no_grad():
 
1054
  pad_token_id=tokenizer.eos_token_id,
1055
  )
1056
 
1057
+ input_len = encoded["input_ids"].shape[1]
1058
+ for row_offset in range(len(batch_prompts)):
1059
+ completion_ids = output[row_offset][input_len:]
1060
+ completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
1061
+ candidate = parse_generated_task_completion(completion, max_support_edges=max_support_edges)
1062
+ processed += 1
1063
+ if not candidate.is_valid:
1064
+ continue
1065
 
1066
+ metadata = {
1067
+ "generated_by": "generator_model",
1068
+ "round": round_index,
1069
+ "difficulty": "hard",
1070
+ "scenario": "adversarial_trace",
1071
+ "grader": {
1072
+ "type": "difficulty_exact_match",
1073
+ "answer_type": "node_id",
1074
+ "case_sensitive": True,
1075
+ "reward_profile": "hard",
1076
+ },
1077
+ }
1078
+ generated.append(
1079
+ TaskInstance(
1080
+ task_id=f"adv_r{round_index}_{len(generated)}",
1081
+ task_type=candidate.task_type,
1082
+ question=candidate.question,
1083
+ answer=candidate.answer,
1084
+ supporting_edges=list(candidate.supporting_edges),
1085
+ metadata=metadata,
1086
+ )
1087
  )
1088
+ if len(generated) >= count:
1089
+ break
1090
+
1091
+ print(
1092
+ f"[self_play][sample_generator_legacy] processed={processed}/{len(prompts)} "
1093
+ f"valid={len(generated)}/{count} "
1094
+ f"elapsed={time.monotonic() - overall_start:.1f}s",
1095
+ flush=True,
1096
  )
1097
 
1098
+ print(
1099
+ f"[self_play][sample_generator_legacy] finished generated={len(generated)}/{count} "
1100
+ f"total_elapsed={time.monotonic() - overall_start:.1f}s",
1101
+ flush=True,
1102
+ )
1103
  return generated
1104
 
1105
 
 
1161
  model_name_or_path: str,
1162
  prompts: list[str],
1163
  max_new_tokens: int,
1164
+ batch_size: int = 4,
1165
  ) -> list[str]:
1166
  from transformers import AutoModelForCausalLM, AutoTokenizer
1167
  import torch
1168
 
1169
+ if not prompts:
1170
+ return []
1171
+
1172
+ print(
1173
+ f"[self_play][sample_answerer] start model={model_name_or_path} "
1174
+ f"prompts={len(prompts)} max_new_tokens={max_new_tokens}",
1175
+ flush=True,
1176
+ )
1177
+ load_start = time.monotonic()
1178
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1179
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1180
  tokenizer.pad_token = tokenizer.eos_token
1181
+ if getattr(tokenizer, "padding_side", "right") != "left":
1182
+ tokenizer.padding_side = "left"
1183
 
1184
  model_kwargs: dict[str, Any] = {}
1185
  if torch.cuda.is_available():
 
1188
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
1189
  model.eval()
1190
  device = next(model.parameters()).device
1191
+ print(
1192
+ f"[self_play][sample_answerer] model_loaded device={device} "
1193
+ f"load_elapsed={time.monotonic() - load_start:.1f}s",
1194
+ flush=True,
1195
+ )
1196
 
1197
  completions: list[str] = []
1198
+ overall_start = time.monotonic()
1199
+ effective_batch = max(1, int(batch_size or 1))
1200
+ processed = 0
1201
+ for batch_start in range(0, len(prompts), effective_batch):
1202
+ batch_prompts = prompts[batch_start : batch_start + effective_batch]
1203
+ encoded = tokenizer(
1204
+ batch_prompts,
1205
+ return_tensors="pt",
1206
+ padding=True,
1207
+ truncation=True,
1208
+ )
1209
  encoded = {key: value.to(device) for key, value in encoded.items()}
1210
  with torch.no_grad():
1211
  output = model.generate(
 
1214
  do_sample=False,
1215
  pad_token_id=tokenizer.eos_token_id,
1216
  )
1217
+ input_len = encoded["input_ids"].shape[1]
1218
+ for row_offset in range(len(batch_prompts)):
1219
+ completion_ids = output[row_offset][input_len:]
1220
+ completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
1221
+ processed += len(batch_prompts)
1222
+ print(
1223
+ f"[self_play][sample_answerer] processed={processed}/{len(prompts)} "
1224
+ f"elapsed={time.monotonic() - overall_start:.1f}s",
1225
+ flush=True,
1226
+ )
1227
+
1228
+ print(
1229
+ f"[self_play][sample_answerer] finished completions={len(completions)} "
1230
+ f"total_elapsed={time.monotonic() - overall_start:.1f}s",
1231
+ flush=True,
1232
+ )
1233
  return completions
1234
 
1235
 
 
1522
  from transformers import AutoModelForCausalLM, AutoTokenizer
1523
  import torch
1524
 
1525
+ if count <= 0 or not prompts:
1526
  return []
1527
 
1528
+ print(
1529
+ f"[self_play][sample_generator] start model={model_name_or_path} "
1530
+ f"prompts={len(prompts)} target_valid={count} "
1531
+ f"max_new_tokens={cfg.generated_task_max_new_tokens}",
1532
+ flush=True,
1533
+ )
1534
+ load_start = time.monotonic()
1535
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1536
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1537
  tokenizer.pad_token = tokenizer.eos_token
1538
+ if getattr(tokenizer, "padding_side", "right") != "left":
1539
+ tokenizer.padding_side = "left"
1540
  model_kwargs: dict[str, Any] = {}
1541
  if torch.cuda.is_available():
1542
  model_kwargs["device_map"] = "auto"
1543
  model_kwargs["torch_dtype"] = torch.bfloat16
1544
  model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
1545
  model.eval()
 
1546
  device = next(model.parameters()).device
1547
+ print(
1548
+ f"[self_play][sample_generator] model_loaded device={device} "
1549
+ f"load_elapsed={time.monotonic() - load_start:.1f}s",
1550
+ flush=True,
1551
+ )
1552
+
1553
  validator = SwarmV2ReplayValidator(
1554
  graph=env.graph,
1555
  validation=cfg.swarm_v2.validation,
1556
  shared_context=cfg.swarm_v2.shared_context,
1557
  seen_questions=seen_questions,
1558
  )
1559
+ completions: list[str] = []
1560
+ valid_count = 0
1561
+ batch_size = max(1, int(getattr(cfg.generator_phase, "generation_batch_size", 4) or 4))
1562
+ max_new_tokens = max(64, int(cfg.generated_task_max_new_tokens))
1563
+ decode_schedule = [(0.7, 0.9), (0.5, 0.85), (0.3, 0.8)]
1564
+ overall_start = time.monotonic()
1565
+
1566
+ pending_indices = list(range(len(prompts)))
1567
+ best_completions: dict[int, str] = {}
1568
+ best_scores: dict[int, int] = {}
1569
+ valid_marks: dict[int, bool] = {}
1570
+
1571
+ for attempt_idx, (temperature, top_p) in enumerate(decode_schedule):
1572
+ if not pending_indices:
1573
  break
1574
+ attempt_start = time.monotonic()
1575
+ next_pending: list[int] = []
1576
+ processed = 0
1577
+ for batch_start in range(0, len(pending_indices), batch_size):
1578
+ batch_indices = pending_indices[batch_start : batch_start + batch_size]
1579
+ batch_prompts = [prompts[i] for i in batch_indices]
1580
+ encoded = tokenizer(
1581
+ batch_prompts,
1582
+ return_tensors="pt",
1583
+ padding=True,
1584
+ truncation=True,
1585
+ max_length=int(getattr(cfg.generator_phase, "max_prompt_length", 1024) or 1024),
1586
+ )
1587
+ encoded = {key: value.to(device) for key, value in encoded.items()}
1588
 
 
 
 
1589
  with torch.no_grad():
1590
  output = model.generate(
1591
  **encoded,
1592
+ max_new_tokens=max_new_tokens,
1593
  do_sample=True,
1594
  top_p=top_p,
1595
  temperature=temperature,
 
1597
  pad_token_id=tokenizer.eos_token_id,
1598
  )
1599
 
1600
+ input_len = encoded["input_ids"].shape[1]
1601
+ for row_offset, prompt_idx in enumerate(batch_indices):
1602
+ completion_ids = output[row_offset][input_len:]
1603
+ completion = tokenizer.decode(completion_ids, skip_special_tokens=True)
1604
+ candidate = parse_generated_task_completion(
1605
+ completion,
1606
+ max_support_edges=cfg.swarm_v2.validation.max_support_edges,
1607
+ )
1608
+ validation = validator.validate(candidate)
1609
+ score = (
1610
+ int(bool(candidate.question))
1611
+ + int(bool(candidate.answer))
1612
+ + len(candidate.supporting_edges)
1613
+ )
1614
+ if validation.is_valid:
1615
+ if not valid_marks.get(prompt_idx):
1616
+ valid_count += 1
1617
+ valid_marks[prompt_idx] = True
1618
+ best_completions[prompt_idx] = completion
1619
+ best_scores[prompt_idx] = score
1620
+ else:
1621
+ if score > best_scores.get(prompt_idx, -999):
1622
+ best_scores[prompt_idx] = score
1623
+ best_completions[prompt_idx] = completion
1624
+ if not valid_marks.get(prompt_idx):
1625
+ next_pending.append(prompt_idx)
1626
+
1627
+ processed += len(batch_indices)
1628
+ print(
1629
+ f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
1630
+ f"processed={processed}/{len(pending_indices)} "
1631
+ f"valid_so_far={valid_count}/{len(prompts)} "
1632
+ f"target_valid={count} "
1633
+ f"elapsed={time.monotonic() - overall_start:.1f}s",
1634
+ flush=True,
1635
  )
1636
+ if valid_count >= count:
 
 
 
 
1637
  break
1638
+ print(
1639
+ f"[self_play][sample_generator] attempt={attempt_idx + 1} done "
1640
+ f"valid={valid_count}/{len(prompts)} "
1641
+ f"attempt_elapsed={time.monotonic() - attempt_start:.1f}s",
1642
+ flush=True,
1643
+ )
1644
+ if valid_count >= count:
1645
+ break
1646
+ pending_indices = next_pending
1647
+
1648
+ for prompt_idx in range(len(prompts)):
1649
+ completions.append(best_completions.get(prompt_idx, ""))
1650
+
1651
+ print(
1652
+ f"[self_play][sample_generator] finished completions={len(completions)} "
1653
+ f"valid={valid_count}/{len(prompts)} target_valid={count} "
1654
+ f"total_elapsed={time.monotonic() - overall_start:.1f}s",
1655
+ flush=True,
1656
+ )
1657
  return completions
1658
 
1659