siddeshwar-kagatikar commited on
Commit
55c5f82
·
1 Parent(s): b0b586a

Add per-batch generation liveness logging

Browse files

The generator and answerer samplers print their progress *after* each
batch completes. With prompts=24, generation_batch_size=16,
max_new_tokens=640 on an L40S, the very first batch can sit silently
for 60-90 seconds because of CUDA kernel autotuning and KV cache
allocation. That made the user think training had hung after the
"model_loaded" line.

Print a "generating..." line *before* each model.generate() call and a
"generate_elapsed=Ts" line right after, in all three samplers
(_sample_swarm_v2_completion_texts_with_model,
_sample_generated_tasks_with_model,
_generate_answerer_completion_texts_with_model). Now every batch has
a visible start and end with timings, so liveness is obvious even
during the slow first-batch warmup.

Made-with: Cursor

Files changed (1) hide show
  1. src/osint_env/training/self_play.py +35 -0
src/osint_env/training/self_play.py CHANGED
@@ -1309,6 +1309,13 @@ def _sample_generated_tasks_with_model(
1309
  )
1310
  encoded = {k: v.to(device) for k, v in encoded.items()}
1311
 
 
 
 
 
 
 
 
1312
  with torch.no_grad():
1313
  output = model.generate(
1314
  **encoded,
@@ -1319,6 +1326,11 @@ def _sample_generated_tasks_with_model(
1319
  num_return_sequences=1,
1320
  pad_token_id=tokenizer.eos_token_id,
1321
  )
 
 
 
 
 
1322
 
1323
  input_len = encoded["input_ids"].shape[1]
1324
  for row_offset in range(len(batch_prompts)):
@@ -1473,6 +1485,13 @@ def _generate_answerer_completion_texts_with_model(
1473
  truncation=True,
1474
  )
1475
  encoded = {key: value.to(device) for key, value in encoded.items()}
 
 
 
 
 
 
 
1476
  with torch.no_grad():
1477
  output = model.generate(
1478
  **encoded,
@@ -1487,6 +1506,7 @@ def _generate_answerer_completion_texts_with_model(
1487
  processed += len(batch_prompts)
1488
  print(
1489
  f"[self_play][sample_answerer] processed={processed}/{len(prompts)} "
 
1490
  f"elapsed={time.monotonic() - overall_start:.1f}s",
1491
  flush=True,
1492
  )
@@ -1852,6 +1872,16 @@ def _sample_swarm_v2_completion_texts_with_model(
1852
  )
1853
  encoded = {key: value.to(device) for key, value in encoded.items()}
1854
 
 
 
 
 
 
 
 
 
 
 
1855
  with torch.no_grad():
1856
  output = model.generate(
1857
  **encoded,
@@ -1862,6 +1892,11 @@ def _sample_swarm_v2_completion_texts_with_model(
1862
  num_return_sequences=1,
1863
  pad_token_id=tokenizer.eos_token_id,
1864
  )
 
 
 
 
 
1865
 
1866
  input_len = encoded["input_ids"].shape[1]
1867
  for row_offset, prompt_idx in enumerate(batch_indices):
 
1309
  )
1310
  encoded = {k: v.to(device) for k, v in encoded.items()}
1311
 
1312
+ print(
1313
+ f"[self_play][sample_legacy] batch_start={batch_start}/{len(prompts)} "
1314
+ f"batch_size={len(batch_prompts)} max_new_tokens={max(64, int(max_new_tokens))} "
1315
+ f"input_len={encoded['input_ids'].shape[1]} generating...",
1316
+ flush=True,
1317
+ )
1318
+ batch_t0 = time.monotonic()
1319
  with torch.no_grad():
1320
  output = model.generate(
1321
  **encoded,
 
1326
  num_return_sequences=1,
1327
  pad_token_id=tokenizer.eos_token_id,
1328
  )
1329
+ print(
1330
+ f"[self_play][sample_legacy] batch_start={batch_start} "
1331
+ f"generate_elapsed={time.monotonic() - batch_t0:.1f}s",
1332
+ flush=True,
1333
+ )
1334
 
1335
  input_len = encoded["input_ids"].shape[1]
1336
  for row_offset in range(len(batch_prompts)):
 
1485
  truncation=True,
1486
  )
1487
  encoded = {key: value.to(device) for key, value in encoded.items()}
1488
+ print(
1489
+ f"[self_play][sample_answerer] batch_start={batch_start}/{len(prompts)} "
1490
+ f"batch_size={len(batch_prompts)} max_new_tokens={max(16, int(max_new_tokens))} "
1491
+ f"input_len={encoded['input_ids'].shape[1]} generating...",
1492
+ flush=True,
1493
+ )
1494
+ batch_t0 = time.monotonic()
1495
  with torch.no_grad():
1496
  output = model.generate(
1497
  **encoded,
 
1506
  processed += len(batch_prompts)
1507
  print(
1508
  f"[self_play][sample_answerer] processed={processed}/{len(prompts)} "
1509
+ f"generate_elapsed={time.monotonic() - batch_t0:.1f}s "
1510
  f"elapsed={time.monotonic() - overall_start:.1f}s",
1511
  flush=True,
1512
  )
 
1872
  )
1873
  encoded = {key: value.to(device) for key, value in encoded.items()}
1874
 
1875
+ print(
1876
+ f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
1877
+ f"batch_start={batch_start}/{len(pending_indices)} "
1878
+ f"batch_size={len(batch_indices)} "
1879
+ f"max_new_tokens={max_new_tokens} "
1880
+ f"input_len={encoded['input_ids'].shape[1]} "
1881
+ f"generating...",
1882
+ flush=True,
1883
+ )
1884
+ batch_t0 = time.monotonic()
1885
  with torch.no_grad():
1886
  output = model.generate(
1887
  **encoded,
 
1892
  num_return_sequences=1,
1893
  pad_token_id=tokenizer.eos_token_id,
1894
  )
1895
+ print(
1896
+ f"[self_play][sample_generator] attempt={attempt_idx + 1}/{len(decode_schedule)} "
1897
+ f"batch_start={batch_start} generate_elapsed={time.monotonic() - batch_t0:.1f}s",
1898
+ flush=True,
1899
+ )
1900
 
1901
  input_len = encoded["input_ids"].shape[1]
1902
  for row_offset, prompt_idx in enumerate(batch_indices):