siddeshwar-kagatikar commited on
Commit
fe1f842
·
1 Parent(s): d814291

Sync current main to Hugging Face Space

Browse files
README.md CHANGED
@@ -174,9 +174,27 @@ osint-env train-self-play --config config/shared_config.json --train-config conf
174
 
175
  When you have compute and the train dependencies installed, remove `--dry-run` (or set `"dry_run": false` in the training config) to execute TRL GRPO updates for alternating generator and answerer phases.
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  The training config also supports `"model_topology": "dual"|"shared"`, `"phase_schedule": "generator_answerer"|"answerer_generator_answerer"`, `"tuning_mode": "full"|"lora"`, and `"canonical_graph_mode": "generate"|"fixed"` so you can switch between two-model vs single-model self-play, full fine-tuning vs LoRA adapters, and whether canonical graph structure is generated each round or kept fixed while training question/answer behavior.
178
 
179
- ### Hugging Face Space Smoke Run (Qwen 3.5 0.8B + W&B)
180
 
181
  For a short verification run (enough to confirm W&B logging before scaling up), use:
182
 
@@ -186,27 +204,42 @@ osint-env train-self-play --config config/shared_config.json --train-config conf
186
 
187
  This config:
188
 
189
- - uses `Qwen/Qwen3.5-0.8B`
190
  - enables W&B reporting (`wandb_enabled: true`)
191
  - uses `pipeline_mode: "swarm_v2"` with `canonical_graph_mode: "fixed"` to keep canonical graph candidates stable while training question/answer behavior
192
- - keeps training intentionally short (`rounds=1`, `max_steps=5` per phase)
193
- - uses LoRA with small batch settings so it can run as a smoke test on an A10G
194
 
195
  To enable canonical graph generation during swarm_v2 training, switch `"canonical_graph_mode"` to `"generate"` in the training config.
196
 
197
- Space setup checklist:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- 1. In Space **Settings -> Hardware**, select **NVIDIA A10G (large)**.
200
- 2. In Space **Settings -> Variables and secrets**, set `WANDB_API_KEY`.
201
- 3. Set `HF_TOKEN` in Space secrets to avoid unauthenticated Hub downloads and stricter rate limits.
202
- 4. Optionally set `WANDB_ENTITY` if your project belongs to a team.
203
- 5. Set `RUN_SELF_PLAY_TRAINING=1` in Space variables to trigger training during container startup.
204
- 6. Optional overrides:
205
  - `TRAIN_SELF_PLAY_CONFIG_PATH` (default: `config/self_play_training_hf_a10g_smoke.json`)
206
  - `TRAIN_ENV_CONFIG_PATH` (default: `config/shared_config.json`)
207
- - `RUN_SELF_PLAY_DRY_RUN=1` to test startup wiring without GRPO updates.
208
- - `OSINT_TRAIN_STRICT_ASSERTS=1` to fail fast when reward variance, KL, loss, grad norms, or parameter updates stay zero.
209
- 7. Restart the Space and monitor build/runtime logs for the training run.
 
210
 
211
  W&B run naming is controlled by `wandb_run_name_prefix` and will emit phase-specific runs like `...-r001-generator` and `...-r001-answerer`.
212
 
@@ -229,10 +262,10 @@ In `legacy` pipeline mode, the reward is a weighted sum:
229
 
230
  Default weights (configurable through `generator_reward_weights` in training config):
231
 
232
- - `validity`: `0.35`
233
- - `hardness`: `0.45`
234
- - `diversity`: `0.10`
235
- - `consistency`: `0.10`
236
 
237
  In `swarm_v2` pipeline mode, generation uses strict replay/validation first, then a structured reward:
238
 
 
174
 
175
  When you have compute and the train dependencies installed, remove `--dry-run` (or set `"dry_run": false` in the training config) to execute TRL GRPO updates for alternating generator and answerer phases.
176
 
177
+ For a standalone Linux server or SSH box, there is also a wrapper script that activates a venv, optionally installs train deps, and runs the same command:
178
+
179
+ ```bash
180
+ VENV_PATH="$HOME/arl" \
181
+ INSTALL_TRAIN_DEPS=1 \
182
+ TRAIN_ENV_CONFIG_PATH="config/shared_config.json" \
183
+ TRAIN_SELF_PLAY_CONFIG_PATH="config/self_play_training_hf_a10g_smoke.json" \
184
+ TRAIN_SELF_PLAY_OUTPUT_DIR="artifacts/self_play_server" \
185
+ bash scripts/train_self_play_standalone.sh
186
+ ```
187
+
188
+ Useful overrides for the standalone script:
189
+
190
+ - `BOOTSTRAP_VENV=1` to create the virtualenv if it does not exist
191
+ - `TRAIN_SELF_PLAY_ROUNDS=2` to override the number of rounds
192
+ - `RUN_SELF_PLAY_DRY_RUN=1` to materialize artifacts without GRPO updates
193
+ - `TRAIN_SETUP_COMMAND='python -m pip install flash-attn --no-build-isolation'` for host-specific extras
194
+
195
  The training config also supports `"model_topology": "dual"|"shared"`, `"phase_schedule": "generator_answerer"|"answerer_generator_answerer"`, `"tuning_mode": "full"|"lora"`, and `"canonical_graph_mode": "generate"|"fixed"` so you can switch between two-model vs single-model self-play, full fine-tuning vs LoRA adapters, and whether canonical graph structure is generated each round or kept fixed while training question/answer behavior.
196
 
197
+ ### Hugging Face Job A10G Run (Separate From The Space)
198
 
199
  For a short verification run (enough to confirm W&B logging before scaling up), use:
200
 
 
204
 
205
  This config:
206
 
207
+ - uses `Qwen/Qwen2.5-0.5B-Instruct`
208
  - enables W&B reporting (`wandb_enabled: true`)
209
  - uses `pipeline_mode: "swarm_v2"` with `canonical_graph_mode: "fixed"` to keep canonical graph candidates stable while training question/answer behavior
210
+ - keeps training intentionally short (`rounds=2`, `max_steps=50` per phase)
211
+ - uses full fine-tuning plus fused AdamW, bf16/tf32, larger generation batches, and extra dataloader workers to make better use of an A10G
212
 
213
  To enable canonical graph generation during swarm_v2 training, switch `"canonical_graph_mode"` to `"generate"` in the training config.
214
 
215
+ If you want the Space to stay on CPU and train separately on paid GPU compute, launch a dedicated Hugging Face Job instead of training inside the Space:
216
+
217
+ ```bash
218
+ osint-env-launch-hf-job \
219
+ --hf-token "$HF_TOKEN" \
220
+ --job-image "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel" \
221
+ --repo-url "https://github.com/your-org/meta-knowledge-graph.git" \
222
+ --repo-ref "main" \
223
+ --flavor "a10g-small" \
224
+ --env-config "config/shared_config.json" \
225
+ --train-config "config/self_play_training_hf_a10g_smoke.json" \
226
+ --output-bucket "your-hf-bucket" \
227
+ --wait
228
+ ```
229
+
230
+ The launcher talks to the Hugging Face Jobs API through `huggingface_hub`, so the Space can remain on CPU while the training job runs on separate A10G compute.
231
+
232
+ Optional Space startup wiring still exists if you want it:
233
 
234
+ 1. Keep the Space on CPU if it is serving inference/UI only.
235
+ 2. Set `RUN_SELF_PLAY_TRAINING=1` only if you intentionally want startup-time training inside the Space container.
236
+ 3. Optional overrides:
 
 
 
237
  - `TRAIN_SELF_PLAY_CONFIG_PATH` (default: `config/self_play_training_hf_a10g_smoke.json`)
238
  - `TRAIN_ENV_CONFIG_PATH` (default: `config/shared_config.json`)
239
+ - `TRAIN_SELF_PLAY_OUTPUT_DIR` to override where artifacts land
240
+ - `RUN_SELF_PLAY_DRY_RUN=1` to test startup wiring without GRPO updates
241
+ - `RUN_SELF_PLAY_BACKGROUND=1` to keep the API up while startup-time training runs
242
+ - `OSINT_TRAIN_STRICT_ASSERTS=1` to fail fast when reward variance, KL, loss, grad norms, or parameter updates stay zero
243
 
244
  W&B run naming is controlled by `wandb_run_name_prefix` and will emit phase-specific runs like `...-r001-generator` and `...-r001-answerer`.
245
 
 
262
 
263
  Default weights (configurable through `generator_reward_weights` in training config):
264
 
265
+ - `validity`: `0.45`
266
+ - `hardness`: `0.20`
267
+ - `diversity`: `0.15`
268
+ - `consistency`: `0.20`
269
 
270
  In `swarm_v2` pipeline mode, generation uses strict replay/validation first, then a structured reward:
271
 
config/self_play_training_example.json CHANGED
@@ -15,11 +15,14 @@
15
  "max_graph_context_edges": 100,
16
  "max_support_edges": 8,
17
  "answerer_judge_max_new_tokens": 48,
 
 
 
18
  "generator_reward_weights": {
19
- "validity": 0.35,
20
- "hardness": 0.45,
21
- "diversity": 0.1,
22
- "consistency": 0.1
23
  },
24
  "lora": {
25
  "r": 16,
@@ -62,12 +65,14 @@
62
  },
63
  "generator_phase": {
64
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
65
- "learning_rate": 1e-06,
66
  "max_steps": 64,
67
  "per_device_train_batch_size": 2,
68
  "gradient_accumulation_steps": 4,
69
  "num_generations": 4,
70
- "max_completion_length": 256,
 
 
71
  "temperature": 1.0,
72
  "top_p": 1.0,
73
  "beta": 0.01,
@@ -77,18 +82,28 @@
77
  "scale_rewards": "none",
78
  "logging_steps": 10,
79
  "save_steps": 50,
 
 
 
 
 
 
 
 
80
  "output_subdir": "generator",
81
  "use_vllm": false,
82
  "vllm_mode": "colocate"
83
  },
84
  "answerer_phase": {
85
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
86
- "learning_rate": 1e-06,
87
  "max_steps": 64,
88
  "per_device_train_batch_size": 2,
89
  "gradient_accumulation_steps": 4,
90
  "num_generations": 4,
91
  "max_completion_length": 192,
 
 
92
  "temperature": 1.0,
93
  "top_p": 1.0,
94
  "beta": 0.01,
@@ -98,6 +113,14 @@
98
  "scale_rewards": "none",
99
  "logging_steps": 10,
100
  "save_steps": 50,
 
 
 
 
 
 
 
 
101
  "output_subdir": "answerer",
102
  "use_vllm": false,
103
  "vllm_mode": "colocate"
 
15
  "max_graph_context_edges": 100,
16
  "max_support_edges": 8,
17
  "answerer_judge_max_new_tokens": 48,
18
+ "generated_task_max_new_tokens": 512,
19
+ "post_training_eval_questions": 24,
20
+ "post_training_eval_answer_max_new_tokens": 128,
21
  "generator_reward_weights": {
22
+ "validity": 0.45,
23
+ "hardness": 0.2,
24
+ "diversity": 0.15,
25
+ "consistency": 0.2
26
  },
27
  "lora": {
28
  "r": 16,
 
65
  },
66
  "generator_phase": {
67
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
68
+ "learning_rate": 5e-06,
69
  "max_steps": 64,
70
  "per_device_train_batch_size": 2,
71
  "gradient_accumulation_steps": 4,
72
  "num_generations": 4,
73
+ "max_completion_length": 384,
74
+ "max_prompt_length": 1024,
75
+ "generation_batch_size": 8,
76
  "temperature": 1.0,
77
  "top_p": 1.0,
78
  "beta": 0.01,
 
82
  "scale_rewards": "none",
83
  "logging_steps": 10,
84
  "save_steps": 50,
85
+ "save_total_limit": 2,
86
+ "optim": "adamw_torch_fused",
87
+ "bf16": true,
88
+ "tf32": true,
89
+ "gradient_checkpointing": false,
90
+ "dataloader_num_workers": 2,
91
+ "dataloader_persistent_workers": true,
92
+ "dataloader_prefetch_factor": 2,
93
  "output_subdir": "generator",
94
  "use_vllm": false,
95
  "vllm_mode": "colocate"
96
  },
97
  "answerer_phase": {
98
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
99
+ "learning_rate": 3e-06,
100
  "max_steps": 64,
101
  "per_device_train_batch_size": 2,
102
  "gradient_accumulation_steps": 4,
103
  "num_generations": 4,
104
  "max_completion_length": 192,
105
+ "max_prompt_length": 1024,
106
+ "generation_batch_size": 8,
107
  "temperature": 1.0,
108
  "top_p": 1.0,
109
  "beta": 0.01,
 
113
  "scale_rewards": "none",
114
  "logging_steps": 10,
115
  "save_steps": 50,
116
+ "save_total_limit": 2,
117
+ "optim": "adamw_torch_fused",
118
+ "bf16": true,
119
+ "tf32": true,
120
+ "gradient_checkpointing": false,
121
+ "dataloader_num_workers": 2,
122
+ "dataloader_persistent_workers": true,
123
+ "dataloader_prefetch_factor": 2,
124
  "output_subdir": "answerer",
125
  "use_vllm": false,
126
  "vllm_mode": "colocate"
config/self_play_training_hf_a10g_smoke.json CHANGED
@@ -10,7 +10,7 @@
10
  "canonical_graph_mode": "fixed",
11
  "model_topology": "shared",
12
  "phase_schedule": "generator_answerer",
13
- "tuning_mode": "lora",
14
  "shared_model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
15
  "seed_tasks_per_round": 16,
16
  "generated_tasks_per_round": 24,
@@ -19,14 +19,19 @@
19
  "max_graph_context_edges": 24,
20
  "max_support_edges": 6,
21
  "answerer_judge_max_new_tokens": 32,
 
 
 
22
  "generator_phase": {
23
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
24
- "learning_rate": 1e-06,
25
  "max_steps": 50,
26
  "per_device_train_batch_size": 4,
27
  "gradient_accumulation_steps": 1,
28
  "num_generations": 4,
29
- "max_completion_length": 768,
 
 
30
  "temperature": 0.9,
31
  "top_p": 0.95,
32
  "repetition_penalty": 1.1,
@@ -37,18 +42,28 @@
37
  "scale_rewards": "group",
38
  "logging_steps": 1,
39
  "save_steps": 10,
 
 
 
 
 
 
 
 
40
  "output_subdir": "generator_train",
41
  "use_vllm": false,
42
  "vllm_mode": "colocate"
43
  },
44
  "answerer_phase": {
45
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
46
- "learning_rate": 1e-06,
47
  "max_steps": 50,
48
  "per_device_train_batch_size": 4,
49
  "gradient_accumulation_steps": 1,
50
  "num_generations": 4,
51
  "max_completion_length": 256,
 
 
52
  "temperature": 0.7,
53
  "top_p": 0.95,
54
  "repetition_penalty": 1.1,
@@ -59,6 +74,14 @@
59
  "scale_rewards": "group",
60
  "logging_steps": 1,
61
  "save_steps": 10,
 
 
 
 
 
 
 
 
62
  "output_subdir": "answerer_train",
63
  "use_vllm": false,
64
  "vllm_mode": "colocate"
 
10
  "canonical_graph_mode": "fixed",
11
  "model_topology": "shared",
12
  "phase_schedule": "generator_answerer",
13
+ "tuning_mode": "full",
14
  "shared_model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
15
  "seed_tasks_per_round": 16,
16
  "generated_tasks_per_round": 24,
 
19
  "max_graph_context_edges": 24,
20
  "max_support_edges": 6,
21
  "answerer_judge_max_new_tokens": 32,
22
+ "generated_task_max_new_tokens": 640,
23
+ "post_training_eval_questions": 24,
24
+ "post_training_eval_answer_max_new_tokens": 128,
25
  "generator_phase": {
26
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
27
+ "learning_rate": 5e-06,
28
  "max_steps": 50,
29
  "per_device_train_batch_size": 4,
30
  "gradient_accumulation_steps": 1,
31
  "num_generations": 4,
32
+ "max_completion_length": 384,
33
+ "max_prompt_length": 768,
34
+ "generation_batch_size": 16,
35
  "temperature": 0.9,
36
  "top_p": 0.95,
37
  "repetition_penalty": 1.1,
 
42
  "scale_rewards": "group",
43
  "logging_steps": 1,
44
  "save_steps": 10,
45
+ "save_total_limit": 2,
46
+ "optim": "adamw_torch_fused",
47
+ "bf16": true,
48
+ "tf32": true,
49
+ "gradient_checkpointing": false,
50
+ "dataloader_num_workers": 4,
51
+ "dataloader_persistent_workers": true,
52
+ "dataloader_prefetch_factor": 4,
53
  "output_subdir": "generator_train",
54
  "use_vllm": false,
55
  "vllm_mode": "colocate"
56
  },
57
  "answerer_phase": {
58
  "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
59
+ "learning_rate": 3e-06,
60
  "max_steps": 50,
61
  "per_device_train_batch_size": 4,
62
  "gradient_accumulation_steps": 1,
63
  "num_generations": 4,
64
  "max_completion_length": 256,
65
+ "max_prompt_length": 768,
66
+ "generation_batch_size": 16,
67
  "temperature": 0.7,
68
  "top_p": 0.95,
69
  "repetition_penalty": 1.1,
 
74
  "scale_rewards": "group",
75
  "logging_steps": 1,
76
  "save_steps": 10,
77
+ "save_total_limit": 2,
78
+ "optim": "adamw_torch_fused",
79
+ "bf16": true,
80
+ "tf32": true,
81
+ "gradient_checkpointing": false,
82
+ "dataloader_num_workers": 4,
83
+ "dataloader_persistent_workers": true,
84
+ "dataloader_prefetch_factor": 4,
85
  "output_subdir": "answerer_train",
86
  "use_vllm": false,
87
  "vllm_mode": "colocate"
docs/adversarial_self_play.md CHANGED
@@ -59,6 +59,14 @@ This directly supports the "train solver, freeze, attack, retrain solver" sequen
59
 
60
  Weights are configurable in `generator_reward_weights`.
61
 
 
 
 
 
 
 
 
 
62
  ### Answerer (existing reward integration)
63
 
64
  `AnswererRewardFunction` wraps existing environment reward logic:
@@ -88,12 +96,60 @@ In dry run mode, the pipeline still:
88
 
89
  But it skips expensive GRPO updates.
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ## Compute Mode
92
 
93
  When compute is available:
94
 
95
  1. Install train dependencies: `python -m pip install -e ".[train]"`
96
  2. Disable dry run (`--dry-run` off and/or `"dry_run": false` in config).
97
- 3. Run `osint-env train-self-play`.
98
 
99
  Outputs are written under `artifacts/self_play` unless overridden.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  Weights are configurable in `generator_reward_weights`.
61
 
62
+ For `swarm_v2`, the reward now prioritizes:
63
+
64
+ - Valid, replayable task structure first.
65
+ - Hardness against the frozen answerer second.
66
+ - Diversity and compact multi-agent/shared-context usage after validity.
67
+
68
+ This avoids the degenerate regime where almost every sample is invalid and the whole batch stays negative.
69
+
70
  ### Answerer (existing reward integration)
71
 
72
  `AnswererRewardFunction` wraps existing environment reward logic:
 
96
 
97
  But it skips expensive GRPO updates.
98
 
99
+ ## Post-Training Evaluation
100
+
101
+ After a non-dry-run training job completes, the runner now writes a post-training evaluation artifact that:
102
+
103
+ - Uses the finetuned generator to create fresh evaluation questions.
104
+ - Evaluates both the finetuned answerer and the original/base answerer on those generated questions.
105
+ - Reports a `delta_vs_original` summary so you can see whether fine-tuning actually improved task success, reward, and graph F1.
106
+ - Saves the summary and episode rows under `post_training_evaluation.json`.
107
+
108
+ You can control this flow with:
109
+
110
+ - `generated_task_max_new_tokens`: decoding budget for generator-side task sampling/eval.
111
+ - `post_training_eval_questions`: how many fresh tasks to evaluate after training.
112
+ - `post_training_eval_answer_max_new_tokens`: answerer decoding budget for the final eval pass.
113
+
114
+ ## Checkpoints And Final Models
115
+
116
+ Self-play outputs are written under `output_dir` (default `artifacts/self_play`) unless overridden.
117
+
118
+ Per round and phase you will now find:
119
+
120
+ - `round_XXX/<phase>/checkpoint-*`: intermediate trainer checkpoints saved every `save_steps`.
121
+ - `round_XXX/<phase>/final_model`: final saved model for that phase, with tokenizer files.
122
+ - `self_play_summary.json`: top-level run summary.
123
+ - `post_training_evaluation.json`: generated-question evaluation written after training.
124
+
125
  ## Compute Mode
126
 
127
  When compute is available:
128
 
129
  1. Install train dependencies: `python -m pip install -e ".[train]"`
130
  2. Disable dry run (`--dry-run` off and/or `"dry_run": false` in config).
131
+ 3. Run `osint-env train-self-play`, or launch a dedicated Hugging Face Job with `osint-env-launch-hf-job` if you want the Space to stay on CPU while training runs on separate GPU compute.
132
 
133
  Outputs are written under `artifacts/self_play` unless overridden.
134
+
135
+ ## Standalone Server Script
136
+
137
+ For an SSH server or other standalone machine, you can use `scripts/train_self_play_standalone.sh`.
138
+
139
+ Example:
140
+
141
+ ```bash
142
+ VENV_PATH="$HOME/arl" \
143
+ INSTALL_TRAIN_DEPS=1 \
144
+ TRAIN_ENV_CONFIG_PATH="config/shared_config.json" \
145
+ TRAIN_SELF_PLAY_CONFIG_PATH="config/self_play_training_hf_a10g_smoke.json" \
146
+ TRAIN_SELF_PLAY_OUTPUT_DIR="artifacts/self_play_server" \
147
+ bash scripts/train_self_play_standalone.sh
148
+ ```
149
+
150
+ Useful environment variables:
151
+
152
+ - `BOOTSTRAP_VENV=1`: create the virtualenv automatically if it does not exist yet.
153
+ - `TRAIN_SELF_PLAY_ROUNDS=2`: override the round count without editing JSON.
154
+ - `RUN_SELF_PLAY_DRY_RUN=1`: skip GRPO updates and only materialize artifacts.
155
+ - `TRAIN_SETUP_COMMAND='python -m pip install flash-attn --no-build-isolation'`: run any host-specific setup before training.
pyproject.toml CHANGED
@@ -22,6 +22,7 @@ train = [
22
  "accelerate>=0.33.0",
23
  "trl>=0.15.0",
24
  "peft>=0.11.0",
 
25
  "pillow",
26
  "torchvision",
27
  "wandb",
@@ -30,6 +31,7 @@ train = [
30
  [project.scripts]
31
  osint-env = "osint_env.cli:main"
32
  server = "osint_env.server_entry:main"
 
33
 
34
  [build-system]
35
  requires = ["setuptools>=68", "wheel"]
 
22
  "accelerate>=0.33.0",
23
  "trl>=0.15.0",
24
  "peft>=0.11.0",
25
+ "huggingface_hub>=0.34.0",
26
  "pillow",
27
  "torchvision",
28
  "wandb",
 
31
  [project.scripts]
32
  osint-env = "osint_env.cli:main"
33
  server = "osint_env.server_entry:main"
34
+ osint-env-launch-hf-job = "osint_env.training.hf_jobs:main"
35
 
36
  [build-system]
37
  requires = ["setuptools>=68", "wheel"]
scripts/space_start.sh CHANGED
@@ -10,23 +10,46 @@ _is_true() {
10
 
11
  ENV_CONFIG_PATH="${TRAIN_ENV_CONFIG_PATH:-config/shared_config.json}"
12
  TRAIN_CONFIG_PATH="${TRAIN_SELF_PLAY_CONFIG_PATH:-config/self_play_training_hf_a10g_smoke.json}"
 
13
  RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-0}"
14
  DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
 
 
 
 
 
 
 
 
15
 
16
- if _is_true "$RUN_FLAG"; then
17
- echo "[space_start] RUN_SELF_PLAY_TRAINING enabled."
18
- echo "[space_start] Training start: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
19
- echo "[space_start] Env config: ${ENV_CONFIG_PATH}"
20
- echo "[space_start] Train config: ${TRAIN_CONFIG_PATH}"
21
  if _is_true "$DRY_RUN_FLAG"; then
22
  echo "[space_start] Running self-play in dry-run mode."
23
- osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" --dry-run
 
24
  else
25
  echo "[space_start] Running self-play training."
26
- osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}"
 
27
  fi
 
28
  echo "[space_start] Self-play command completed."
29
  echo "[space_start] Training end: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  else
31
  echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
32
  fi
 
10
 
11
  ENV_CONFIG_PATH="${TRAIN_ENV_CONFIG_PATH:-config/shared_config.json}"
12
  TRAIN_CONFIG_PATH="${TRAIN_SELF_PLAY_CONFIG_PATH:-config/self_play_training_hf_a10g_smoke.json}"
13
+ TRAIN_OUTPUT_DIR="${TRAIN_SELF_PLAY_OUTPUT_DIR:-}"
14
  RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-0}"
15
  DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
16
+ BACKGROUND_FLAG="${RUN_SELF_PLAY_BACKGROUND:-1}"
17
+
18
+ _train_self_play() {
19
+ if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
20
+ OUTPUT_ARG="--train-output-dir ${TRAIN_OUTPUT_DIR}"
21
+ else
22
+ OUTPUT_ARG=""
23
+ fi
24
 
 
 
 
 
 
25
  if _is_true "$DRY_RUN_FLAG"; then
26
  echo "[space_start] Running self-play in dry-run mode."
27
+ # shellcheck disable=SC2086
28
+ osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG} --dry-run
29
  else
30
  echo "[space_start] Running self-play training."
31
+ # shellcheck disable=SC2086
32
+ osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG}
33
  fi
34
+
35
  echo "[space_start] Self-play command completed."
36
  echo "[space_start] Training end: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
37
+ }
38
+
39
+ if _is_true "$RUN_FLAG"; then
40
+ echo "[space_start] RUN_SELF_PLAY_TRAINING enabled."
41
+ echo "[space_start] Training start: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
42
+ echo "[space_start] Env config: ${ENV_CONFIG_PATH}"
43
+ echo "[space_start] Train config: ${TRAIN_CONFIG_PATH}"
44
+ if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
45
+ echo "[space_start] Train output dir: ${TRAIN_OUTPUT_DIR}"
46
+ fi
47
+ if _is_true "$BACKGROUND_FLAG"; then
48
+ echo "[space_start] Launching self-play in background so the Space API can stay online."
49
+ _train_self_play &
50
+ else
51
+ _train_self_play
52
+ fi
53
  else
54
  echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
55
  fi
src/osint_env/agents/single_agent.py CHANGED
@@ -1,8 +1,11 @@
1
  from __future__ import annotations
2
 
 
 
3
  from osint_env.domain.models import Action, ActionType
4
  from osint_env.env.environment import OSINTEnvironment
5
  from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
 
6
 
7
 
8
  class SingleAgentRunner:
@@ -15,14 +18,31 @@ class SingleAgentRunner:
15
  done = False
16
  info = {}
17
  while not done:
18
- messages = [{"role": "system", "content": f"question: {obs.task['question']}"}]
19
- tools = []
 
 
 
 
 
 
 
 
 
20
  try:
21
  llm_resp = self.llm.generate(messages, tools)
22
  planned_calls = llm_resp.tool_calls[:2]
23
  except Exception:
24
  planned_calls = []
25
 
 
 
 
 
 
 
 
 
26
  for call in planned_calls:
27
  obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
28
  if done:
@@ -39,3 +59,16 @@ class SingleAgentRunner:
39
  if token.startswith("alias_") or token.startswith("user_"):
40
  return token
41
  return "unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import re
4
+
5
  from osint_env.domain.models import Action, ActionType
6
  from osint_env.env.environment import OSINTEnvironment
7
  from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
8
+ from osint_env.platforms.tool_schemas import build_lookup_tools
9
 
10
 
11
  class SingleAgentRunner:
 
18
  done = False
19
  info = {}
20
  while not done:
21
+ messages = [
22
+ {
23
+ "role": "system",
24
+ "content": (
25
+ f"question: {obs.task['question']}\n"
26
+ f"shared_context_available: {bool(obs.task.get('shared_context_available', False))}\n"
27
+ "Use lookup tools to gather evidence before answering."
28
+ ),
29
+ }
30
+ ]
31
+ tools = build_lookup_tools()
32
  try:
33
  llm_resp = self.llm.generate(messages, tools)
34
  planned_calls = llm_resp.tool_calls[:2]
35
  except Exception:
36
  planned_calls = []
37
 
38
+ if not planned_calls and bool(obs.task.get("shared_context_available", False)):
39
+ planned_calls = [
40
+ {
41
+ "tool_name": "search_shared_context",
42
+ "args": {"query": self._shared_context_query(obs.task["question"]), "k": 5},
43
+ }
44
+ ]
45
+
46
  for call in planned_calls:
47
  obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
48
  if done:
 
59
  if token.startswith("alias_") or token.startswith("user_"):
60
  return token
61
  return "unknown"
62
+
63
+ @staticmethod
64
+ def _shared_context_query(question: str) -> str:
65
+ id_match = re.search(r"\b(?:alias|user|post|thr|thread|org|loc|event)_[A-Za-z0-9_]+\b", question)
66
+ if id_match:
67
+ return id_match.group(0)
68
+ path_match = re.search(r"relation path\s+(.+?),\s*which entity", question, flags=re.IGNORECASE)
69
+ if path_match:
70
+ first_relation = path_match.group(1).split("->", 1)[0].strip()
71
+ if first_relation:
72
+ return first_relation
73
+ tokens = re.findall(r"[A-Za-z0-9_]+", question)
74
+ return tokens[0] if tokens else question
src/osint_env/agents/swarm_agent.py CHANGED
@@ -7,6 +7,7 @@ from osint_env.domain.models import Action, ActionType
7
  from osint_env.env.environment import OSINTEnvironment
8
  from osint_env.env.spawn_reward_hooks import critical_steps, parl_style_spawn_reward
9
  from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
 
10
 
11
 
12
  class SwarmAgentRunner:
@@ -135,12 +136,13 @@ class SwarmAgentRunner:
135
  "content": (
136
  f"question: {obs.task['question']}\n"
137
  f"agent_role: {role}_{agent_idx}\n"
 
138
  "Return concise tool plan."
139
  ),
140
  }
141
  ]
142
  try:
143
- response = self.llm.generate(messages, tools=[])
144
  except Exception:
145
  response = None
146
 
@@ -160,6 +162,11 @@ class SwarmAgentRunner:
160
  return calls
161
 
162
  question = str(obs.task.get("question", "")).lower()
 
 
 
 
 
163
  if role == "explorer":
164
  if "event" in question:
165
  return [{"tool_name": "search_threads", "args": {"topic": "security"}}]
@@ -182,6 +189,19 @@ class SwarmAgentRunner:
182
 
183
  return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def _edge_plan(self, agent_idx: int) -> dict[str, Any] | None:
186
  if self.env.state is None or not self.env.state.task.supporting_edges:
187
  return None
 
7
  from osint_env.env.environment import OSINTEnvironment
8
  from osint_env.env.spawn_reward_hooks import critical_steps, parl_style_spawn_reward
9
  from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
10
+ from osint_env.platforms.tool_schemas import build_lookup_tools
11
 
12
 
13
  class SwarmAgentRunner:
 
136
  "content": (
137
  f"question: {obs.task['question']}\n"
138
  f"agent_role: {role}_{agent_idx}\n"
139
+ f"shared_context_available: {bool(obs.task.get('shared_context_available', False))}\n"
140
  "Return concise tool plan."
141
  ),
142
  }
143
  ]
144
  try:
145
+ response = self.llm.generate(messages, tools=build_lookup_tools())
146
  except Exception:
147
  response = None
148
 
 
162
  return calls
163
 
164
  question = str(obs.task.get("question", "")).lower()
165
+ shared_context_available = bool(obs.task.get("shared_context_available", False))
166
+ shared_query = self._shared_context_query(str(obs.task.get("question", "")))
167
+ if shared_context_available and role in {"explorer", "reasoner"}:
168
+ return [{"tool_name": "search_shared_context", "args": {"query": shared_query, "k": 5}}]
169
+
170
  if role == "explorer":
171
  if "event" in question:
172
  return [{"tool_name": "search_threads", "args": {"topic": "security"}}]
 
189
 
190
  return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
191
 
192
+ @staticmethod
193
+ def _shared_context_query(question: str) -> str:
194
+ id_match = re.search(r"\b(?:alias|user|post|thr|thread|org|loc|event)_[A-Za-z0-9_]+\b", question)
195
+ if id_match:
196
+ return id_match.group(0)
197
+ path_match = re.search(r"relation path\s+(.+?),\s*which entity", question, flags=re.IGNORECASE)
198
+ if path_match:
199
+ first_relation = path_match.group(1).split("->", 1)[0].strip()
200
+ if first_relation:
201
+ return first_relation
202
+ tokens = re.findall(r"[A-Za-z0-9_]+", question)
203
+ return tokens[0] if tokens else question
204
+
205
  def _edge_plan(self, agent_idx: int) -> dict[str, Any] | None:
206
  if self.env.state is None or not self.env.state.task.supporting_edges:
207
  return None
src/osint_env/baselines/openai_runner.py CHANGED
@@ -12,6 +12,7 @@ from osint_env.env.environment import OSINTEnvironment
12
  from osint_env.env.reward import compute_graph_f1
13
  from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard
14
  from osint_env.eval.metrics import EvalMetrics
 
15
  from osint_env.viz import export_dashboard
16
 
17
 
@@ -50,123 +51,6 @@ class OpenAIBaselineConfig:
50
  max_steps: int = 8
51
  seed: int | None = 7
52
  append_leaderboard: bool = True
53
-
54
-
55
- def _tool_schema(
56
- name: str,
57
- description: str,
58
- properties: dict[str, Any],
59
- required: list[str],
60
- ) -> dict[str, Any]:
61
- return {
62
- "type": "function",
63
- "function": {
64
- "name": name,
65
- "description": description,
66
- "parameters": {
67
- "type": "object",
68
- "properties": properties,
69
- "required": required,
70
- "additionalProperties": False,
71
- },
72
- },
73
- }
74
-
75
-
76
- def build_action_tools() -> list[dict[str, Any]]:
77
- return [
78
- _tool_schema(
79
- "search_posts",
80
- "Search microblog posts by substring over post text, post id, author id, canonical user id, or referenced entity ids/names.",
81
- {"query": {"type": "string", "description": "Substring to search for in post text."}},
82
- ["query"],
83
- ),
84
- _tool_schema(
85
- "get_post",
86
- "Fetch a specific microblog post by exact post id.",
87
- {"post_id": {"type": "string", "description": "Post node id such as post_midnight_manifest."}},
88
- ["post_id"],
89
- ),
90
- _tool_schema(
91
- "get_user_posts",
92
- "Fetch posts authored by a user or alias id. Alias ids are resolved to the canonical user and vice versa.",
93
- {"user_id": {"type": "string", "description": "User or alias node id."}},
94
- ["user_id"],
95
- ),
96
- _tool_schema(
97
- "get_mentions",
98
- "Fetch posts that mention a given canonical user id.",
99
- {"user_id": {"type": "string", "description": "Canonical user node id."}},
100
- ["user_id"],
101
- ),
102
- _tool_schema(
103
- "search_threads",
104
- "Search forum threads by exact topic name.",
105
- {"topic": {"type": "string", "description": "Thread topic such as security or ai."}},
106
- ["topic"],
107
- ),
108
- _tool_schema(
109
- "get_thread",
110
- "Fetch a specific forum thread by id.",
111
- {"thread_id": {"type": "string", "description": "Thread node id."}},
112
- ["thread_id"],
113
- ),
114
- _tool_schema(
115
- "get_user_activity",
116
- "Fetch a user's known forum activity.",
117
- {"user_id": {"type": "string", "description": "Canonical user node id."}},
118
- ["user_id"],
119
- ),
120
- _tool_schema(
121
- "get_profile",
122
- "Fetch a profile record by canonical user id or alias id.",
123
- {"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
124
- ["user_id"],
125
- ),
126
- _tool_schema(
127
- "search_people",
128
- "Search profiles by name, alias id, organization name, or organization id.",
129
- {
130
- "name": {"type": "string", "description": "Optional name substring.", "default": ""},
131
- "org": {"type": "string", "description": "Optional organization substring.", "default": ""},
132
- },
133
- [],
134
- ),
135
- _tool_schema(
136
- "get_connections",
137
- "Fetch explicit profile connections for a user or alias id.",
138
- {"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
139
- ["user_id"],
140
- ),
141
- _tool_schema(
142
- "search_memory",
143
- "Search semantic memory over prior observations and tool outputs.",
144
- {
145
- "query": {"type": "string", "description": "Memory retrieval query."},
146
- "k": {"type": "integer", "description": "Top-k matches.", "default": 5},
147
- },
148
- ["query"],
149
- ),
150
- _tool_schema(
151
- "add_edge",
152
- "Add a supported graph edge to the working memory graph.",
153
- {
154
- "src": {"type": "string"},
155
- "rel": {"type": "string"},
156
- "dst": {"type": "string"},
157
- "confidence": {"type": "number", "default": 1.0},
158
- },
159
- ["src", "rel", "dst"],
160
- ),
161
- _tool_schema(
162
- "submit_answer",
163
- "Finish the episode by submitting the exact node id answer.",
164
- {"answer": {"type": "string", "description": "Exact node id answer for the task."}},
165
- ["answer"],
166
- ),
167
- ]
168
-
169
-
170
  def _message_text(message: Any) -> str:
171
  content = getattr(message, "content", "")
172
  if isinstance(content, str):
 
12
  from osint_env.env.reward import compute_graph_f1
13
  from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard
14
  from osint_env.eval.metrics import EvalMetrics
15
+ from osint_env.platforms.tool_schemas import build_action_tools
16
  from osint_env.viz import export_dashboard
17
 
18
 
 
51
  max_steps: int = 8
52
  seed: int | None = 7
53
  append_leaderboard: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def _message_text(message: Any) -> str:
55
  content = getattr(message, "content", "")
56
  if isinstance(content, str):
src/osint_env/env/environment.py CHANGED
@@ -137,6 +137,10 @@ class OSINTEnvironment(Env):
137
  top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
138
  results = self.semantic_memory.search(query=query, k=max(1, top_k)) if query else []
139
  output = {"results": results, "count": len(results)}
 
 
 
 
140
  else:
141
  output = self.tools.call(tool_name, args)
142
  except Exception as exc:
@@ -207,6 +211,67 @@ class OSINTEnvironment(Env):
207
  matches = sum(1 for token in clues if token in haystack)
208
  return matches / len(clues)
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  def _accumulate_reward_components(self, values: dict[str, float]) -> None:
211
  if self.state is None:
212
  return
@@ -218,11 +283,17 @@ class OSINTEnvironment(Env):
218
  raise RuntimeError("State is not initialized.")
219
  metadata = dict(self.state.task.metadata or {})
220
  grader = metadata.get("grader") if isinstance(metadata.get("grader"), dict) else None
 
221
  task_payload = {
222
  "task_id": self.state.task.task_id,
223
  "task_type": self.state.task.task_type,
224
  "question": self.state.task.question,
225
  "difficulty": self.state.difficulty,
 
 
 
 
 
226
  "grader": (
227
  dict(grader)
228
  if grader is not None
 
137
  top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
138
  results = self.semantic_memory.search(query=query, k=max(1, top_k)) if query else []
139
  output = {"results": results, "count": len(results)}
140
+ elif tool_name == "search_shared_context":
141
+ query = str(args.get("query", "")).strip()
142
+ top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
143
+ output = self._search_shared_context(query=query, k=max(1, top_k))
144
  else:
145
  output = self.tools.call(tool_name, args)
146
  except Exception as exc:
 
211
  matches = sum(1 for token in clues if token in haystack)
212
  return matches / len(clues)
213
 
214
+ def _task_shared_context(self) -> dict[str, Any]:
215
+ if self.state is None:
216
+ return {"nodes": [], "edges": []}
217
+ metadata = dict(self.state.task.metadata or {})
218
+ canonical_graph = metadata.get("canonical_graph")
219
+ if isinstance(canonical_graph, dict):
220
+ return {
221
+ "nodes": list(canonical_graph.get("nodes", [])),
222
+ "edges": list(canonical_graph.get("edges", [])),
223
+ }
224
+
225
+ nodes = sorted({edge.src for edge in self.state.task.supporting_edges} | {edge.dst for edge in self.state.task.supporting_edges})
226
+ edges = [
227
+ {
228
+ "src": edge.src,
229
+ "rel": edge.rel,
230
+ "dst": edge.dst,
231
+ "confidence": float(edge.confidence),
232
+ }
233
+ for edge in self.state.task.supporting_edges
234
+ ]
235
+ return {"nodes": nodes, "edges": edges}
236
+
237
+ def _search_shared_context(self, query: str, k: int = 5) -> dict[str, Any]:
238
+ shared_context = self._task_shared_context()
239
+ needle = str(query or "").strip().lower()
240
+ results: list[dict[str, Any]] = []
241
+
242
+ for node_id in shared_context.get("nodes", []):
243
+ token = str(node_id).strip()
244
+ if not token:
245
+ continue
246
+ if needle and needle not in token.lower():
247
+ continue
248
+ results.append({"type": "node", "node_id": token})
249
+
250
+ for edge in shared_context.get("edges", []):
251
+ if not isinstance(edge, dict):
252
+ continue
253
+ src = str(edge.get("src", "")).strip()
254
+ rel = str(edge.get("rel", "")).strip()
255
+ dst = str(edge.get("dst", "")).strip()
256
+ haystack = " ".join(part for part in (src, rel, dst) if part).lower()
257
+ if needle and needle not in haystack:
258
+ continue
259
+ results.append(
260
+ {
261
+ "type": "edge",
262
+ "src": src,
263
+ "rel": rel,
264
+ "dst": dst,
265
+ "confidence": float(edge.get("confidence", 1.0)),
266
+ }
267
+ )
268
+
269
+ return {
270
+ "results": results[: max(1, int(k))],
271
+ "count": len(results),
272
+ "shared_context_available": bool(shared_context.get("nodes") or shared_context.get("edges")),
273
+ }
274
+
275
  def _accumulate_reward_components(self, values: dict[str, float]) -> None:
276
  if self.state is None:
277
  return
 
283
  raise RuntimeError("State is not initialized.")
284
  metadata = dict(self.state.task.metadata or {})
285
  grader = metadata.get("grader") if isinstance(metadata.get("grader"), dict) else None
286
+ shared_context = self._task_shared_context()
287
  task_payload = {
288
  "task_id": self.state.task.task_id,
289
  "task_type": self.state.task.task_type,
290
  "question": self.state.task.question,
291
  "difficulty": self.state.difficulty,
292
+ "shared_context_available": bool(shared_context.get("nodes") or shared_context.get("edges")),
293
+ "shared_context_size": {
294
+ "nodes": len(shared_context.get("nodes", [])),
295
+ "edges": len(shared_context.get("edges", [])),
296
+ },
297
  "grader": (
298
  dict(grader)
299
  if grader is not None
src/osint_env/platforms/tool_schemas.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def _tool_schema(
7
+ name: str,
8
+ description: str,
9
+ properties: dict[str, Any],
10
+ required: list[str],
11
+ ) -> dict[str, Any]:
12
+ return {
13
+ "type": "function",
14
+ "function": {
15
+ "name": name,
16
+ "description": description,
17
+ "parameters": {
18
+ "type": "object",
19
+ "properties": properties,
20
+ "required": required,
21
+ "additionalProperties": False,
22
+ },
23
+ },
24
+ }
25
+
26
+
27
+ def build_lookup_tools() -> list[dict[str, Any]]:
28
+ return [
29
+ _tool_schema(
30
+ "search_posts",
31
+ "Search microblog posts by substring over post text, post id, author id, canonical user id, or referenced entity ids/names.",
32
+ {"query": {"type": "string", "description": "Substring to search for in post text."}},
33
+ ["query"],
34
+ ),
35
+ _tool_schema(
36
+ "get_post",
37
+ "Fetch a specific microblog post by exact post id.",
38
+ {"post_id": {"type": "string", "description": "Post node id such as post_midnight_manifest."}},
39
+ ["post_id"],
40
+ ),
41
+ _tool_schema(
42
+ "get_user_posts",
43
+ "Fetch posts authored by a user or alias id. Alias ids are resolved to the canonical user and vice versa.",
44
+ {"user_id": {"type": "string", "description": "User or alias node id."}},
45
+ ["user_id"],
46
+ ),
47
+ _tool_schema(
48
+ "get_mentions",
49
+ "Fetch posts that mention a given canonical user id.",
50
+ {"user_id": {"type": "string", "description": "Canonical user node id."}},
51
+ ["user_id"],
52
+ ),
53
+ _tool_schema(
54
+ "search_threads",
55
+ "Search forum threads by exact topic name.",
56
+ {"topic": {"type": "string", "description": "Thread topic such as security or ai."}},
57
+ ["topic"],
58
+ ),
59
+ _tool_schema(
60
+ "get_thread",
61
+ "Fetch a specific forum thread by id.",
62
+ {"thread_id": {"type": "string", "description": "Thread node id."}},
63
+ ["thread_id"],
64
+ ),
65
+ _tool_schema(
66
+ "get_user_activity",
67
+ "Fetch a user's known forum activity.",
68
+ {"user_id": {"type": "string", "description": "Canonical user node id."}},
69
+ ["user_id"],
70
+ ),
71
+ _tool_schema(
72
+ "get_profile",
73
+ "Fetch a profile record by canonical user id or alias id.",
74
+ {"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
75
+ ["user_id"],
76
+ ),
77
+ _tool_schema(
78
+ "search_people",
79
+ "Search profiles by name, alias id, organization name, or organization id.",
80
+ {
81
+ "name": {"type": "string", "description": "Optional name substring.", "default": ""},
82
+ "org": {"type": "string", "description": "Optional organization substring.", "default": ""},
83
+ },
84
+ [],
85
+ ),
86
+ _tool_schema(
87
+ "get_connections",
88
+ "Fetch explicit profile connections for a user or alias id.",
89
+ {"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
90
+ ["user_id"],
91
+ ),
92
+ _tool_schema(
93
+ "search_memory",
94
+ "Search semantic memory over prior observations and tool outputs.",
95
+ {
96
+ "query": {"type": "string", "description": "Memory retrieval query."},
97
+ "k": {"type": "integer", "description": "Top-k matches.", "default": 5},
98
+ },
99
+ ["query"],
100
+ ),
101
+ _tool_schema(
102
+ "search_shared_context",
103
+ "Search the task-local shared context graph carried with the current question.",
104
+ {
105
+ "query": {"type": "string", "description": "Substring query over shared-context node ids and edge fields."},
106
+ "k": {"type": "integer", "description": "Maximum number of node/edge hits to return.", "default": 5},
107
+ },
108
+ ["query"],
109
+ ),
110
+ ]
111
+
112
+
113
+ def build_action_tools() -> list[dict[str, Any]]:
114
+ return build_lookup_tools() + [
115
+ _tool_schema(
116
+ "add_edge",
117
+ "Add a supported graph edge to the working memory graph.",
118
+ {
119
+ "src": {"type": "string"},
120
+ "rel": {"type": "string"},
121
+ "dst": {"type": "string"},
122
+ "confidence": {"type": "number", "default": 1.0},
123
+ },
124
+ ["src", "rel", "dst"],
125
+ ),
126
+ _tool_schema(
127
+ "submit_answer",
128
+ "Finish the episode by submitting the exact node id answer.",
129
+ {"answer": {"type": "string", "description": "Exact node id answer for the task."}},
130
+ ["answer"],
131
+ ),
132
+ ]
src/osint_env/training/__init__.py CHANGED
@@ -11,6 +11,7 @@ from osint_env.training.config import (
11
  SwarmV2ValidationConfig,
12
  load_self_play_config,
13
  )
 
14
  from osint_env.training.self_play import run_adversarial_self_play
15
 
16
  __all__ = [
@@ -23,5 +24,6 @@ __all__ = [
23
  "SwarmV2SwarmConfig",
24
  "SwarmV2ValidationConfig",
25
  "load_self_play_config",
 
26
  "run_adversarial_self_play",
27
  ]
 
11
  SwarmV2ValidationConfig,
12
  load_self_play_config,
13
  )
14
+ from osint_env.training.hf_jobs import launch_hf_self_play_job
15
  from osint_env.training.self_play import run_adversarial_self_play
16
 
17
  __all__ = [
 
24
  "SwarmV2SwarmConfig",
25
  "SwarmV2ValidationConfig",
26
  "load_self_play_config",
27
+ "launch_hf_self_play_job",
28
  "run_adversarial_self_play",
29
  ]
src/osint_env/training/config.py CHANGED
@@ -11,7 +11,7 @@ class KimiGRPOPhaseConfig:
11
  """Configuration for one GRPO phase in the alternating self-play loop."""
12
 
13
  model_name_or_path: str = "Qwen/Qwen2.5-0.5B-Instruct"
14
- learning_rate: float = 1e-6
15
  max_steps: int = 64
16
  per_device_train_batch_size: int = 2
17
  gradient_accumulation_steps: int = 4
@@ -27,7 +27,17 @@ class KimiGRPOPhaseConfig:
27
  scale_rewards: str = "none"
28
  logging_steps: int = 10
29
  save_steps: int = 50
 
30
  output_subdir: str = "phase"
 
 
 
 
 
 
 
 
 
31
  use_vllm: bool = False
32
  vllm_mode: str = "colocate"
33
 
@@ -36,10 +46,10 @@ class KimiGRPOPhaseConfig:
36
  class GeneratorRewardWeights:
37
  """Weighted components for adversarial task-generator reward."""
38
 
39
- validity: float = 0.35
40
- hardness: float = 0.45
41
- diversity: float = 0.10
42
- consistency: float = 0.10
43
 
44
 
45
  @dataclass(slots=True)
@@ -130,14 +140,25 @@ class SelfPlayTrainingConfig:
130
  max_graph_context_edges: int = 100
131
  max_support_edges: int = 8
132
  answerer_judge_max_new_tokens: int = 48
 
 
 
133
  generator_reward_weights: GeneratorRewardWeights = field(default_factory=GeneratorRewardWeights)
134
  lora: LoraTuningConfig = field(default_factory=LoraTuningConfig)
135
  swarm_v2: SwarmV2Config = field(default_factory=SwarmV2Config)
136
  generator_phase: KimiGRPOPhaseConfig = field(
137
- default_factory=lambda: KimiGRPOPhaseConfig(output_subdir="generator")
 
 
 
 
138
  )
139
  answerer_phase: KimiGRPOPhaseConfig = field(
140
- default_factory=lambda: KimiGRPOPhaseConfig(output_subdir="answerer")
 
 
 
 
141
  )
142
 
143
 
@@ -224,6 +245,42 @@ def _parse_phase(data: dict[str, Any], fallback: KimiGRPOPhaseConfig) -> KimiGRP
224
  logging_steps=_parse_int(data.get("logging_steps"), fallback.logging_steps, floor=1),
225
  save_steps=_parse_int(data.get("save_steps"), fallback.save_steps, floor=1),
226
  output_subdir=str(data.get("output_subdir", fallback.output_subdir)).strip() or fallback.output_subdir,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  use_vllm=_parse_bool(data.get("use_vllm"), fallback.use_vllm),
228
  vllm_mode=str(data.get("vllm_mode", fallback.vllm_mode)).strip() or fallback.vllm_mode,
229
  )
@@ -231,10 +288,10 @@ def _parse_phase(data: dict[str, Any], fallback: KimiGRPOPhaseConfig) -> KimiGRP
231
 
232
  def _parse_generator_weights(data: dict[str, Any]) -> GeneratorRewardWeights:
233
  return GeneratorRewardWeights(
234
- validity=_parse_float(data.get("validity"), 0.35),
235
- hardness=_parse_float(data.get("hardness"), 0.45),
236
- diversity=_parse_float(data.get("diversity"), 0.10),
237
- consistency=_parse_float(data.get("consistency"), 0.10),
238
  )
239
 
240
 
@@ -420,6 +477,21 @@ def load_self_play_config(path: str | Path | None) -> SelfPlayTrainingConfig:
420
  defaults.answerer_judge_max_new_tokens,
421
  floor=1,
422
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  generator_reward_weights=_parse_generator_weights(
424
  _as_dict(payload.get("generator_reward_weights"))
425
  ),
 
11
  """Configuration for one GRPO phase in the alternating self-play loop."""
12
 
13
  model_name_or_path: str = "Qwen/Qwen2.5-0.5B-Instruct"
14
+ learning_rate: float = 3e-6
15
  max_steps: int = 64
16
  per_device_train_batch_size: int = 2
17
  gradient_accumulation_steps: int = 4
 
27
  scale_rewards: str = "none"
28
  logging_steps: int = 10
29
  save_steps: int = 50
30
+ save_total_limit: int = 2
31
  output_subdir: str = "phase"
32
+ optim: str = "adamw_torch_fused"
33
+ bf16: bool = True
34
+ tf32: bool = True
35
+ gradient_checkpointing: bool = False
36
+ dataloader_num_workers: int = 2
37
+ dataloader_persistent_workers: bool = True
38
+ dataloader_prefetch_factor: int = 2
39
+ generation_batch_size: int = 8
40
+ max_prompt_length: int = 1024
41
  use_vllm: bool = False
42
  vllm_mode: str = "colocate"
43
 
 
46
  class GeneratorRewardWeights:
47
  """Weighted components for adversarial task-generator reward."""
48
 
49
+ validity: float = 0.45
50
+ hardness: float = 0.20
51
+ diversity: float = 0.15
52
+ consistency: float = 0.20
53
 
54
 
55
  @dataclass(slots=True)
 
140
  max_graph_context_edges: int = 100
141
  max_support_edges: int = 8
142
  answerer_judge_max_new_tokens: int = 48
143
+ generated_task_max_new_tokens: int = 512
144
+ post_training_eval_questions: int = 24
145
+ post_training_eval_answer_max_new_tokens: int = 128
146
  generator_reward_weights: GeneratorRewardWeights = field(default_factory=GeneratorRewardWeights)
147
  lora: LoraTuningConfig = field(default_factory=LoraTuningConfig)
148
  swarm_v2: SwarmV2Config = field(default_factory=SwarmV2Config)
149
  generator_phase: KimiGRPOPhaseConfig = field(
150
+ default_factory=lambda: KimiGRPOPhaseConfig(
151
+ output_subdir="generator",
152
+ learning_rate=5e-6,
153
+ max_completion_length=384,
154
+ )
155
  )
156
  answerer_phase: KimiGRPOPhaseConfig = field(
157
+ default_factory=lambda: KimiGRPOPhaseConfig(
158
+ output_subdir="answerer",
159
+ learning_rate=3e-6,
160
+ max_completion_length=192,
161
+ )
162
  )
163
 
164
 
 
245
  logging_steps=_parse_int(data.get("logging_steps"), fallback.logging_steps, floor=1),
246
  save_steps=_parse_int(data.get("save_steps"), fallback.save_steps, floor=1),
247
  output_subdir=str(data.get("output_subdir", fallback.output_subdir)).strip() or fallback.output_subdir,
248
+ optim=str(data.get("optim", fallback.optim)).strip() or fallback.optim,
249
+ bf16=_parse_bool(data.get("bf16"), fallback.bf16),
250
+ tf32=_parse_bool(data.get("tf32"), fallback.tf32),
251
+ gradient_checkpointing=_parse_bool(
252
+ data.get("gradient_checkpointing"),
253
+ fallback.gradient_checkpointing,
254
+ ),
255
+ dataloader_num_workers=_parse_int(
256
+ data.get("dataloader_num_workers"),
257
+ fallback.dataloader_num_workers,
258
+ floor=0,
259
+ ),
260
+ dataloader_persistent_workers=_parse_bool(
261
+ data.get("dataloader_persistent_workers"),
262
+ fallback.dataloader_persistent_workers,
263
+ ),
264
+ dataloader_prefetch_factor=_parse_int(
265
+ data.get("dataloader_prefetch_factor"),
266
+ fallback.dataloader_prefetch_factor,
267
+ floor=1,
268
+ ),
269
+ generation_batch_size=_parse_int(
270
+ data.get("generation_batch_size"),
271
+ fallback.generation_batch_size,
272
+ floor=1,
273
+ ),
274
+ max_prompt_length=_parse_int(
275
+ data.get("max_prompt_length"),
276
+ fallback.max_prompt_length,
277
+ floor=32,
278
+ ),
279
+ save_total_limit=_parse_int(
280
+ data.get("save_total_limit"),
281
+ fallback.save_total_limit,
282
+ floor=1,
283
+ ),
284
  use_vllm=_parse_bool(data.get("use_vllm"), fallback.use_vllm),
285
  vllm_mode=str(data.get("vllm_mode", fallback.vllm_mode)).strip() or fallback.vllm_mode,
286
  )
 
288
 
289
  def _parse_generator_weights(data: dict[str, Any]) -> GeneratorRewardWeights:
290
  return GeneratorRewardWeights(
291
+ validity=_parse_float(data.get("validity"), 0.45),
292
+ hardness=_parse_float(data.get("hardness"), 0.20),
293
+ diversity=_parse_float(data.get("diversity"), 0.15),
294
+ consistency=_parse_float(data.get("consistency"), 0.20),
295
  )
296
 
297
 
 
477
  defaults.answerer_judge_max_new_tokens,
478
  floor=1,
479
  ),
480
+ generated_task_max_new_tokens=_parse_int(
481
+ payload.get("generated_task_max_new_tokens"),
482
+ defaults.generated_task_max_new_tokens,
483
+ floor=32,
484
+ ),
485
+ post_training_eval_questions=_parse_int(
486
+ payload.get("post_training_eval_questions"),
487
+ defaults.post_training_eval_questions,
488
+ floor=1,
489
+ ),
490
+ post_training_eval_answer_max_new_tokens=_parse_int(
491
+ payload.get("post_training_eval_answer_max_new_tokens"),
492
+ defaults.post_training_eval_answer_max_new_tokens,
493
+ floor=1,
494
+ ),
495
  generator_reward_weights=_parse_generator_weights(
496
  _as_dict(payload.get("generator_reward_weights"))
497
  ),
src/osint_env/training/hf_jobs.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import shlex
7
+ import time
8
+ from typing import Any
9
+
10
+ DEFAULT_HF_JOB_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"
11
+
12
+
13
+ def _is_true(value: str | None) -> bool:
14
+ token = str(value or "").strip().lower()
15
+ return token in {"1", "true", "yes", "y", "on"}
16
+
17
+
18
+ def _default_train_output_dir(bucket_name: str | None, run_name: str) -> str:
19
+ if bucket_name:
20
+ return f"/training-outputs/{run_name}"
21
+ return f"artifacts/{run_name}"
22
+
23
+
24
+ def _require_hf_token(value: str | None) -> str:
25
+ token = str(value or "").strip()
26
+ if not token:
27
+ raise RuntimeError(
28
+ "HF_TOKEN is required to launch a Hugging Face Job. "
29
+ "Set HF_TOKEN in your environment or pass --hf-token."
30
+ )
31
+ return token
32
+
33
+
34
+ def _resolve_job_image(job_image: str | None, space_id: str | None) -> str:
35
+ image = str(job_image or "").strip()
36
+ if image:
37
+ return image
38
+ space = str(space_id or "").strip()
39
+ if space:
40
+ return f"hf.co/spaces/{space}"
41
+ return DEFAULT_HF_JOB_IMAGE
42
+
43
+
44
+ def _train_self_play_command(
45
+ *,
46
+ env_config_path: str,
47
+ train_config_path: str,
48
+ output_dir: str,
49
+ dry_run: bool,
50
+ ) -> list[str]:
51
+ command = [
52
+ "osint-env",
53
+ "train-self-play",
54
+ "--config",
55
+ env_config_path,
56
+ "--train-config",
57
+ train_config_path,
58
+ "--train-output-dir",
59
+ output_dir,
60
+ ]
61
+ if dry_run:
62
+ command.append("--dry-run")
63
+ return command
64
+
65
+
66
+ def _shell_join(parts: list[str]) -> str:
67
+ return " ".join(shlex.quote(part) for part in parts)
68
+
69
+
70
+ def _build_job_command(
71
+ *,
72
+ env_config_path: str,
73
+ train_config_path: str,
74
+ output_dir: str,
75
+ dry_run: bool,
76
+ repo_url: str,
77
+ repo_ref: str,
78
+ repo_subdir: str,
79
+ setup_command: str,
80
+ ) -> list[str]:
81
+ train_command = _train_self_play_command(
82
+ env_config_path=env_config_path,
83
+ train_config_path=train_config_path,
84
+ output_dir=output_dir,
85
+ dry_run=dry_run,
86
+ )
87
+ repo = str(repo_url).strip()
88
+ if not repo:
89
+ return train_command
90
+
91
+ worktree = "/workspace/osint_env_app"
92
+ clone_command = f"git clone --depth 1 {shlex.quote(repo)} {shlex.quote(worktree)}"
93
+ ref = str(repo_ref).strip()
94
+ if ref:
95
+ clone_command = (
96
+ f"git clone --depth 1 --branch {shlex.quote(ref)} "
97
+ f"{shlex.quote(repo)} {shlex.quote(worktree)}"
98
+ )
99
+
100
+ shell_lines = [
101
+ "set -euo pipefail",
102
+ "export PYTHONUNBUFFERED=1",
103
+ "export PIP_DISABLE_PIP_VERSION_CHECK=1",
104
+ "command -v git >/dev/null 2>&1 || { echo 'git is required when --repo-url is set' >&2; exit 1; }",
105
+ "mkdir -p /workspace",
106
+ clone_command,
107
+ f"cd {shlex.quote(worktree)}",
108
+ ]
109
+ subdir = str(repo_subdir).strip()
110
+ if subdir:
111
+ shell_lines.append(f"cd {shlex.quote(subdir)}")
112
+ shell_lines.extend(
113
+ [
114
+ "python -m pip install --upgrade pip",
115
+ "python -m pip install -e '.[train]'",
116
+ ]
117
+ )
118
+ setup = str(setup_command).strip()
119
+ if setup:
120
+ shell_lines.append(setup)
121
+ shell_lines.append(_shell_join(train_command))
122
+ return ["bash", "-lc", "\n".join(shell_lines)]
123
+
124
+
125
+ def launch_hf_self_play_job(
126
+ *,
127
+ hf_token: str,
128
+ job_image: str,
129
+ env_config_path: str,
130
+ train_config_path: str,
131
+ flavor: str,
132
+ timeout: str,
133
+ output_dir: str,
134
+ space_id: str = "",
135
+ namespace: str = "",
136
+ run_name: str = "",
137
+ dry_run: bool = False,
138
+ wait: bool = False,
139
+ output_bucket: str = "",
140
+ repo_url: str = "",
141
+ repo_ref: str = "",
142
+ repo_subdir: str = "",
143
+ setup_command: str = "",
144
+ ) -> dict[str, Any]:
145
+ try:
146
+ from huggingface_hub import Volume, fetch_job_logs, inspect_job, login, run_job
147
+ except ImportError as exc:
148
+ raise RuntimeError(
149
+ "huggingface_hub is required to launch HF Jobs. "
150
+ "Install dependencies that include huggingface_hub first."
151
+ ) from exc
152
+
153
+ token = _require_hf_token(hf_token)
154
+ image = _resolve_job_image(job_image=job_image, space_id=space_id)
155
+ login(token=token, add_to_git_credential=False)
156
+
157
+ command = _build_job_command(
158
+ env_config_path=env_config_path,
159
+ train_config_path=train_config_path,
160
+ output_dir=output_dir,
161
+ dry_run=dry_run,
162
+ repo_url=repo_url,
163
+ repo_ref=repo_ref,
164
+ repo_subdir=repo_subdir,
165
+ setup_command=setup_command,
166
+ )
167
+
168
+ secrets = {"HF_TOKEN": token}
169
+ for secret_name in ("WANDB_API_KEY", "OPENAI_API_KEY", "GITHUB_TOKEN", "GH_TOKEN"):
170
+ secret_value = str(os.getenv(secret_name, "")).strip()
171
+ if secret_value:
172
+ secrets[secret_name] = secret_value
173
+
174
+ env: dict[str, str] = {
175
+ "PYTHONUNBUFFERED": "1",
176
+ "HF_HUB_ENABLE_HF_TRANSFER": "1",
177
+ }
178
+ if run_name:
179
+ env["OSINT_HF_JOB_RUN_NAME"] = run_name
180
+ for env_name in (
181
+ "WANDB_ENTITY",
182
+ "WANDB_PROJECT",
183
+ "WANDB_RUN_GROUP",
184
+ "OSINT_TRAIN_STRICT_ASSERTS",
185
+ "HF_HOME",
186
+ "TRANSFORMERS_CACHE",
187
+ ):
188
+ env_value = str(os.getenv(env_name, "")).strip()
189
+ if env_value:
190
+ env[env_name] = env_value
191
+
192
+ volumes: list[Any] = []
193
+ if output_bucket:
194
+ volumes.append(Volume(type="bucket", source=output_bucket, mount_path="/training-outputs"))
195
+
196
+ job = run_job(
197
+ image=image,
198
+ command=command,
199
+ flavor=flavor,
200
+ timeout=timeout,
201
+ namespace=namespace or None,
202
+ env=env,
203
+ secrets=secrets,
204
+ volumes=volumes or None,
205
+ )
206
+
207
+ payload: dict[str, Any] = {
208
+ "job_id": str(job.id),
209
+ "job_url": str(job.url),
210
+ "job_image": image,
211
+ "flavor": flavor,
212
+ "timeout": timeout,
213
+ "output_dir": output_dir,
214
+ "output_bucket": output_bucket,
215
+ "repo_url": repo_url,
216
+ "repo_ref": repo_ref,
217
+ "repo_subdir": repo_subdir,
218
+ "space_id_compat": space_id,
219
+ "dry_run": dry_run,
220
+ "waited": False,
221
+ }
222
+
223
+ if wait:
224
+ terminal_states = {"COMPLETED", "ERROR", "CANCELLED", "TIMEOUT"}
225
+ last_stage = ""
226
+ while True:
227
+ info = inspect_job(job_id=job.id)
228
+ stage = str(getattr(getattr(info, "status", None), "stage", "") or "")
229
+ if stage != last_stage:
230
+ print(json.dumps({"job_id": str(job.id), "stage": stage, "url": str(job.url)}))
231
+ last_stage = stage
232
+ if stage in terminal_states:
233
+ payload["waited"] = True
234
+ payload["final_stage"] = stage
235
+ if stage != "COMPLETED":
236
+ payload["logs"] = list(fetch_job_logs(job_id=job.id))
237
+ break
238
+ time.sleep(15)
239
+
240
+ return payload
241
+
242
+
243
+ def build_parser() -> argparse.ArgumentParser:
244
+ parser = argparse.ArgumentParser(
245
+ description="Launch OSINT self-play training as a separate Hugging Face Job on dedicated compute."
246
+ )
247
+ parser.add_argument("--hf-token", default=os.getenv("HF_TOKEN", ""), help="HF token. Defaults to HF_TOKEN env var.")
248
+ parser.add_argument(
249
+ "--job-image",
250
+ default=os.getenv("HF_JOB_IMAGE", ""),
251
+ help=(
252
+ "Docker image for the dedicated training job. "
253
+ f"Defaults to {DEFAULT_HF_JOB_IMAGE!r} unless --space-id is provided."
254
+ ),
255
+ )
256
+ parser.add_argument(
257
+ "--space-id",
258
+ default=os.getenv("HF_SPACE_ID", ""),
259
+ help="Optional compatibility fallback to reuse a Space image, e.g. owner/space-name.",
260
+ )
261
+ parser.add_argument(
262
+ "--env-config",
263
+ default=os.getenv("TRAIN_ENV_CONFIG_PATH", "config/shared_config.json"),
264
+ help="Environment config path inside the training image or checked-out repo.",
265
+ )
266
+ parser.add_argument(
267
+ "--train-config",
268
+ default=os.getenv("TRAIN_SELF_PLAY_CONFIG_PATH", "config/self_play_training_hf_a10g_smoke.json"),
269
+ help="Training config path inside the training image or checked-out repo.",
270
+ )
271
+ parser.add_argument("--flavor", default=os.getenv("HF_JOB_FLAVOR", "a10g-small"))
272
+ parser.add_argument("--timeout", default=os.getenv("HF_JOB_TIMEOUT", "8h"))
273
+ parser.add_argument("--namespace", default=os.getenv("HF_JOB_NAMESPACE", ""))
274
+ parser.add_argument("--run-name", default=os.getenv("HF_JOB_RUN_NAME", "osint-self-play-job"))
275
+ parser.add_argument("--output-bucket", default=os.getenv("HF_JOB_OUTPUT_BUCKET", ""))
276
+ parser.add_argument("--output-dir", default=os.getenv("TRAIN_SELF_PLAY_OUTPUT_DIR", ""))
277
+ parser.add_argument(
278
+ "--repo-url",
279
+ default=os.getenv("HF_JOB_REPO_URL", ""),
280
+ help="Optional git repository URL to clone inside the job before training.",
281
+ )
282
+ parser.add_argument(
283
+ "--repo-ref",
284
+ default=os.getenv("HF_JOB_REPO_REF", ""),
285
+ help="Optional git branch, tag, or commit-ish to check out when --repo-url is used.",
286
+ )
287
+ parser.add_argument(
288
+ "--repo-subdir",
289
+ default=os.getenv("HF_JOB_REPO_SUBDIR", ""),
290
+ help="Optional subdirectory inside the cloned repo that contains pyproject.toml.",
291
+ )
292
+ parser.add_argument(
293
+ "--setup-command",
294
+ default=os.getenv("HF_JOB_SETUP_COMMAND", ""),
295
+ help="Optional shell command to run after install and before training.",
296
+ )
297
+ parser.add_argument("--dry-run", action="store_true", default=_is_true(os.getenv("RUN_SELF_PLAY_DRY_RUN", "")))
298
+ parser.add_argument("--wait", action="store_true", default=_is_true(os.getenv("HF_JOB_WAIT", "")))
299
+ return parser
300
+
301
+
302
+ def main() -> None:
303
+ args = build_parser().parse_args()
304
+ run_name = str(args.run_name).strip() or "osint-self-play-job"
305
+ output_bucket = str(args.output_bucket).strip()
306
+ output_dir = str(args.output_dir).strip() or _default_train_output_dir(output_bucket, run_name)
307
+
308
+ payload = launch_hf_self_play_job(
309
+ hf_token=str(args.hf_token),
310
+ job_image=str(args.job_image),
311
+ env_config_path=str(args.env_config),
312
+ train_config_path=str(args.train_config),
313
+ flavor=str(args.flavor),
314
+ timeout=str(args.timeout),
315
+ output_dir=output_dir,
316
+ space_id=str(args.space_id),
317
+ namespace=str(args.namespace),
318
+ run_name=run_name,
319
+ dry_run=bool(args.dry_run),
320
+ wait=bool(args.wait),
321
+ output_bucket=output_bucket,
322
+ repo_url=str(args.repo_url),
323
+ repo_ref=str(args.repo_ref),
324
+ repo_subdir=str(args.repo_subdir),
325
+ setup_command=str(args.setup_command),
326
+ )
327
+ print(json.dumps(payload, indent=2, sort_keys=True))
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
src/osint_env/training/rewards.py CHANGED
@@ -8,6 +8,7 @@ from functools import lru_cache
8
  from typing import Any
9
 
10
  from osint_env.data.generator import (
 
11
  emit_swarm_v2_question,
12
  enumerate_swarm_v2_neighbors,
13
  select_swarm_v2_answer,
@@ -224,7 +225,7 @@ def _parse_tool_trace(value: Any) -> list[SwarmReplayToolCall]:
224
  continue
225
  tool_name = str(row.get("tool_name", row.get("tool", ""))).strip()
226
  args = row.get("args", {})
227
- output = row.get("output", {})
228
  if not tool_name:
229
  continue
230
  out.append(
@@ -281,6 +282,12 @@ def _coerce_int(value: Any, default: int) -> int:
281
  try:
282
  return int(float(token))
283
  except ValueError:
 
 
 
 
 
 
284
  return default
285
  return default
286
 
@@ -497,11 +504,16 @@ class SwarmV2ReplayValidator:
497
  replayed_edges: list[Edge] = []
498
  replayed_answer = ""
499
  replayed_question = ""
 
 
 
 
500
 
501
- if not candidate.tool_trace:
502
- return ["non_replayable_tool_calls"], replayed_edges, replayed_answer, replayed_question
 
503
 
504
- for call in candidate.tool_trace:
505
  if call.tool_name == "enumerate_neighbors":
506
  node_id = str(call.args.get("node_id", "")).strip()
507
  expected_edge = call.args.get("expected_edge", {})
@@ -520,21 +532,31 @@ class SwarmV2ReplayValidator:
520
  if expected_key not in {(edge.src, edge.rel, edge.dst) for edge in neighbors}:
521
  reasons.append("non_replayable_tool_calls")
522
  elif call.tool_name == "trace_path":
523
- candidate_path = call.args.get("path", candidate.supporting_edges)
524
- replayed_edges = trace_swarm_v2_path(self.graph, candidate_path)
525
  if not replayed_edges:
526
  reasons.append("non_replayable_tool_calls")
527
  elif call.tool_name == "select_answer":
528
- replayed_answer = select_swarm_v2_answer(replayed_edges)
529
- if not replayed_answer:
530
- reasons.append("non_replayable_tool_calls")
531
  elif call.tool_name == "emit_question":
532
- replayed_question = emit_swarm_v2_question(replayed_edges)
533
- if not replayed_question:
534
- reasons.append("non_replayable_tool_calls")
535
  else:
536
  reasons.append("non_replayable_tool_calls")
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  return reasons, replayed_edges, replayed_answer, replayed_question
539
 
540
  def validate(self, candidate: GeneratedTaskCandidate) -> ReplayValidationResult:
@@ -747,37 +769,37 @@ class GeneratorRewardFunction:
747
  # (3) tiny text-level signal so completely-collapsed completions
748
  # differ from completions that at least *attempt* JSON.
749
  reason_penalty = {
750
- "missing_question_or_answer": 0.45,
751
- "malformed_support_edges": 0.30,
752
- "non_replayable_tool_calls": 0.40,
753
- "non_unique_derivation_path": 0.25,
754
- "unseen_nodes_or_edges": 0.30,
755
- "answer_leakage": 0.40,
756
- "duplicate_or_near_duplicate": 0.20,
757
- "context_or_support_budget_overflow": 0.20,
758
  }
759
- penalty = 0.20
760
  for reason in validation_result.reasons:
761
  penalty += reason_penalty.get(reason, 0.10)
762
 
763
  partial_credit = 0.0
764
  if candidate.question:
765
- partial_credit += 0.30
766
  if candidate.answer:
767
- partial_credit += 0.30
768
  if candidate.supporting_edges:
769
- partial_credit += min(0.40, 0.10 * len(candidate.supporting_edges))
770
  if candidate.tool_trace:
771
- partial_credit += min(0.35, 0.08 * len(candidate.tool_trace))
772
  if candidate.subagent_outputs:
773
  partial_credit += 0.10
774
  if candidate.canonical_edges or candidate.canonical_nodes:
775
- partial_credit += 0.10
776
 
777
  text_signal = self._completion_text_signal(completion_text)
778
 
779
  reward = partial_credit - penalty + text_signal
780
- return float(max(-1.8, min(-0.05, reward)))
781
 
782
  @staticmethod
783
  def _completion_text_signal(completion_text: str) -> float:
@@ -930,15 +952,24 @@ class GeneratorRewardFunction:
930
  swarm_diversity = self._swarm_diversity_score(candidate)
931
  context_pressure = self._context_pressure_score(validation_result)
932
  parl_parallel, parl_finish = self._parl_scores(candidate)
 
 
 
 
 
 
 
 
 
 
 
933
 
934
  reward = (
935
- 0.25 # valid JSON/schema
936
- + 0.30 # replayable derivation
937
- + (0.30 * hardness)
938
- + (0.15 * swarm_diversity)
939
- + (0.10 * context_pressure)
940
- + (0.025 * parl_parallel)
941
- + (0.025 * parl_finish)
942
  )
943
  return reward, validation_result
944
 
 
8
  from typing import Any
9
 
10
  from osint_env.data.generator import (
11
+ build_swarm_v2_tool_trace,
12
  emit_swarm_v2_question,
13
  enumerate_swarm_v2_neighbors,
14
  select_swarm_v2_answer,
 
225
  continue
226
  tool_name = str(row.get("tool_name", row.get("tool", ""))).strip()
227
  args = row.get("args", {})
228
+ output = row.get("output", row.get("result", {}))
229
  if not tool_name:
230
  continue
231
  out.append(
 
282
  try:
283
  return int(float(token))
284
  except ValueError:
285
+ match = re.search(r"[-+]?\d+(?:\.\d+)?", token)
286
+ if match:
287
+ try:
288
+ return int(float(match.group(0)))
289
+ except ValueError:
290
+ return default
291
  return default
292
  return default
293
 
 
504
  replayed_edges: list[Edge] = []
505
  replayed_answer = ""
506
  replayed_question = ""
507
+ declared_answer = ""
508
+ declared_question = ""
509
+ tool_trace = list(candidate.tool_trace)
510
+ trace_path_source: Any = candidate.supporting_edges
511
 
512
+ if not tool_trace and candidate.supporting_edges:
513
+ synthesized_trace = build_swarm_v2_tool_trace(self.graph, candidate.supporting_edges)
514
+ tool_trace = _parse_tool_trace(synthesized_trace)
515
 
516
+ for call in tool_trace:
517
  if call.tool_name == "enumerate_neighbors":
518
  node_id = str(call.args.get("node_id", "")).strip()
519
  expected_edge = call.args.get("expected_edge", {})
 
532
  if expected_key not in {(edge.src, edge.rel, edge.dst) for edge in neighbors}:
533
  reasons.append("non_replayable_tool_calls")
534
  elif call.tool_name == "trace_path":
535
+ trace_path_source = call.args.get("path", trace_path_source)
536
+ replayed_edges = trace_swarm_v2_path(self.graph, trace_path_source)
537
  if not replayed_edges:
538
  reasons.append("non_replayable_tool_calls")
539
  elif call.tool_name == "select_answer":
540
+ declared_answer = normalize_answer(str(call.output.get("answer", "")).strip())
 
 
541
  elif call.tool_name == "emit_question":
542
+ declared_question = str(call.output.get("question", "")).strip()
 
 
543
  else:
544
  reasons.append("non_replayable_tool_calls")
545
 
546
+ if not replayed_edges:
547
+ replayed_edges = trace_swarm_v2_path(self.graph, trace_path_source)
548
+ if not replayed_edges and candidate.supporting_edges:
549
+ replayed_edges = trace_swarm_v2_path(self.graph, candidate.supporting_edges)
550
+ if not replayed_edges:
551
+ reasons.append("non_replayable_tool_calls")
552
+ return reasons, replayed_edges, replayed_answer, replayed_question
553
+
554
+ replayed_answer = select_swarm_v2_answer(replayed_edges)
555
+ replayed_question = emit_swarm_v2_question(replayed_edges)
556
+ if declared_answer and declared_answer != normalize_answer(replayed_answer):
557
+ reasons.append("non_replayable_tool_calls")
558
+ if declared_question and declared_question != replayed_question:
559
+ reasons.append("non_replayable_tool_calls")
560
  return reasons, replayed_edges, replayed_answer, replayed_question
561
 
562
  def validate(self, candidate: GeneratedTaskCandidate) -> ReplayValidationResult:
 
769
  # (3) tiny text-level signal so completely-collapsed completions
770
  # differ from completions that at least *attempt* JSON.
771
  reason_penalty = {
772
+ "missing_question_or_answer": 0.35,
773
+ "malformed_support_edges": 0.25,
774
+ "non_replayable_tool_calls": 0.25,
775
+ "non_unique_derivation_path": 0.20,
776
+ "unseen_nodes_or_edges": 0.25,
777
+ "answer_leakage": 0.30,
778
+ "duplicate_or_near_duplicate": 0.15,
779
+ "context_or_support_budget_overflow": 0.15,
780
  }
781
+ penalty = 0.10
782
  for reason in validation_result.reasons:
783
  penalty += reason_penalty.get(reason, 0.10)
784
 
785
  partial_credit = 0.0
786
  if candidate.question:
787
+ partial_credit += 0.25
788
  if candidate.answer:
789
+ partial_credit += 0.25
790
  if candidate.supporting_edges:
791
+ partial_credit += min(0.36, 0.12 * len(candidate.supporting_edges))
792
  if candidate.tool_trace:
793
+ partial_credit += min(0.20, 0.05 * len(candidate.tool_trace))
794
  if candidate.subagent_outputs:
795
  partial_credit += 0.10
796
  if candidate.canonical_edges or candidate.canonical_nodes:
797
+ partial_credit += 0.12
798
 
799
  text_signal = self._completion_text_signal(completion_text)
800
 
801
  reward = partial_credit - penalty + text_signal
802
+ return float(max(-1.25, min(-0.02, reward)))
803
 
804
  @staticmethod
805
  def _completion_text_signal(completion_text: str) -> float:
 
952
  swarm_diversity = self._swarm_diversity_score(candidate)
953
  context_pressure = self._context_pressure_score(validation_result)
954
  parl_parallel, parl_finish = self._parl_scores(candidate)
955
+ hardness_component = max(0.0, min(1.0, (hardness + 0.4) / 1.4))
956
+ consistency_component = max(
957
+ 0.0,
958
+ min(
959
+ 1.0,
960
+ (0.55 * context_pressure)
961
+ + (0.25 * parl_parallel)
962
+ + (0.20 * parl_finish),
963
+ ),
964
+ )
965
+ completion_component = max(0.0, min(1.0, self._completion_text_signal(completion_text) / 0.25))
966
 
967
  reward = (
968
+ self.weights.validity
969
+ + (self.weights.hardness * hardness_component)
970
+ + (self.weights.diversity * swarm_diversity)
971
+ + (self.weights.consistency * consistency_component)
972
+ + (0.05 * completion_component)
 
 
973
  )
974
  return reward, validation_result
975
 
src/osint_env/training/self_play.py CHANGED
@@ -18,6 +18,7 @@ from osint_env.data.generator import (
18
  )
19
  from osint_env.domain.models import Edge, EnvironmentConfig, TaskInstance
20
  from osint_env.env.environment import OSINTEnvironment
 
21
  from osint_env.llm import build_llm_client
22
  from osint_env.training.config import (
23
  KimiGRPOPhaseConfig,
@@ -31,6 +32,8 @@ from osint_env.training.rewards import (
31
  GeneratorRewardFunction,
32
  SwarmV2ReplayValidator,
33
  decode_completion_text,
 
 
34
  parse_generated_task_completion,
35
  )
36
 
@@ -99,6 +102,92 @@ def _edges_from_payload(rows: Any, max_edges: int) -> list[Edge]:
99
  return edges
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def _canonical_example_payload(
104
  graph: Any,
@@ -113,7 +202,6 @@ def _canonical_example_payload(
113
  "answer": "",
114
  "task_type": "swarm_v2_trace",
115
  "supporting_edges": [],
116
- "tool_trace": [],
117
  "subagent_outputs": ["path_agent: no replayable edge"],
118
  "orchestrator": {
119
  "spawn_count": 1,
@@ -126,16 +214,18 @@ def _canonical_example_payload(
126
 
127
  traced_edges = traced_edges[:2]
128
  spawn_count = min(swarm_cfg.max_agents, max(1, len(traced_edges) + 1))
129
- full_tool_trace = build_swarm_v2_tool_trace(graph, traced_edges)
130
  return {
131
  "question": emit_swarm_v2_question(traced_edges),
132
  "answer": select_swarm_v2_answer(traced_edges),
133
  "task_type": f"swarm_v2_{len(traced_edges)}hop_trace",
134
  "supporting_edges": [_edge_payload(edge) for edge in traced_edges],
135
- "tool_trace": full_tool_trace[:4],
136
  "subagent_outputs": [
137
- f"path_agent_{idx}: {edge.src}->{edge.dst}"
138
  for idx, edge in enumerate(traced_edges)
 
 
 
 
139
  ],
140
  "orchestrator": {
141
  "spawn_count": spawn_count,
@@ -176,10 +266,7 @@ def _swarm_v2_answer_prompt(
176
  swarm_cfg: SwarmV2SwarmConfig,
177
  ) -> str:
178
  del swarm_cfg # kept for signature compatibility
179
- compact_context = {
180
- "nodes": list(shared_context.get("nodes", []))[:8],
181
- "edges": list(shared_context.get("edges", []))[:6],
182
- }
183
  return (
184
  "You answer one OSINT graph question using ONLY the shared context.\n"
185
  "Output rules:\n"
@@ -222,21 +309,7 @@ def _build_swarm_v2_answerer_rows(
222
  ) -> list[dict[str, Any]]:
223
  rows: list[dict[str, Any]] = []
224
  for task in tasks:
225
- metadata = dict(task.metadata or {})
226
- canonical_graph = metadata.get("canonical_graph")
227
- if isinstance(canonical_graph, dict):
228
- shared_context = {
229
- "nodes": list(canonical_graph.get("nodes", []))[: cfg.swarm_v2.shared_context.max_nodes],
230
- "edges": list(canonical_graph.get("edges", []))[: cfg.swarm_v2.shared_context.max_edges],
231
- }
232
- else:
233
- deterministic_seed = sum(ord(ch) for ch in task.task_id)
234
- shared_context = _graph_context_for_prompt(
235
- env=env,
236
- max_nodes=cfg.swarm_v2.shared_context.max_nodes,
237
- max_edges=cfg.swarm_v2.shared_context.max_edges,
238
- rng=random.Random(deterministic_seed),
239
- )
240
 
241
  rows.append(
242
  {
@@ -338,31 +411,39 @@ def _swarm_v2_generator_prompt(
338
  anchors = "\n".join(f"- {question}" for question in anchor_questions)
339
  canonical_mode = str(canonical_graph_mode).strip().lower() or "generate"
340
  example_payload = _canonical_example_payload(graph, canonical_candidate, swarm_cfg)
 
 
 
 
 
341
  canonical_instruction = (
342
  "You may propose canonical_graph updates when they improve replayability and keep it graph-grounded."
343
  if canonical_mode == "generate"
344
  else "Reuse the provided canonical candidate as-is; do not add, remove, or modify canonical_graph nodes/edges."
345
  )
346
- canonical_compact = {
347
- "nodes": list(canonical_candidate.get("nodes", []))[:8],
348
- "edges": list(canonical_candidate.get("edges", []))[:6],
349
- }
350
  return (
351
- "You generate ONE OSINT question/answer task as compact JSON.\n"
352
  "Output rules:\n"
353
  "- Return ONLY one JSON object. No markdown. No prose. End with }.\n"
354
- "- Required keys: question, answer, task_type, supporting_edges, tool_trace, "
355
- "subagent_outputs, orchestrator.\n"
356
  "- supporting_edges: non-empty list of {src, rel, dst, confidence}, taken from canonical edges.\n"
357
- "- tool_trace: non-empty list of {tool, args, result} using only "
358
- "enumerate_neighbors|trace_path|select_answer|emit_question.\n"
359
- "- answer = final dst of the trace. question describes the path.\n"
 
360
  "- orchestrator: integer keys spawn_count, finished_subtasks, critical_steps, breadth, depth.\n"
361
  f"- canonical_graph_mode={canonical_mode}: {canonical_instruction}\n"
 
362
  "Example (copy schema, not values):\n"
363
  f"{json.dumps(example_payload, separators=(',', ':'), sort_keys=True)}\n"
 
 
364
  "Canonical candidate (use these edges):\n"
365
  f"{json.dumps(canonical_compact, separators=(',', ':'), sort_keys=True)}\n"
 
 
366
  f"Avoid these prior questions: {anchors}\n"
367
  "JSON:"
368
  )
@@ -410,6 +491,15 @@ def _build_swarm_v2_generator_rows(
410
  "prompt": prompt,
411
  "candidate_id": f"candidate_{idx}",
412
  "canonical_graph_json": json.dumps(canonical_candidate, sort_keys=True),
 
 
 
 
 
 
 
 
 
413
  }
414
  )
415
  canonical_candidates.append(canonical_candidate)
@@ -442,6 +532,16 @@ def _safe_build_grpo_config(
442
  "scale_rewards": str(phase.scale_rewards),
443
  "logging_steps": int(phase.logging_steps),
444
  "save_steps": int(phase.save_steps),
 
 
 
 
 
 
 
 
 
 
445
  "remove_unused_columns": False,
446
  "use_vllm": bool(phase.use_vllm),
447
  "vllm_mode": str(phase.vllm_mode),
@@ -550,16 +650,25 @@ def _train_grpo_phase(
550
 
551
  final_dir = output_dir / "final_model"
552
  trainer.save_model(str(final_dir))
 
 
 
 
553
 
554
  global_step = int(getattr(train_output, "global_step", 0))
555
  training_loss = float(getattr(train_output, "training_loss", 0.0))
556
 
 
557
  result = {
558
  "model_path": str(final_dir),
 
 
 
559
  "global_step": global_step,
560
  "training_loss": training_loss,
561
  "train_rows": len(rows),
562
  "tuning_mode": str(tuning_mode).strip().lower() or "full",
 
563
  }
564
 
565
  log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
@@ -599,7 +708,13 @@ def _train_grpo_phase(
599
  "grad_norm_max": max(grad_norm_values) if grad_norm_values else 0.0,
600
  "entropy_min": min(entropy_values) if entropy_values else 0.0,
601
  "entropy_max": max(entropy_values) if entropy_values else 0.0,
 
602
  "trainable_param_count": trainable_param_count,
 
 
 
 
 
603
  "params_with_grad": params_with_grad,
604
  "nonzero_grad_tensors": nonzero_grad_tensors,
605
  "fingerprint_param_count": len(pre_update_fingerprint),
@@ -721,8 +836,10 @@ def _sample_generated_tasks_with_model(
721
  round_index: int,
722
  count: int,
723
  max_support_edges: int,
 
724
  ) -> list[TaskInstance]:
725
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
726
 
727
  if count <= 0:
728
  return []
@@ -730,11 +847,13 @@ def _sample_generated_tasks_with_model(
730
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
731
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
732
  tokenizer.pad_token = tokenizer.eos_token
733
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
 
 
 
 
734
  model.eval()
735
 
736
- import torch
737
-
738
  device = next(model.parameters()).device
739
  generated: list[TaskInstance] = []
740
 
@@ -747,7 +866,7 @@ def _sample_generated_tasks_with_model(
747
  with torch.no_grad():
748
  output = model.generate(
749
  **encoded,
750
- max_new_tokens=256,
751
  do_sample=True,
752
  top_p=0.95,
753
  temperature=1.0,
@@ -841,6 +960,256 @@ def _save_payload(path: Path, payload: Any) -> None:
841
  path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
842
 
843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  def _fallback_swarm_v2_completion_texts(
845
  env: OSINTEnvironment,
846
  cfg: SelfPlayTrainingConfig,
@@ -914,6 +1283,7 @@ def _sample_swarm_v2_completion_texts_with_model(
914
  seen_questions: list[str],
915
  ) -> list[str]:
916
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
917
 
918
  if count <= 0:
919
  return []
@@ -921,11 +1291,13 @@ def _sample_swarm_v2_completion_texts_with_model(
921
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
922
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
923
  tokenizer.pad_token = tokenizer.eos_token
924
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
 
 
 
 
925
  model.eval()
926
 
927
- import torch
928
-
929
  device = next(model.parameters()).device
930
  completions: list[str] = []
931
  validator = SwarmV2ReplayValidator(
@@ -946,7 +1318,7 @@ def _sample_swarm_v2_completion_texts_with_model(
946
  with torch.no_grad():
947
  output = model.generate(
948
  **encoded,
949
- max_new_tokens=max(256, int(cfg.generator_phase.max_completion_length)),
950
  do_sample=True,
951
  top_p=top_p,
952
  temperature=temperature,
@@ -1003,6 +1375,10 @@ def _materialize_swarm_v2_completions(
1003
  max_support_edges=cfg.swarm_v2.validation.max_support_edges,
1004
  )
1005
  validation = validator.validate(candidate)
 
 
 
 
1006
 
1007
  if use_fixed_canonical and prompt_canonical_candidates and completion_idx < len(prompt_canonical_candidates):
1008
  canonical_graph = dict(prompt_canonical_candidates[completion_idx])
@@ -1045,14 +1421,7 @@ def _materialize_swarm_v2_completions(
1045
  {
1046
  "candidate_index": completion_idx,
1047
  "question": candidate.question,
1048
- "tool_trace": [
1049
- {
1050
- "tool_name": call.tool_name,
1051
- "args": dict(call.args),
1052
- "output": dict(call.output),
1053
- }
1054
- for call in candidate.tool_trace
1055
- ],
1056
  "replayed_edges": validation.to_dict()["replayed_edges"],
1057
  }
1058
  )
@@ -1078,14 +1447,7 @@ def _materialize_swarm_v2_completions(
1078
  "difficulty": "hard",
1079
  "scenario": "swarm_v2_trace",
1080
  "canonical_graph": canonical_graph,
1081
- "tool_trace": [
1082
- {
1083
- "tool_name": call.tool_name,
1084
- "args": dict(call.args),
1085
- "output": dict(call.output),
1086
- }
1087
- for call in candidate.tool_trace
1088
- ],
1089
  "subagent_outputs": list(candidate.subagent_outputs),
1090
  "validation": validation.to_dict(),
1091
  "shared_context_budget": {
@@ -1106,7 +1468,7 @@ def _materialize_swarm_v2_completions(
1106
  task_type=candidate.task_type or "swarm_v2_trace",
1107
  question=candidate.question,
1108
  answer=candidate.answer,
1109
- supporting_edges=list(validation.replayed_edges or candidate.supporting_edges),
1110
  metadata=metadata,
1111
  )
1112
  )
@@ -1132,6 +1494,8 @@ def _run_adversarial_self_play_swarm_v2(
1132
  seed_tasks = list(env.tasks)
1133
  seed_questions = [task.question for task in seed_tasks]
1134
  generator_model, answerer_model = _resolve_initial_models(training_config)
 
 
1135
  rng = random.Random(env_config.seed)
1136
 
1137
  bootstrap_completions = _fallback_swarm_v2_completion_texts(
@@ -1383,6 +1747,18 @@ def _run_adversarial_self_play_swarm_v2(
1383
  }
1384
  )
1385
 
 
 
 
 
 
 
 
 
 
 
 
 
1386
  final_payload = {
1387
  "dry_run": effective_dry_run,
1388
  "pipeline_mode": "swarm_v2",
@@ -1396,6 +1772,11 @@ def _run_adversarial_self_play_swarm_v2(
1396
  "generator": generator_model,
1397
  "answerer": answerer_model,
1398
  },
 
 
 
 
 
1399
  "kimi_objective_mapping": {
1400
  "grouped_rollouts": "TRL GRPO num_generations",
1401
  "mean_centered_advantage": "GRPO relative reward baseline",
@@ -1437,6 +1818,8 @@ def run_adversarial_self_play(
1437
  seed_tasks = list(env.tasks)
1438
 
1439
  generator_model, answerer_model = _resolve_initial_models(training_config)
 
 
1440
 
1441
  rng = random.Random(env_config.seed)
1442
  rounds_payload: list[dict[str, Any]] = []
@@ -1557,6 +1940,7 @@ def run_adversarial_self_play(
1557
  round_index=round_index,
1558
  count=training_config.generated_tasks_per_round,
1559
  max_support_edges=training_config.max_support_edges,
 
1560
  )
1561
  if not generated_tasks:
1562
  generated_tasks = _fallback_generated_tasks(
@@ -1641,6 +2025,18 @@ def run_adversarial_self_play(
1641
  }
1642
  )
1643
 
 
 
 
 
 
 
 
 
 
 
 
 
1644
  final_payload = {
1645
  "dry_run": effective_dry_run,
1646
  "pipeline_mode": "legacy",
@@ -1654,6 +2050,11 @@ def run_adversarial_self_play(
1654
  "generator": generator_model,
1655
  "answerer": answerer_model,
1656
  },
 
 
 
 
 
1657
  "kimi_objective_mapping": {
1658
  "grouped_rollouts": "TRL GRPO num_generations",
1659
  "mean_centered_advantage": "GRPO relative reward baseline",
 
18
  )
19
  from osint_env.domain.models import Edge, EnvironmentConfig, TaskInstance
20
  from osint_env.env.environment import OSINTEnvironment
21
+ from osint_env.env.reward import compute_graph_f1
22
  from osint_env.llm import build_llm_client
23
  from osint_env.training.config import (
24
  KimiGRPOPhaseConfig,
 
32
  GeneratorRewardFunction,
33
  SwarmV2ReplayValidator,
34
  decode_completion_text,
35
+ extract_answer_from_completion,
36
+ normalize_answer,
37
  parse_generated_task_completion,
38
  )
39
 
 
102
  return edges
103
 
104
 
105
+ def _compact_shared_context(
106
+ shared_context: dict[str, Any],
107
+ max_nodes: int = 8,
108
+ max_edges: int = 6,
109
+ ) -> dict[str, Any]:
110
+ return {
111
+ "nodes": list(shared_context.get("nodes", []))[:max_nodes],
112
+ "edges": list(shared_context.get("edges", []))[:max_edges],
113
+ }
114
+
115
+
116
+ def _task_shared_context(
117
+ env: OSINTEnvironment,
118
+ task: TaskInstance,
119
+ cfg: SelfPlayTrainingConfig,
120
+ ) -> dict[str, Any]:
121
+ metadata = dict(task.metadata or {})
122
+ canonical_graph = metadata.get("canonical_graph")
123
+ if isinstance(canonical_graph, dict):
124
+ return {
125
+ "nodes": list(canonical_graph.get("nodes", []))[: cfg.swarm_v2.shared_context.max_nodes],
126
+ "edges": list(canonical_graph.get("edges", []))[: cfg.swarm_v2.shared_context.max_edges],
127
+ }
128
+
129
+ deterministic_seed = sum(ord(ch) for ch in task.task_id)
130
+ return _graph_context_for_prompt(
131
+ env=env,
132
+ max_nodes=cfg.swarm_v2.shared_context.max_nodes,
133
+ max_edges=cfg.swarm_v2.shared_context.max_edges,
134
+ rng=random.Random(deterministic_seed),
135
+ )
136
+
137
+
138
+ def _swarm_v2_worker_packets(
139
+ canonical_candidate: dict[str, Any],
140
+ shared_context: dict[str, Any],
141
+ swarm_cfg: SwarmV2SwarmConfig,
142
+ ) -> dict[str, Any]:
143
+ path_edges = _edges_from_payload(
144
+ canonical_candidate.get("path", canonical_candidate.get("edges", [])),
145
+ max_edges=max(1, swarm_cfg.max_depth * 2),
146
+ )
147
+ if not path_edges:
148
+ path_edges = _edges_from_payload(canonical_candidate.get("edges", []), max_edges=2)
149
+ relation_path = [edge.rel for edge in path_edges]
150
+ start_node = path_edges[0].src if path_edges else ""
151
+ return {
152
+ "path_agent": {
153
+ "path_edges": [_edge_payload(edge) for edge in path_edges],
154
+ "goal": "Choose one contiguous replayable path from the canonical candidate.",
155
+ },
156
+ "question_agent": {
157
+ "start_node": start_node,
158
+ "relation_path": relation_path,
159
+ "goal": "Write a compact question that describes the path without leaking the answer.",
160
+ },
161
+ "context_agent": {
162
+ "shared_context": _compact_shared_context(shared_context),
163
+ "goal": "Keep support/context usage compact and graph-grounded.",
164
+ },
165
+ "planner": {
166
+ "max_agents": int(swarm_cfg.max_agents),
167
+ "max_breadth": int(swarm_cfg.max_breadth),
168
+ "max_depth": int(swarm_cfg.max_depth),
169
+ },
170
+ }
171
+
172
+
173
+ def _serialize_tool_trace(tool_trace: Any) -> list[dict[str, Any]]:
174
+ serialized: list[dict[str, Any]] = []
175
+ for call in tool_trace or []:
176
+ tool_name = getattr(call, "tool_name", "")
177
+ args = getattr(call, "args", {})
178
+ output = getattr(call, "output", {})
179
+ if not tool_name:
180
+ continue
181
+ serialized.append(
182
+ {
183
+ "tool_name": str(tool_name),
184
+ "args": dict(args) if isinstance(args, dict) else {},
185
+ "output": dict(output) if isinstance(output, dict) else {},
186
+ }
187
+ )
188
+ return serialized
189
+
190
+
191
 
192
  def _canonical_example_payload(
193
  graph: Any,
 
202
  "answer": "",
203
  "task_type": "swarm_v2_trace",
204
  "supporting_edges": [],
 
205
  "subagent_outputs": ["path_agent: no replayable edge"],
206
  "orchestrator": {
207
  "spawn_count": 1,
 
214
 
215
  traced_edges = traced_edges[:2]
216
  spawn_count = min(swarm_cfg.max_agents, max(1, len(traced_edges) + 1))
 
217
  return {
218
  "question": emit_swarm_v2_question(traced_edges),
219
  "answer": select_swarm_v2_answer(traced_edges),
220
  "task_type": f"swarm_v2_{len(traced_edges)}hop_trace",
221
  "supporting_edges": [_edge_payload(edge) for edge in traced_edges],
 
222
  "subagent_outputs": [
223
+ f"path_agent_{idx}: {edge.src} --{edge.rel}--> {edge.dst}"
224
  for idx, edge in enumerate(traced_edges)
225
+ ]
226
+ + [
227
+ "question_agent: emitted compact relation-path question",
228
+ "context_agent: kept shared context focused on replayable edges",
229
  ],
230
  "orchestrator": {
231
  "spawn_count": spawn_count,
 
266
  swarm_cfg: SwarmV2SwarmConfig,
267
  ) -> str:
268
  del swarm_cfg # kept for signature compatibility
269
+ compact_context = _compact_shared_context(shared_context)
 
 
 
270
  return (
271
  "You answer one OSINT graph question using ONLY the shared context.\n"
272
  "Output rules:\n"
 
309
  ) -> list[dict[str, Any]]:
310
  rows: list[dict[str, Any]] = []
311
  for task in tasks:
312
+ shared_context = _task_shared_context(env=env, task=task, cfg=cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  rows.append(
315
  {
 
411
  anchors = "\n".join(f"- {question}" for question in anchor_questions)
412
  canonical_mode = str(canonical_graph_mode).strip().lower() or "generate"
413
  example_payload = _canonical_example_payload(graph, canonical_candidate, swarm_cfg)
414
+ worker_packets = _swarm_v2_worker_packets(
415
+ canonical_candidate=canonical_candidate,
416
+ shared_context=shared_context,
417
+ swarm_cfg=swarm_cfg,
418
+ )
419
  canonical_instruction = (
420
  "You may propose canonical_graph updates when they improve replayability and keep it graph-grounded."
421
  if canonical_mode == "generate"
422
  else "Reuse the provided canonical candidate as-is; do not add, remove, or modify canonical_graph nodes/edges."
423
  )
424
+ canonical_compact = _compact_shared_context(canonical_candidate)
 
 
 
425
  return (
426
+ "You coordinate a compact multi-agent OSINT task-generation swarm.\n"
427
  "Output rules:\n"
428
  "- Return ONLY one JSON object. No markdown. No prose. End with }.\n"
429
+ "- Required keys: question, answer, task_type, supporting_edges, subagent_outputs, orchestrator.\n"
430
+ "- Optional keys: canonical_graph, validation.\n"
431
  "- supporting_edges: non-empty list of {src, rel, dst, confidence}, taken from canonical edges.\n"
432
+ "- supporting_edges must form one contiguous replayable path. Keep it compact.\n"
433
+ "- Do NOT emit verbose tool traces or neighbor dumps; replay tools are derived from supporting_edges.\n"
434
+ "- answer = final dst of the trace. question describes the path without leaking the answer.\n"
435
+ "- subagent_outputs: 2-4 terse strings summarizing path_agent/question_agent/context_agent work.\n"
436
  "- orchestrator: integer keys spawn_count, finished_subtasks, critical_steps, breadth, depth.\n"
437
  f"- canonical_graph_mode={canonical_mode}: {canonical_instruction}\n"
438
+ "- Favor minimal shared context per worker so question generation stays parallel-friendly.\n"
439
  "Example (copy schema, not values):\n"
440
  f"{json.dumps(example_payload, separators=(',', ':'), sort_keys=True)}\n"
441
+ "Worker packets:\n"
442
+ f"{json.dumps(worker_packets, separators=(',', ':'), sort_keys=True)}\n"
443
  "Canonical candidate (use these edges):\n"
444
  f"{json.dumps(canonical_compact, separators=(',', ':'), sort_keys=True)}\n"
445
+ "Shared context:\n"
446
+ f"{json.dumps(_compact_shared_context(shared_context), separators=(',', ':'), sort_keys=True)}\n"
447
  f"Avoid these prior questions: {anchors}\n"
448
  "JSON:"
449
  )
 
491
  "prompt": prompt,
492
  "candidate_id": f"candidate_{idx}",
493
  "canonical_graph_json": json.dumps(canonical_candidate, sort_keys=True),
494
+ "shared_context_json": json.dumps(shared_context, sort_keys=True),
495
+ "worker_packets_json": json.dumps(
496
+ _swarm_v2_worker_packets(
497
+ canonical_candidate=canonical_candidate,
498
+ shared_context=shared_context,
499
+ swarm_cfg=cfg.swarm_v2.generator_swarm,
500
+ ),
501
+ sort_keys=True,
502
+ ),
503
  }
504
  )
505
  canonical_candidates.append(canonical_candidate)
 
532
  "scale_rewards": str(phase.scale_rewards),
533
  "logging_steps": int(phase.logging_steps),
534
  "save_steps": int(phase.save_steps),
535
+ "save_total_limit": int(phase.save_total_limit),
536
+ "optim": str(phase.optim),
537
+ "bf16": bool(phase.bf16),
538
+ "tf32": bool(phase.tf32),
539
+ "gradient_checkpointing": bool(phase.gradient_checkpointing),
540
+ "dataloader_num_workers": int(phase.dataloader_num_workers),
541
+ "dataloader_persistent_workers": bool(phase.dataloader_persistent_workers),
542
+ "dataloader_prefetch_factor": int(phase.dataloader_prefetch_factor),
543
+ "generation_batch_size": int(phase.generation_batch_size),
544
+ "max_prompt_length": int(phase.max_prompt_length),
545
  "remove_unused_columns": False,
546
  "use_vllm": bool(phase.use_vllm),
547
  "vllm_mode": str(phase.vllm_mode),
 
650
 
651
  final_dir = output_dir / "final_model"
652
  trainer.save_model(str(final_dir))
653
+ trainer_tokenizer = getattr(trainer, "processing_class", None) or getattr(trainer, "tokenizer", None)
654
+ if trainer_tokenizer is not None and hasattr(trainer_tokenizer, "save_pretrained"):
655
+ trainer_tokenizer.save_pretrained(str(final_dir))
656
+ checkpoint_dirs = [str(path) for path in sorted(output_dir.glob("checkpoint-*")) if path.is_dir()]
657
 
658
  global_step = int(getattr(train_output, "global_step", 0))
659
  training_loss = float(getattr(train_output, "training_loss", 0.0))
660
 
661
+ total_param_count = int(sum(param.numel() for param in trainer.model.parameters()))
662
  result = {
663
  "model_path": str(final_dir),
664
+ "final_model_path": str(final_dir),
665
+ "phase_output_dir": str(output_dir),
666
+ "checkpoint_dirs": checkpoint_dirs,
667
  "global_step": global_step,
668
  "training_loss": training_loss,
669
  "train_rows": len(rows),
670
  "tuning_mode": str(tuning_mode).strip().lower() or "full",
671
+ "is_full_finetune": str(tuning_mode).strip().lower() != "lora",
672
  }
673
 
674
  log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
 
708
  "grad_norm_max": max(grad_norm_values) if grad_norm_values else 0.0,
709
  "entropy_min": min(entropy_values) if entropy_values else 0.0,
710
  "entropy_max": max(entropy_values) if entropy_values else 0.0,
711
+ "total_param_count": total_param_count,
712
  "trainable_param_count": trainable_param_count,
713
+ "trainable_fraction": (
714
+ float(trainable_param_count / total_param_count)
715
+ if total_param_count > 0
716
+ else 0.0
717
+ ),
718
  "params_with_grad": params_with_grad,
719
  "nonzero_grad_tensors": nonzero_grad_tensors,
720
  "fingerprint_param_count": len(pre_update_fingerprint),
 
836
  round_index: int,
837
  count: int,
838
  max_support_edges: int,
839
+ max_new_tokens: int,
840
  ) -> list[TaskInstance]:
841
  from transformers import AutoModelForCausalLM, AutoTokenizer
842
+ import torch
843
 
844
  if count <= 0:
845
  return []
 
847
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
848
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
849
  tokenizer.pad_token = tokenizer.eos_token
850
+ model_kwargs: dict[str, Any] = {}
851
+ if torch.cuda.is_available():
852
+ model_kwargs["device_map"] = "auto"
853
+ model_kwargs["torch_dtype"] = torch.bfloat16
854
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
855
  model.eval()
856
 
 
 
857
  device = next(model.parameters()).device
858
  generated: list[TaskInstance] = []
859
 
 
866
  with torch.no_grad():
867
  output = model.generate(
868
  **encoded,
869
+ max_new_tokens=max(64, int(max_new_tokens)),
870
  do_sample=True,
871
  top_p=0.95,
872
  temperature=1.0,
 
960
  path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
961
 
962
 
963
+ def _generate_answerer_completion_texts_with_model(
964
+ model_name_or_path: str,
965
+ prompts: list[str],
966
+ max_new_tokens: int,
967
+ ) -> list[str]:
968
+ from transformers import AutoModelForCausalLM, AutoTokenizer
969
+ import torch
970
+
971
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
972
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
973
+ tokenizer.pad_token = tokenizer.eos_token
974
+
975
+ model_kwargs: dict[str, Any] = {}
976
+ if torch.cuda.is_available():
977
+ model_kwargs["device_map"] = "auto"
978
+ model_kwargs["torch_dtype"] = torch.bfloat16
979
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
980
+ model.eval()
981
+ device = next(model.parameters()).device
982
+
983
+ completions: list[str] = []
984
+ for prompt in prompts:
985
+ encoded = tokenizer(prompt, return_tensors="pt")
986
+ encoded = {key: value.to(device) for key, value in encoded.items()}
987
+ with torch.no_grad():
988
+ output = model.generate(
989
+ **encoded,
990
+ max_new_tokens=max(16, int(max_new_tokens)),
991
+ do_sample=False,
992
+ pad_token_id=tokenizer.eos_token_id,
993
+ )
994
+ completion_ids = output[0][encoded["input_ids"].shape[1] :]
995
+ completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
996
+ return completions
997
+
998
+
999
+ def _top_validation_reasons(validation_reports: list[dict[str, Any]]) -> list[tuple[str, int]]:
1000
+ counts: dict[str, int] = {}
1001
+ for report in validation_reports:
1002
+ validation = report.get("validation", {}) if isinstance(report, dict) else {}
1003
+ reasons = validation.get("reasons", []) if isinstance(validation, dict) else []
1004
+ for reason in reasons:
1005
+ token = str(reason).strip()
1006
+ if not token:
1007
+ continue
1008
+ counts[token] = counts.get(token, 0) + 1
1009
+ return sorted(counts.items(), key=lambda item: (-item[1], item[0]))
1010
+
1011
+
1012
+ def _run_post_training_evaluation(
1013
+ env_config: EnvironmentConfig,
1014
+ training_config: SelfPlayTrainingConfig,
1015
+ generator_model: str,
1016
+ answerer_models: dict[str, str],
1017
+ output_dir: Path,
1018
+ pipeline_mode: str,
1019
+ effective_dry_run: bool,
1020
+ ) -> dict[str, Any]:
1021
+ tasks_path = output_dir / "post_training_eval_generated_tasks.json"
1022
+ validation_path = output_dir / "post_training_eval_validation_reports.json"
1023
+ payload_path = output_dir / "post_training_evaluation.json"
1024
+ payload: dict[str, Any] = {
1025
+ "pipeline_mode": pipeline_mode,
1026
+ "generator_model": generator_model,
1027
+ "answerer_models": dict(answerer_models),
1028
+ "generated_tasks_path": str(tasks_path),
1029
+ "validation_reports_path": str(validation_path),
1030
+ "skipped": False,
1031
+ }
1032
+
1033
+ if effective_dry_run:
1034
+ payload.update({"skipped": True, "reason": "dry_run"})
1035
+ _save_payload(validation_path, [])
1036
+ _save_payload(tasks_path, [])
1037
+ _save_payload(payload_path, payload)
1038
+ payload["path"] = str(payload_path)
1039
+ return payload
1040
+
1041
+ try:
1042
+ env = OSINTEnvironment(env_config, llm=build_llm_client(env_config.llm))
1043
+ rng = random.Random(env_config.seed + 9973)
1044
+ validation_reports: list[dict[str, Any]] = []
1045
+
1046
+ if pipeline_mode == "swarm_v2":
1047
+ generator_rows, prompt_canonical_candidates = _build_swarm_v2_generator_rows(env, training_config, rng)
1048
+ completion_texts = _sample_swarm_v2_completion_texts_with_model(
1049
+ env=env,
1050
+ cfg=training_config,
1051
+ model_name_or_path=generator_model,
1052
+ prompts=[row["prompt"] for row in generator_rows],
1053
+ count=max(1, training_config.post_training_eval_questions * 2),
1054
+ seen_questions=[task.question for task in env.tasks],
1055
+ )
1056
+ generated_tasks, validation_reports, _, _ = _materialize_swarm_v2_completions(
1057
+ env=env,
1058
+ cfg=training_config,
1059
+ completion_texts=completion_texts,
1060
+ round_index=max(1, training_config.rounds) + 1,
1061
+ seen_questions=[task.question for task in env.tasks],
1062
+ prompt_canonical_candidates=prompt_canonical_candidates,
1063
+ )
1064
+ if not generated_tasks:
1065
+ generated_tasks, validation_reports, _, _ = _materialize_swarm_v2_completions(
1066
+ env=env,
1067
+ cfg=training_config,
1068
+ completion_texts=_fallback_swarm_v2_completion_texts(
1069
+ env=env,
1070
+ cfg=training_config,
1071
+ round_index=max(1, training_config.rounds) + 1,
1072
+ rng=rng,
1073
+ ),
1074
+ round_index=max(1, training_config.rounds) + 1,
1075
+ seen_questions=[task.question for task in env.tasks],
1076
+ prompt_canonical_candidates=None,
1077
+ )
1078
+ generated_tasks = generated_tasks[: max(1, training_config.post_training_eval_questions)]
1079
+ answer_rows = _build_swarm_v2_answerer_rows(env, generated_tasks, training_config)
1080
+ reward_fn = AnswererRewardFunction(
1081
+ graph=env.graph,
1082
+ pipeline_mode="swarm_v2",
1083
+ parl_max_parallel_hint=training_config.swarm_v2.answerer_swarm.max_agents,
1084
+ )
1085
+ else:
1086
+ generator_rows = _build_generator_rows(env=env, cfg=training_config, rng=rng)
1087
+ generated_tasks = _sample_generated_tasks_with_model(
1088
+ model_name_or_path=generator_model,
1089
+ prompts=[row["prompt"] for row in generator_rows],
1090
+ round_index=max(1, training_config.rounds) + 1,
1091
+ count=max(1, training_config.post_training_eval_questions),
1092
+ max_support_edges=training_config.max_support_edges,
1093
+ max_new_tokens=training_config.generated_task_max_new_tokens,
1094
+ )
1095
+ if not generated_tasks:
1096
+ generated_tasks = _fallback_generated_tasks(
1097
+ base_tasks=list(env.tasks),
1098
+ round_index=max(1, training_config.rounds) + 1,
1099
+ count=max(1, training_config.post_training_eval_questions),
1100
+ rng=rng,
1101
+ )
1102
+ answer_rows = _build_answerer_rows(generated_tasks)
1103
+ reward_fn = AnswererRewardFunction(graph=env.graph)
1104
+
1105
+ _save_tasks(tasks_path, generated_tasks)
1106
+ _save_payload(validation_path, validation_reports)
1107
+
1108
+ model_evaluations: dict[str, dict[str, Any]] = {}
1109
+ for model_label, answerer_model in answerer_models.items():
1110
+ answerer_completions = _generate_answerer_completion_texts_with_model(
1111
+ model_name_or_path=answerer_model,
1112
+ prompts=[row["prompt"] for row in answer_rows],
1113
+ max_new_tokens=training_config.post_training_eval_answer_max_new_tokens,
1114
+ )
1115
+ rewards = reward_fn(
1116
+ prompts=[row["prompt"] for row in answer_rows],
1117
+ completions=answerer_completions,
1118
+ answer=[row["answer"] for row in answer_rows],
1119
+ question=[row["question"] for row in answer_rows],
1120
+ supporting_edges_json=[row["supporting_edges_json"] for row in answer_rows],
1121
+ difficulty=[row["difficulty"] for row in answer_rows],
1122
+ )
1123
+
1124
+ episodes: list[dict[str, Any]] = []
1125
+ for task, row, completion_text, reward in zip(generated_tasks, answer_rows, answerer_completions, rewards):
1126
+ support_edges = AnswererRewardFunction._parse_support_edges(row["supporting_edges_json"])
1127
+ pred_edges = AnswererRewardFunction._extract_predicted_edges(completion_text, support_edges)
1128
+ predicted_answer = normalize_answer(extract_answer_from_completion(completion_text))
1129
+ target_answer = normalize_answer(task.answer)
1130
+ graph_f1 = compute_graph_f1(pred_edges, support_edges)
1131
+ episodes.append(
1132
+ {
1133
+ "task_id": task.task_id,
1134
+ "task_type": task.task_type,
1135
+ "question": task.question,
1136
+ "task_answer": target_answer,
1137
+ "agent_answer": predicted_answer,
1138
+ "reward": float(reward),
1139
+ "graph_f1": float(graph_f1),
1140
+ "success": int(predicted_answer == target_answer),
1141
+ "support_edge_count": len(support_edges),
1142
+ "predicted_edge_count": len(pred_edges),
1143
+ "completion_length": len(completion_text),
1144
+ }
1145
+ )
1146
+
1147
+ episode_count = len(episodes)
1148
+ model_evaluations[model_label] = {
1149
+ "model_path": answerer_model,
1150
+ "episodes": episodes,
1151
+ "summary": {
1152
+ "episodes": episode_count,
1153
+ "task_success_rate": (
1154
+ float(sum(row["success"] for row in episodes) / max(1, episode_count))
1155
+ if episodes
1156
+ else 0.0
1157
+ ),
1158
+ "avg_reward": (
1159
+ float(sum(float(row["reward"]) for row in episodes) / max(1, episode_count))
1160
+ if episodes
1161
+ else 0.0
1162
+ ),
1163
+ "avg_graph_f1": (
1164
+ float(sum(float(row["graph_f1"]) for row in episodes) / max(1, episode_count))
1165
+ if episodes
1166
+ else 0.0
1167
+ ),
1168
+ "avg_completion_length": (
1169
+ float(sum(int(row["completion_length"]) for row in episodes) / max(1, episode_count))
1170
+ if episodes
1171
+ else 0.0
1172
+ ),
1173
+ },
1174
+ }
1175
+
1176
+ final_summary = model_evaluations.get("finetuned_answerer", {}).get("summary", {})
1177
+ baseline_summary = model_evaluations.get("original_answerer", {}).get("summary", {})
1178
+ summary = {
1179
+ "generated_task_count": len(generated_tasks),
1180
+ "generator_valid_rate": (
1181
+ float(len(generated_tasks) / max(1, len(validation_reports)))
1182
+ if validation_reports
1183
+ else 1.0
1184
+ ),
1185
+ "compared_models": sorted(model_evaluations.keys()),
1186
+ "finetuned_answerer": dict(final_summary),
1187
+ "original_answerer": dict(baseline_summary),
1188
+ "delta_vs_original": {
1189
+ "task_success_rate": float(final_summary.get("task_success_rate", 0.0) - baseline_summary.get("task_success_rate", 0.0)),
1190
+ "avg_reward": float(final_summary.get("avg_reward", 0.0) - baseline_summary.get("avg_reward", 0.0)),
1191
+ "avg_graph_f1": float(final_summary.get("avg_graph_f1", 0.0) - baseline_summary.get("avg_graph_f1", 0.0)),
1192
+ },
1193
+ "top_generator_invalid_reasons": _top_validation_reasons(validation_reports)[:5],
1194
+ }
1195
+ payload.update(
1196
+ {
1197
+ "summary": summary,
1198
+ "model_evaluations": model_evaluations,
1199
+ }
1200
+ )
1201
+ except Exception as exc:
1202
+ payload.update({"skipped": True, "reason": f"{type(exc).__name__}: {exc}"})
1203
+
1204
+ if not tasks_path.exists():
1205
+ _save_payload(tasks_path, [])
1206
+ if not validation_path.exists():
1207
+ _save_payload(validation_path, [])
1208
+ _save_payload(payload_path, payload)
1209
+ payload["path"] = str(payload_path)
1210
+ return payload
1211
+
1212
+
1213
  def _fallback_swarm_v2_completion_texts(
1214
  env: OSINTEnvironment,
1215
  cfg: SelfPlayTrainingConfig,
 
1283
  seen_questions: list[str],
1284
  ) -> list[str]:
1285
  from transformers import AutoModelForCausalLM, AutoTokenizer
1286
+ import torch
1287
 
1288
  if count <= 0:
1289
  return []
 
1291
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1292
  if tokenizer.pad_token is None and tokenizer.eos_token is not None:
1293
  tokenizer.pad_token = tokenizer.eos_token
1294
+ model_kwargs: dict[str, Any] = {}
1295
+ if torch.cuda.is_available():
1296
+ model_kwargs["device_map"] = "auto"
1297
+ model_kwargs["torch_dtype"] = torch.bfloat16
1298
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
1299
  model.eval()
1300
 
 
 
1301
  device = next(model.parameters()).device
1302
  completions: list[str] = []
1303
  validator = SwarmV2ReplayValidator(
 
1318
  with torch.no_grad():
1319
  output = model.generate(
1320
  **encoded,
1321
+ max_new_tokens=max(64, int(cfg.generated_task_max_new_tokens)),
1322
  do_sample=True,
1323
  top_p=top_p,
1324
  temperature=temperature,
 
1375
  max_support_edges=cfg.swarm_v2.validation.max_support_edges,
1376
  )
1377
  validation = validator.validate(candidate)
1378
+ replay_edges = list(validation.replayed_edges or candidate.supporting_edges)
1379
+ materialized_tool_trace = _serialize_tool_trace(candidate.tool_trace)
1380
+ if not materialized_tool_trace and replay_edges:
1381
+ materialized_tool_trace = build_swarm_v2_tool_trace(env.graph, replay_edges)
1382
 
1383
  if use_fixed_canonical and prompt_canonical_candidates and completion_idx < len(prompt_canonical_candidates):
1384
  canonical_graph = dict(prompt_canonical_candidates[completion_idx])
 
1421
  {
1422
  "candidate_index": completion_idx,
1423
  "question": candidate.question,
1424
+ "tool_trace": materialized_tool_trace,
 
 
 
 
 
 
 
1425
  "replayed_edges": validation.to_dict()["replayed_edges"],
1426
  }
1427
  )
 
1447
  "difficulty": "hard",
1448
  "scenario": "swarm_v2_trace",
1449
  "canonical_graph": canonical_graph,
1450
+ "tool_trace": materialized_tool_trace,
 
 
 
 
 
 
 
1451
  "subagent_outputs": list(candidate.subagent_outputs),
1452
  "validation": validation.to_dict(),
1453
  "shared_context_budget": {
 
1468
  task_type=candidate.task_type or "swarm_v2_trace",
1469
  question=candidate.question,
1470
  answer=candidate.answer,
1471
+ supporting_edges=replay_edges,
1472
  metadata=metadata,
1473
  )
1474
  )
 
1494
  seed_tasks = list(env.tasks)
1495
  seed_questions = [task.question for task in seed_tasks]
1496
  generator_model, answerer_model = _resolve_initial_models(training_config)
1497
+ initial_generator_model = str(generator_model)
1498
+ initial_answerer_model = str(answerer_model)
1499
  rng = random.Random(env_config.seed)
1500
 
1501
  bootstrap_completions = _fallback_swarm_v2_completion_texts(
 
1747
  }
1748
  )
1749
 
1750
+ post_training_evaluation = _run_post_training_evaluation(
1751
+ env_config=env_config,
1752
+ training_config=training_config,
1753
+ generator_model=generator_model,
1754
+ answerer_models={
1755
+ "finetuned_answerer": answerer_model,
1756
+ "original_answerer": initial_answerer_model,
1757
+ },
1758
+ output_dir=run_dir,
1759
+ pipeline_mode="swarm_v2",
1760
+ effective_dry_run=effective_dry_run,
1761
+ )
1762
  final_payload = {
1763
  "dry_run": effective_dry_run,
1764
  "pipeline_mode": "swarm_v2",
 
1772
  "generator": generator_model,
1773
  "answerer": answerer_model,
1774
  },
1775
+ "initial_models": {
1776
+ "generator": initial_generator_model,
1777
+ "answerer": initial_answerer_model,
1778
+ },
1779
+ "post_training_evaluation": post_training_evaluation,
1780
  "kimi_objective_mapping": {
1781
  "grouped_rollouts": "TRL GRPO num_generations",
1782
  "mean_centered_advantage": "GRPO relative reward baseline",
 
1818
  seed_tasks = list(env.tasks)
1819
 
1820
  generator_model, answerer_model = _resolve_initial_models(training_config)
1821
+ initial_generator_model = str(generator_model)
1822
+ initial_answerer_model = str(answerer_model)
1823
 
1824
  rng = random.Random(env_config.seed)
1825
  rounds_payload: list[dict[str, Any]] = []
 
1940
  round_index=round_index,
1941
  count=training_config.generated_tasks_per_round,
1942
  max_support_edges=training_config.max_support_edges,
1943
+ max_new_tokens=training_config.generated_task_max_new_tokens,
1944
  )
1945
  if not generated_tasks:
1946
  generated_tasks = _fallback_generated_tasks(
 
2025
  }
2026
  )
2027
 
2028
+ post_training_evaluation = _run_post_training_evaluation(
2029
+ env_config=env_config,
2030
+ training_config=training_config,
2031
+ generator_model=generator_model,
2032
+ answerer_models={
2033
+ "finetuned_answerer": answerer_model,
2034
+ "original_answerer": initial_answerer_model,
2035
+ },
2036
+ output_dir=run_dir,
2037
+ pipeline_mode="legacy",
2038
+ effective_dry_run=effective_dry_run,
2039
+ )
2040
  final_payload = {
2041
  "dry_run": effective_dry_run,
2042
  "pipeline_mode": "legacy",
 
2050
  "generator": generator_model,
2051
  "answerer": answerer_model,
2052
  },
2053
+ "initial_models": {
2054
+ "generator": initial_generator_model,
2055
+ "answerer": initial_answerer_model,
2056
+ },
2057
+ "post_training_evaluation": post_training_evaluation,
2058
  "kimi_objective_mapping": {
2059
  "grouped_rollouts": "TRL GRPO num_generations",
2060
  "mean_centered_advantage": "GRPO relative reward baseline",
tests/test_environment.py CHANGED
@@ -30,6 +30,22 @@ def test_search_memory_tool_returns_results_after_tool_use():
30
  assert obs.tool_outputs[-1]["output"]["count"] >= 1
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def test_invalid_tool_call_does_not_crash_episode():
34
  env = OSINTEnvironment(EnvironmentConfig(max_steps=4, seed=8))
35
  env.reset()
 
30
  assert obs.tool_outputs[-1]["output"]["count"] >= 1
31
 
32
 
33
+ def test_search_shared_context_returns_task_local_hits():
34
+ env = OSINTEnvironment(EnvironmentConfig(max_steps=6, seed=7))
35
+ obs = env.reset()
36
+ assert obs.task["shared_context_available"] is True
37
+ answer = str(env.state.task.answer if env.state else "")
38
+
39
+ obs, reward, done, _ = env.step(
40
+ Action(ActionType.CALL_TOOL, {"tool_name": "search_shared_context", "args": {"query": answer, "k": 5}})
41
+ )
42
+ assert done is False
43
+ assert isinstance(reward, float)
44
+ assert obs.tool_outputs[-1]["tool"] == "search_shared_context"
45
+ assert obs.tool_outputs[-1]["output"]["shared_context_available"] is True
46
+ assert obs.tool_outputs[-1]["output"]["count"] >= 1
47
+
48
+
49
  def test_invalid_tool_call_does_not_crash_episode():
50
  env = OSINTEnvironment(EnvironmentConfig(max_steps=4, seed=8))
51
  env.reset()
tests/test_hf_jobs.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from osint_env.training.hf_jobs import (
2
+ DEFAULT_HF_JOB_IMAGE,
3
+ _build_job_command,
4
+ _default_train_output_dir,
5
+ _resolve_job_image,
6
+ )
7
+
8
+
9
+ def test_resolve_job_image_prefers_explicit_image():
10
+ assert _resolve_job_image("python:3.12", "owner/space") == "python:3.12"
11
+
12
+
13
+ def test_resolve_job_image_supports_space_fallback():
14
+ assert _resolve_job_image("", "owner/space") == "hf.co/spaces/owner/space"
15
+ assert _resolve_job_image("", "") == DEFAULT_HF_JOB_IMAGE
16
+
17
+
18
+ def test_default_train_output_dir_uses_bucket_mount_when_present():
19
+ assert _default_train_output_dir("my-bucket", "run-42") == "/training-outputs/run-42"
20
+ assert _default_train_output_dir("", "run-42") == "artifacts/run-42"
21
+
22
+
23
+ def test_build_job_command_runs_train_directly_when_image_has_code():
24
+ command = _build_job_command(
25
+ env_config_path="config/shared_config.json",
26
+ train_config_path="config/train.json",
27
+ output_dir="artifacts/self_play",
28
+ dry_run=False,
29
+ repo_url="",
30
+ repo_ref="",
31
+ repo_subdir="",
32
+ setup_command="",
33
+ )
34
+ assert command == [
35
+ "osint-env",
36
+ "train-self-play",
37
+ "--config",
38
+ "config/shared_config.json",
39
+ "--train-config",
40
+ "config/train.json",
41
+ "--train-output-dir",
42
+ "artifacts/self_play",
43
+ ]
44
+
45
+
46
+ def test_build_job_command_bootstraps_repo_when_requested():
47
+ command = _build_job_command(
48
+ env_config_path="config/shared_config.json",
49
+ train_config_path="config/train.json",
50
+ output_dir="/training-outputs/run-1",
51
+ dry_run=True,
52
+ repo_url="https://github.com/example/osint-env.git",
53
+ repo_ref="main",
54
+ repo_subdir=".",
55
+ setup_command="python -m pip install flash-attn --no-build-isolation",
56
+ )
57
+ assert command[:2] == ["bash", "-lc"]
58
+ script = command[2]
59
+ assert "git clone --depth 1 --branch main https://github.com/example/osint-env.git /workspace/osint_env_app" in script
60
+ assert "python -m pip install -e '.[train]'" in script
61
+ assert "python -m pip install flash-attn --no-build-isolation" in script
62
+ assert "--train-config config/train.json" in script
63
+ assert "--train-output-dir /training-outputs/run-1" in script
64
+ assert "--dry-run" in script
tests/test_openai_baseline.py CHANGED
@@ -7,6 +7,7 @@ def test_openai_baseline_toolset_contains_answer_and_graph_actions():
7
  assert "submit_answer" in names
8
  assert "add_edge" in names
9
  assert "search_memory" in names
 
10
  assert "get_post" in names
11
 
12
 
 
7
  assert "submit_answer" in names
8
  assert "add_edge" in names
9
  assert "search_memory" in names
10
+ assert "search_shared_context" in names
11
  assert "get_post" in names
12
 
13
 
tests/test_self_play_swarm_v2.py CHANGED
@@ -195,8 +195,8 @@ def test_swarm_v2_replay_validator_accepts_valid_candidate_and_rejects_invalid_c
195
  no_trace_payload = deepcopy(payload)
196
  no_trace_payload["tool_trace"] = []
197
  no_trace = validator.validate(parse_generated_task_completion(json.dumps(no_trace_payload)))
198
- assert no_trace.is_valid is False
199
- assert "non_replayable_tool_calls" in no_trace.reasons
200
 
201
  unseen_payload = deepcopy(payload)
202
  unseen_payload["supporting_edges"][0]["dst"] = "user_missing"
@@ -205,6 +205,22 @@ def test_swarm_v2_replay_validator_accepts_valid_candidate_and_rejects_invalid_c
205
  assert "unseen_nodes_or_edges" in unseen.reasons
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def test_swarm_v2_replay_validator_rejects_non_unique_paths():
209
  graph = CanonicalGraph(
210
  nodes={
@@ -337,7 +353,9 @@ def test_swarm_v2_generator_reward_grades_invalid_outputs_instead_of_constant_pe
337
  scores = reward_fn(completions=[missing_everything, partial_json, partial_edges, json.dumps(valid_payload)])
338
 
339
  assert len(set(scores)) > 2
340
- assert scores[0] < scores[1] < scores[2] < scores[3]
 
 
341
  assert reward_fn._debug_last_batch["batch_reward_std"] > 0.0
342
  assert reward_fn._debug_last_batch["valid_output_ratio"] == 0.25
343
 
@@ -373,6 +391,24 @@ def test_parse_generated_task_completion_handles_garbage_orchestrator_values():
373
  assert candidate.orchestrator.depth == 0
374
 
375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  def test_swarm_v2_generator_reward_is_robust_to_parse_crashes():
377
  """Reward function must never raise: any malformed completion gets a floor reward."""
378
  cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
@@ -432,6 +468,11 @@ def test_swarm_v2_dry_run_writes_new_artifacts_and_preserves_legacy_contract(tmp
432
  loaded = json.loads(Path(artifacts[key]).read_text(encoding="utf-8"))
433
  assert loaded is not None
434
 
 
 
 
 
 
435
 
436
  def test_swarm_v2_fixed_canonical_mode_reuses_prompt_candidates(tmp_path: Path):
437
  env_cfg = EnvironmentConfig(seed=19, n_users=14, max_steps=6)
 
195
  no_trace_payload = deepcopy(payload)
196
  no_trace_payload["tool_trace"] = []
197
  no_trace = validator.validate(parse_generated_task_completion(json.dumps(no_trace_payload)))
198
+ assert no_trace.is_valid is True
199
+ assert no_trace.replayed_edges
200
 
201
  unseen_payload = deepcopy(payload)
202
  unseen_payload["supporting_edges"][0]["dst"] = "user_missing"
 
205
  assert "unseen_nodes_or_edges" in unseen.reasons
206
 
207
 
208
+ def test_swarm_v2_replay_validator_can_derive_tool_trace_from_support_edges():
209
+ cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
210
+ env = OSINTEnvironment(EnvironmentConfig(seed=27, n_users=18, max_steps=6))
211
+ payload = _build_valid_candidate_payload(env, cfg)
212
+ payload.pop("tool_trace", None)
213
+
214
+ validator = SwarmV2ReplayValidator(
215
+ graph=env.graph,
216
+ validation=cfg.swarm_v2.validation,
217
+ shared_context=cfg.swarm_v2.shared_context,
218
+ seen_questions=[],
219
+ )
220
+ result = validator.validate(parse_generated_task_completion(json.dumps(payload)))
221
+ assert result.is_valid is True
222
+
223
+
224
  def test_swarm_v2_replay_validator_rejects_non_unique_paths():
225
  graph = CanonicalGraph(
226
  nodes={
 
353
  scores = reward_fn(completions=[missing_everything, partial_json, partial_edges, json.dumps(valid_payload)])
354
 
355
  assert len(set(scores)) > 2
356
+ assert scores[2] > scores[0]
357
+ assert scores[2] > scores[1]
358
+ assert scores[3] != scores[0]
359
  assert reward_fn._debug_last_batch["batch_reward_std"] > 0.0
360
  assert reward_fn._debug_last_batch["valid_output_ratio"] == 0.25
361
 
 
391
  assert candidate.orchestrator.depth == 0
392
 
393
 
394
+ def test_parse_generated_task_completion_accepts_result_alias_in_tool_trace():
395
+ cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
396
+ env = OSINTEnvironment(EnvironmentConfig(seed=35, n_users=18, max_steps=6))
397
+ payload = _build_valid_candidate_payload(env, cfg)
398
+ payload["tool_trace"] = [
399
+ {
400
+ "tool": call["tool_name"],
401
+ "args": dict(call["args"]),
402
+ "result": dict(call["output"]),
403
+ }
404
+ for call in payload["tool_trace"]
405
+ ]
406
+
407
+ candidate = parse_generated_task_completion(json.dumps(payload))
408
+ assert candidate.tool_trace
409
+ assert all(call.output for call in candidate.tool_trace)
410
+
411
+
412
  def test_swarm_v2_generator_reward_is_robust_to_parse_crashes():
413
  """Reward function must never raise: any malformed completion gets a floor reward."""
414
  cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
 
468
  loaded = json.loads(Path(artifacts[key]).read_text(encoding="utf-8"))
469
  assert loaded is not None
470
 
471
+ post_eval = payload["post_training_evaluation"]
472
+ assert Path(post_eval["path"]).exists()
473
+ assert sorted(post_eval["answerer_models"].keys()) == ["finetuned_answerer", "original_answerer"]
474
+ assert json.loads(Path(post_eval["path"]).read_text(encoding="utf-8"))["skipped"] is True
475
+
476
 
477
  def test_swarm_v2_fixed_canonical_mode_reuses_prompt_candidates(tmp_path: Path):
478
  env_cfg = EnvironmentConfig(seed=19, n_users=14, max_steps=6)
tests/test_swarm_agent.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from osint_env.agents.swarm_agent import SwarmAgentRunner
2
  from osint_env.domain.models import EnvironmentConfig, SwarmConfig
3
  from osint_env.env.environment import OSINTEnvironment
@@ -15,3 +16,27 @@ def test_swarm_runner_emits_spawn_telemetry():
15
  assert info["spawn_count"] > 0
16
  assert "spawn_auxiliary" in info["reward_components"]
17
  assert info["spawn_critical_steps"] > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from osint_env.llm.interface import LLMResponse
2
  from osint_env.agents.swarm_agent import SwarmAgentRunner
3
  from osint_env.domain.models import EnvironmentConfig, SwarmConfig
4
  from osint_env.env.environment import OSINTEnvironment
 
16
  assert info["spawn_count"] > 0
17
  assert "spawn_auxiliary" in info["reward_components"]
18
  assert info["spawn_critical_steps"] > 0
19
+
20
+
21
+ class RecordingLLM:
22
+ def __init__(self):
23
+ self.tool_names: list[str] = []
24
+
25
+ def generate(self, messages, tools):
26
+ del messages
27
+ self.tool_names = [tool["function"]["name"] for tool in tools]
28
+ return LLMResponse(content="{}", tool_calls=[])
29
+
30
+
31
+ def test_swarm_runner_passes_lookup_tools_to_llm():
32
+ config = EnvironmentConfig(
33
+ seed=16,
34
+ max_steps=6,
35
+ swarm=SwarmConfig(enabled=True, max_agents=2, max_breadth=2, max_width=2, max_depth=1, planner_rounds=1),
36
+ )
37
+ env = OSINTEnvironment(config)
38
+ llm = RecordingLLM()
39
+ SwarmAgentRunner(env, llm=llm).run_episode()
40
+
41
+ assert "search_memory" in llm.tool_names
42
+ assert "search_shared_context" in llm.tool_names
tests/test_training_config.py CHANGED
@@ -14,6 +14,13 @@ def test_self_play_config_defaults_when_missing():
14
  assert cfg.generator_phase.max_steps >= 1
15
  assert cfg.answerer_phase.max_steps >= 1
16
  assert cfg.generator_reward_weights.hardness > 0.0
 
 
 
 
 
 
 
17
  assert cfg.swarm_v2.generator_swarm.shared_context is True
18
  assert cfg.swarm_v2.validation.max_support_edges >= 1
19
  assert cfg.wandb_enabled is False
@@ -41,6 +48,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
41
  "shared_model_name_or_path": "/models/local-base",
42
  "seed_tasks_per_round": 12,
43
  "generated_tasks_per_round": 18,
 
 
 
44
  "swarm_v2": {
45
  "generator_swarm": {
46
  "shared_context": True,
@@ -90,6 +100,12 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
90
  "model_name_or_path": "Qwen/Qwen2.5-3B-Instruct",
91
  "max_steps": 77,
92
  "num_generations": 6,
 
 
 
 
 
 
93
  "loss_type": "grpo",
94
  "scale_rewards": "group",
95
  "output_subdir": "gen_phase",
@@ -98,6 +114,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
98
  "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct",
99
  "max_steps": 55,
100
  "num_generations": 5,
 
 
 
101
  "output_subdir": "ans_phase",
102
  },
103
  }
@@ -121,6 +140,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
121
  assert cfg.shared_model_name_or_path == "/models/local-base"
122
  assert cfg.seed_tasks_per_round == 12
123
  assert cfg.generated_tasks_per_round == 18
 
 
 
124
  assert cfg.swarm_v2.generator_swarm.max_agents == 5
125
  assert cfg.swarm_v2.answerer_swarm.max_agents == 4
126
  assert cfg.swarm_v2.validation.max_support_edges == 6
@@ -133,6 +155,12 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
133
  assert cfg.generator_phase.model_name_or_path == "Qwen/Qwen2.5-3B-Instruct"
134
  assert cfg.generator_phase.max_steps == 77
135
  assert cfg.generator_phase.num_generations == 6
 
 
 
 
 
 
136
  assert cfg.generator_phase.loss_type == "grpo"
137
  assert cfg.generator_phase.scale_rewards == "group"
138
  assert cfg.generator_phase.output_subdir == "gen_phase"
@@ -140,6 +168,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
140
  assert cfg.answerer_phase.model_name_or_path == "Qwen/Qwen2.5-1.5B-Instruct"
141
  assert cfg.answerer_phase.max_steps == 55
142
  assert cfg.answerer_phase.num_generations == 5
 
 
 
143
  assert cfg.answerer_phase.output_subdir == "ans_phase"
144
 
145
 
 
14
  assert cfg.generator_phase.max_steps >= 1
15
  assert cfg.answerer_phase.max_steps >= 1
16
  assert cfg.generator_reward_weights.hardness > 0.0
17
+ assert cfg.generated_task_max_new_tokens >= 32
18
+ assert cfg.post_training_eval_questions >= 1
19
+ assert cfg.generator_phase.optim == "adamw_torch_fused"
20
+ assert cfg.generator_phase.bf16 is True
21
+ assert cfg.generator_phase.tf32 is True
22
+ assert cfg.generator_phase.generation_batch_size >= 1
23
+ assert cfg.generator_phase.max_prompt_length >= 32
24
  assert cfg.swarm_v2.generator_swarm.shared_context is True
25
  assert cfg.swarm_v2.validation.max_support_edges >= 1
26
  assert cfg.wandb_enabled is False
 
48
  "shared_model_name_or_path": "/models/local-base",
49
  "seed_tasks_per_round": 12,
50
  "generated_tasks_per_round": 18,
51
+ "generated_task_max_new_tokens": 640,
52
+ "post_training_eval_questions": 9,
53
+ "post_training_eval_answer_max_new_tokens": 96,
54
  "swarm_v2": {
55
  "generator_swarm": {
56
  "shared_context": True,
 
100
  "model_name_or_path": "Qwen/Qwen2.5-3B-Instruct",
101
  "max_steps": 77,
102
  "num_generations": 6,
103
+ "optim": "adamw_torch",
104
+ "bf16": False,
105
+ "tf32": False,
106
+ "generation_batch_size": 12,
107
+ "max_prompt_length": 768,
108
+ "save_total_limit": 3,
109
  "loss_type": "grpo",
110
  "scale_rewards": "group",
111
  "output_subdir": "gen_phase",
 
114
  "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct",
115
  "max_steps": 55,
116
  "num_generations": 5,
117
+ "dataloader_num_workers": 6,
118
+ "dataloader_persistent_workers": False,
119
+ "dataloader_prefetch_factor": 6,
120
  "output_subdir": "ans_phase",
121
  },
122
  }
 
140
  assert cfg.shared_model_name_or_path == "/models/local-base"
141
  assert cfg.seed_tasks_per_round == 12
142
  assert cfg.generated_tasks_per_round == 18
143
+ assert cfg.generated_task_max_new_tokens == 640
144
+ assert cfg.post_training_eval_questions == 9
145
+ assert cfg.post_training_eval_answer_max_new_tokens == 96
146
  assert cfg.swarm_v2.generator_swarm.max_agents == 5
147
  assert cfg.swarm_v2.answerer_swarm.max_agents == 4
148
  assert cfg.swarm_v2.validation.max_support_edges == 6
 
155
  assert cfg.generator_phase.model_name_or_path == "Qwen/Qwen2.5-3B-Instruct"
156
  assert cfg.generator_phase.max_steps == 77
157
  assert cfg.generator_phase.num_generations == 6
158
+ assert cfg.generator_phase.optim == "adamw_torch"
159
+ assert cfg.generator_phase.bf16 is False
160
+ assert cfg.generator_phase.tf32 is False
161
+ assert cfg.generator_phase.generation_batch_size == 12
162
+ assert cfg.generator_phase.max_prompt_length == 768
163
+ assert cfg.generator_phase.save_total_limit == 3
164
  assert cfg.generator_phase.loss_type == "grpo"
165
  assert cfg.generator_phase.scale_rewards == "group"
166
  assert cfg.generator_phase.output_subdir == "gen_phase"
 
168
  assert cfg.answerer_phase.model_name_or_path == "Qwen/Qwen2.5-1.5B-Instruct"
169
  assert cfg.answerer_phase.max_steps == 55
170
  assert cfg.answerer_phase.num_generations == 5
171
+ assert cfg.answerer_phase.dataloader_num_workers == 6
172
+ assert cfg.answerer_phase.dataloader_persistent_workers is False
173
+ assert cfg.answerer_phase.dataloader_prefetch_factor == 6
174
  assert cfg.answerer_phase.output_subdir == "ans_phase"
175
 
176