Spaces:
Paused
Paused
siddeshwar-kagatikar commited on
Commit ·
fe1f842
1
Parent(s): d814291
Sync current main to Hugging Face Space
Browse files- README.md +51 -18
- config/self_play_training_example.json +30 -7
- config/self_play_training_hf_a10g_smoke.json +27 -4
- docs/adversarial_self_play.md +57 -1
- pyproject.toml +2 -0
- scripts/space_start.sh +30 -7
- src/osint_env/agents/single_agent.py +35 -2
- src/osint_env/agents/swarm_agent.py +21 -1
- src/osint_env/baselines/openai_runner.py +1 -117
- src/osint_env/env/environment.py +71 -0
- src/osint_env/platforms/tool_schemas.py +132 -0
- src/osint_env/training/__init__.py +2 -0
- src/osint_env/training/config.py +83 -11
- src/osint_env/training/hf_jobs.py +331 -0
- src/osint_env/training/rewards.py +65 -34
- src/osint_env/training/self_play.py +459 -58
- tests/test_environment.py +16 -0
- tests/test_hf_jobs.py +64 -0
- tests/test_openai_baseline.py +1 -0
- tests/test_self_play_swarm_v2.py +44 -3
- tests/test_swarm_agent.py +25 -0
- tests/test_training_config.py +31 -0
README.md
CHANGED
|
@@ -174,9 +174,27 @@ osint-env train-self-play --config config/shared_config.json --train-config conf
|
|
| 174 |
|
| 175 |
When you have compute and the train dependencies installed, remove `--dry-run` (or set `"dry_run": false` in the training config) to execute TRL GRPO updates for alternating generator and answerer phases.
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
The training config also supports `"model_topology": "dual"|"shared"`, `"phase_schedule": "generator_answerer"|"answerer_generator_answerer"`, `"tuning_mode": "full"|"lora"`, and `"canonical_graph_mode": "generate"|"fixed"` so you can switch between two-model vs single-model self-play, full fine-tuning vs LoRA adapters, and whether canonical graph structure is generated each round or kept fixed while training question/answer behavior.
|
| 178 |
|
| 179 |
-
### Hugging Face
|
| 180 |
|
| 181 |
For a short verification run (enough to confirm W&B logging before scaling up), use:
|
| 182 |
|
|
@@ -186,27 +204,42 @@ osint-env train-self-play --config config/shared_config.json --train-config conf
|
|
| 186 |
|
| 187 |
This config:
|
| 188 |
|
| 189 |
-
- uses `Qwen/
|
| 190 |
- enables W&B reporting (`wandb_enabled: true`)
|
| 191 |
- uses `pipeline_mode: "swarm_v2"` with `canonical_graph_mode: "fixed"` to keep canonical graph candidates stable while training question/answer behavior
|
| 192 |
-
- keeps training intentionally short (`rounds=
|
| 193 |
-
- uses
|
| 194 |
|
| 195 |
To enable canonical graph generation during swarm_v2 training, switch `"canonical_graph_mode"` to `"generate"` in the training config.
|
| 196 |
|
| 197 |
-
Space
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
1.
|
| 200 |
-
2.
|
| 201 |
-
3.
|
| 202 |
-
4. Optionally set `WANDB_ENTITY` if your project belongs to a team.
|
| 203 |
-
5. Set `RUN_SELF_PLAY_TRAINING=1` in Space variables to trigger training during container startup.
|
| 204 |
-
6. Optional overrides:
|
| 205 |
- `TRAIN_SELF_PLAY_CONFIG_PATH` (default: `config/self_play_training_hf_a10g_smoke.json`)
|
| 206 |
- `TRAIN_ENV_CONFIG_PATH` (default: `config/shared_config.json`)
|
| 207 |
-
- `
|
| 208 |
-
- `
|
| 209 |
-
|
|
|
|
| 210 |
|
| 211 |
W&B run naming is controlled by `wandb_run_name_prefix` and will emit phase-specific runs like `...-r001-generator` and `...-r001-answerer`.
|
| 212 |
|
|
@@ -229,10 +262,10 @@ In `legacy` pipeline mode, the reward is a weighted sum:
|
|
| 229 |
|
| 230 |
Default weights (configurable through `generator_reward_weights` in training config):
|
| 231 |
|
| 232 |
-
- `validity`: `0.
|
| 233 |
-
- `hardness`: `0.
|
| 234 |
-
- `diversity`: `0.
|
| 235 |
-
- `consistency`: `0.
|
| 236 |
|
| 237 |
In `swarm_v2` pipeline mode, generation uses strict replay/validation first, then a structured reward:
|
| 238 |
|
|
|
|
| 174 |
|
| 175 |
When you have compute and the train dependencies installed, remove `--dry-run` (or set `"dry_run": false` in the training config) to execute TRL GRPO updates for alternating generator and answerer phases.
|
| 176 |
|
| 177 |
+
For a standalone Linux server or SSH box, there is also a wrapper script that activates a venv, optionally installs train deps, and runs the same command:
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
VENV_PATH="$HOME/arl" \
|
| 181 |
+
INSTALL_TRAIN_DEPS=1 \
|
| 182 |
+
TRAIN_ENV_CONFIG_PATH="config/shared_config.json" \
|
| 183 |
+
TRAIN_SELF_PLAY_CONFIG_PATH="config/self_play_training_hf_a10g_smoke.json" \
|
| 184 |
+
TRAIN_SELF_PLAY_OUTPUT_DIR="artifacts/self_play_server" \
|
| 185 |
+
bash scripts/train_self_play_standalone.sh
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
Useful overrides for the standalone script:
|
| 189 |
+
|
| 190 |
+
- `BOOTSTRAP_VENV=1` to create the virtualenv if it does not exist
|
| 191 |
+
- `TRAIN_SELF_PLAY_ROUNDS=2` to override the number of rounds
|
| 192 |
+
- `RUN_SELF_PLAY_DRY_RUN=1` to materialize artifacts without GRPO updates
|
| 193 |
+
- `TRAIN_SETUP_COMMAND='python -m pip install flash-attn --no-build-isolation'` for host-specific extras
|
| 194 |
+
|
| 195 |
The training config also supports `"model_topology": "dual"|"shared"`, `"phase_schedule": "generator_answerer"|"answerer_generator_answerer"`, `"tuning_mode": "full"|"lora"`, and `"canonical_graph_mode": "generate"|"fixed"` so you can switch between two-model vs single-model self-play, full fine-tuning vs LoRA adapters, and whether canonical graph structure is generated each round or kept fixed while training question/answer behavior.
|
| 196 |
|
| 197 |
+
### Hugging Face Job A10G Run (Separate From The Space)
|
| 198 |
|
| 199 |
For a short verification run (enough to confirm W&B logging before scaling up), use:
|
| 200 |
|
|
|
|
| 204 |
|
| 205 |
This config:
|
| 206 |
|
| 207 |
+
- uses `Qwen/Qwen2.5-0.5B-Instruct`
|
| 208 |
- enables W&B reporting (`wandb_enabled: true`)
|
| 209 |
- uses `pipeline_mode: "swarm_v2"` with `canonical_graph_mode: "fixed"` to keep canonical graph candidates stable while training question/answer behavior
|
| 210 |
+
- keeps training intentionally short (`rounds=2`, `max_steps=50` per phase)
|
| 211 |
+
- uses full fine-tuning plus fused AdamW, bf16/tf32, larger generation batches, and extra dataloader workers to make better use of an A10G
|
| 212 |
|
| 213 |
To enable canonical graph generation during swarm_v2 training, switch `"canonical_graph_mode"` to `"generate"` in the training config.
|
| 214 |
|
| 215 |
+
If you want the Space to stay on CPU and train separately on paid GPU compute, launch a dedicated Hugging Face Job instead of training inside the Space:
|
| 216 |
+
|
| 217 |
+
```bash
|
| 218 |
+
osint-env-launch-hf-job \
|
| 219 |
+
--hf-token "$HF_TOKEN" \
|
| 220 |
+
--job-image "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel" \
|
| 221 |
+
--repo-url "https://github.com/your-org/meta-knowledge-graph.git" \
|
| 222 |
+
--repo-ref "main" \
|
| 223 |
+
--flavor "a10g-small" \
|
| 224 |
+
--env-config "config/shared_config.json" \
|
| 225 |
+
--train-config "config/self_play_training_hf_a10g_smoke.json" \
|
| 226 |
+
--output-bucket "your-hf-bucket" \
|
| 227 |
+
--wait
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
The launcher talks to the Hugging Face Jobs API through `huggingface_hub`, so the Space can remain on CPU while the training job runs on separate A10G compute.
|
| 231 |
+
|
| 232 |
+
Optional Space startup wiring still exists if you want it:
|
| 233 |
|
| 234 |
+
1. Keep the Space on CPU if it is serving inference/UI only.
|
| 235 |
+
2. Set `RUN_SELF_PLAY_TRAINING=1` only if you intentionally want startup-time training inside the Space container.
|
| 236 |
+
3. Optional overrides:
|
|
|
|
|
|
|
|
|
|
| 237 |
- `TRAIN_SELF_PLAY_CONFIG_PATH` (default: `config/self_play_training_hf_a10g_smoke.json`)
|
| 238 |
- `TRAIN_ENV_CONFIG_PATH` (default: `config/shared_config.json`)
|
| 239 |
+
- `TRAIN_SELF_PLAY_OUTPUT_DIR` to override where artifacts land
|
| 240 |
+
- `RUN_SELF_PLAY_DRY_RUN=1` to test startup wiring without GRPO updates
|
| 241 |
+
- `RUN_SELF_PLAY_BACKGROUND=1` to keep the API up while startup-time training runs
|
| 242 |
+
- `OSINT_TRAIN_STRICT_ASSERTS=1` to fail fast when reward variance, KL, loss, grad norms, or parameter updates stay zero
|
| 243 |
|
| 244 |
W&B run naming is controlled by `wandb_run_name_prefix` and will emit phase-specific runs like `...-r001-generator` and `...-r001-answerer`.
|
| 245 |
|
|
|
|
| 262 |
|
| 263 |
Default weights (configurable through `generator_reward_weights` in training config):
|
| 264 |
|
| 265 |
+
- `validity`: `0.45`
|
| 266 |
+
- `hardness`: `0.20`
|
| 267 |
+
- `diversity`: `0.15`
|
| 268 |
+
- `consistency`: `0.20`
|
| 269 |
|
| 270 |
In `swarm_v2` pipeline mode, generation uses strict replay/validation first, then a structured reward:
|
| 271 |
|
config/self_play_training_example.json
CHANGED
|
@@ -15,11 +15,14 @@
|
|
| 15 |
"max_graph_context_edges": 100,
|
| 16 |
"max_support_edges": 8,
|
| 17 |
"answerer_judge_max_new_tokens": 48,
|
|
|
|
|
|
|
|
|
|
| 18 |
"generator_reward_weights": {
|
| 19 |
-
"validity": 0.
|
| 20 |
-
"hardness": 0.
|
| 21 |
-
"diversity": 0.
|
| 22 |
-
"consistency": 0.
|
| 23 |
},
|
| 24 |
"lora": {
|
| 25 |
"r": 16,
|
|
@@ -62,12 +65,14 @@
|
|
| 62 |
},
|
| 63 |
"generator_phase": {
|
| 64 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 65 |
-
"learning_rate":
|
| 66 |
"max_steps": 64,
|
| 67 |
"per_device_train_batch_size": 2,
|
| 68 |
"gradient_accumulation_steps": 4,
|
| 69 |
"num_generations": 4,
|
| 70 |
-
"max_completion_length":
|
|
|
|
|
|
|
| 71 |
"temperature": 1.0,
|
| 72 |
"top_p": 1.0,
|
| 73 |
"beta": 0.01,
|
|
@@ -77,18 +82,28 @@
|
|
| 77 |
"scale_rewards": "none",
|
| 78 |
"logging_steps": 10,
|
| 79 |
"save_steps": 50,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"output_subdir": "generator",
|
| 81 |
"use_vllm": false,
|
| 82 |
"vllm_mode": "colocate"
|
| 83 |
},
|
| 84 |
"answerer_phase": {
|
| 85 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 86 |
-
"learning_rate":
|
| 87 |
"max_steps": 64,
|
| 88 |
"per_device_train_batch_size": 2,
|
| 89 |
"gradient_accumulation_steps": 4,
|
| 90 |
"num_generations": 4,
|
| 91 |
"max_completion_length": 192,
|
|
|
|
|
|
|
| 92 |
"temperature": 1.0,
|
| 93 |
"top_p": 1.0,
|
| 94 |
"beta": 0.01,
|
|
@@ -98,6 +113,14 @@
|
|
| 98 |
"scale_rewards": "none",
|
| 99 |
"logging_steps": 10,
|
| 100 |
"save_steps": 50,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
"output_subdir": "answerer",
|
| 102 |
"use_vllm": false,
|
| 103 |
"vllm_mode": "colocate"
|
|
|
|
| 15 |
"max_graph_context_edges": 100,
|
| 16 |
"max_support_edges": 8,
|
| 17 |
"answerer_judge_max_new_tokens": 48,
|
| 18 |
+
"generated_task_max_new_tokens": 512,
|
| 19 |
+
"post_training_eval_questions": 24,
|
| 20 |
+
"post_training_eval_answer_max_new_tokens": 128,
|
| 21 |
"generator_reward_weights": {
|
| 22 |
+
"validity": 0.45,
|
| 23 |
+
"hardness": 0.2,
|
| 24 |
+
"diversity": 0.15,
|
| 25 |
+
"consistency": 0.2
|
| 26 |
},
|
| 27 |
"lora": {
|
| 28 |
"r": 16,
|
|
|
|
| 65 |
},
|
| 66 |
"generator_phase": {
|
| 67 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 68 |
+
"learning_rate": 5e-06,
|
| 69 |
"max_steps": 64,
|
| 70 |
"per_device_train_batch_size": 2,
|
| 71 |
"gradient_accumulation_steps": 4,
|
| 72 |
"num_generations": 4,
|
| 73 |
+
"max_completion_length": 384,
|
| 74 |
+
"max_prompt_length": 1024,
|
| 75 |
+
"generation_batch_size": 8,
|
| 76 |
"temperature": 1.0,
|
| 77 |
"top_p": 1.0,
|
| 78 |
"beta": 0.01,
|
|
|
|
| 82 |
"scale_rewards": "none",
|
| 83 |
"logging_steps": 10,
|
| 84 |
"save_steps": 50,
|
| 85 |
+
"save_total_limit": 2,
|
| 86 |
+
"optim": "adamw_torch_fused",
|
| 87 |
+
"bf16": true,
|
| 88 |
+
"tf32": true,
|
| 89 |
+
"gradient_checkpointing": false,
|
| 90 |
+
"dataloader_num_workers": 2,
|
| 91 |
+
"dataloader_persistent_workers": true,
|
| 92 |
+
"dataloader_prefetch_factor": 2,
|
| 93 |
"output_subdir": "generator",
|
| 94 |
"use_vllm": false,
|
| 95 |
"vllm_mode": "colocate"
|
| 96 |
},
|
| 97 |
"answerer_phase": {
|
| 98 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 99 |
+
"learning_rate": 3e-06,
|
| 100 |
"max_steps": 64,
|
| 101 |
"per_device_train_batch_size": 2,
|
| 102 |
"gradient_accumulation_steps": 4,
|
| 103 |
"num_generations": 4,
|
| 104 |
"max_completion_length": 192,
|
| 105 |
+
"max_prompt_length": 1024,
|
| 106 |
+
"generation_batch_size": 8,
|
| 107 |
"temperature": 1.0,
|
| 108 |
"top_p": 1.0,
|
| 109 |
"beta": 0.01,
|
|
|
|
| 113 |
"scale_rewards": "none",
|
| 114 |
"logging_steps": 10,
|
| 115 |
"save_steps": 50,
|
| 116 |
+
"save_total_limit": 2,
|
| 117 |
+
"optim": "adamw_torch_fused",
|
| 118 |
+
"bf16": true,
|
| 119 |
+
"tf32": true,
|
| 120 |
+
"gradient_checkpointing": false,
|
| 121 |
+
"dataloader_num_workers": 2,
|
| 122 |
+
"dataloader_persistent_workers": true,
|
| 123 |
+
"dataloader_prefetch_factor": 2,
|
| 124 |
"output_subdir": "answerer",
|
| 125 |
"use_vllm": false,
|
| 126 |
"vllm_mode": "colocate"
|
config/self_play_training_hf_a10g_smoke.json
CHANGED
|
@@ -10,7 +10,7 @@
|
|
| 10 |
"canonical_graph_mode": "fixed",
|
| 11 |
"model_topology": "shared",
|
| 12 |
"phase_schedule": "generator_answerer",
|
| 13 |
-
"tuning_mode": "
|
| 14 |
"shared_model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 15 |
"seed_tasks_per_round": 16,
|
| 16 |
"generated_tasks_per_round": 24,
|
|
@@ -19,14 +19,19 @@
|
|
| 19 |
"max_graph_context_edges": 24,
|
| 20 |
"max_support_edges": 6,
|
| 21 |
"answerer_judge_max_new_tokens": 32,
|
|
|
|
|
|
|
|
|
|
| 22 |
"generator_phase": {
|
| 23 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 24 |
-
"learning_rate":
|
| 25 |
"max_steps": 50,
|
| 26 |
"per_device_train_batch_size": 4,
|
| 27 |
"gradient_accumulation_steps": 1,
|
| 28 |
"num_generations": 4,
|
| 29 |
-
"max_completion_length":
|
|
|
|
|
|
|
| 30 |
"temperature": 0.9,
|
| 31 |
"top_p": 0.95,
|
| 32 |
"repetition_penalty": 1.1,
|
|
@@ -37,18 +42,28 @@
|
|
| 37 |
"scale_rewards": "group",
|
| 38 |
"logging_steps": 1,
|
| 39 |
"save_steps": 10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"output_subdir": "generator_train",
|
| 41 |
"use_vllm": false,
|
| 42 |
"vllm_mode": "colocate"
|
| 43 |
},
|
| 44 |
"answerer_phase": {
|
| 45 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 46 |
-
"learning_rate":
|
| 47 |
"max_steps": 50,
|
| 48 |
"per_device_train_batch_size": 4,
|
| 49 |
"gradient_accumulation_steps": 1,
|
| 50 |
"num_generations": 4,
|
| 51 |
"max_completion_length": 256,
|
|
|
|
|
|
|
| 52 |
"temperature": 0.7,
|
| 53 |
"top_p": 0.95,
|
| 54 |
"repetition_penalty": 1.1,
|
|
@@ -59,6 +74,14 @@
|
|
| 59 |
"scale_rewards": "group",
|
| 60 |
"logging_steps": 1,
|
| 61 |
"save_steps": 10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
"output_subdir": "answerer_train",
|
| 63 |
"use_vllm": false,
|
| 64 |
"vllm_mode": "colocate"
|
|
|
|
| 10 |
"canonical_graph_mode": "fixed",
|
| 11 |
"model_topology": "shared",
|
| 12 |
"phase_schedule": "generator_answerer",
|
| 13 |
+
"tuning_mode": "full",
|
| 14 |
"shared_model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 15 |
"seed_tasks_per_round": 16,
|
| 16 |
"generated_tasks_per_round": 24,
|
|
|
|
| 19 |
"max_graph_context_edges": 24,
|
| 20 |
"max_support_edges": 6,
|
| 21 |
"answerer_judge_max_new_tokens": 32,
|
| 22 |
+
"generated_task_max_new_tokens": 640,
|
| 23 |
+
"post_training_eval_questions": 24,
|
| 24 |
+
"post_training_eval_answer_max_new_tokens": 128,
|
| 25 |
"generator_phase": {
|
| 26 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 27 |
+
"learning_rate": 5e-06,
|
| 28 |
"max_steps": 50,
|
| 29 |
"per_device_train_batch_size": 4,
|
| 30 |
"gradient_accumulation_steps": 1,
|
| 31 |
"num_generations": 4,
|
| 32 |
+
"max_completion_length": 384,
|
| 33 |
+
"max_prompt_length": 768,
|
| 34 |
+
"generation_batch_size": 16,
|
| 35 |
"temperature": 0.9,
|
| 36 |
"top_p": 0.95,
|
| 37 |
"repetition_penalty": 1.1,
|
|
|
|
| 42 |
"scale_rewards": "group",
|
| 43 |
"logging_steps": 1,
|
| 44 |
"save_steps": 10,
|
| 45 |
+
"save_total_limit": 2,
|
| 46 |
+
"optim": "adamw_torch_fused",
|
| 47 |
+
"bf16": true,
|
| 48 |
+
"tf32": true,
|
| 49 |
+
"gradient_checkpointing": false,
|
| 50 |
+
"dataloader_num_workers": 4,
|
| 51 |
+
"dataloader_persistent_workers": true,
|
| 52 |
+
"dataloader_prefetch_factor": 4,
|
| 53 |
"output_subdir": "generator_train",
|
| 54 |
"use_vllm": false,
|
| 55 |
"vllm_mode": "colocate"
|
| 56 |
},
|
| 57 |
"answerer_phase": {
|
| 58 |
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 59 |
+
"learning_rate": 3e-06,
|
| 60 |
"max_steps": 50,
|
| 61 |
"per_device_train_batch_size": 4,
|
| 62 |
"gradient_accumulation_steps": 1,
|
| 63 |
"num_generations": 4,
|
| 64 |
"max_completion_length": 256,
|
| 65 |
+
"max_prompt_length": 768,
|
| 66 |
+
"generation_batch_size": 16,
|
| 67 |
"temperature": 0.7,
|
| 68 |
"top_p": 0.95,
|
| 69 |
"repetition_penalty": 1.1,
|
|
|
|
| 74 |
"scale_rewards": "group",
|
| 75 |
"logging_steps": 1,
|
| 76 |
"save_steps": 10,
|
| 77 |
+
"save_total_limit": 2,
|
| 78 |
+
"optim": "adamw_torch_fused",
|
| 79 |
+
"bf16": true,
|
| 80 |
+
"tf32": true,
|
| 81 |
+
"gradient_checkpointing": false,
|
| 82 |
+
"dataloader_num_workers": 4,
|
| 83 |
+
"dataloader_persistent_workers": true,
|
| 84 |
+
"dataloader_prefetch_factor": 4,
|
| 85 |
"output_subdir": "answerer_train",
|
| 86 |
"use_vllm": false,
|
| 87 |
"vllm_mode": "colocate"
|
docs/adversarial_self_play.md
CHANGED
|
@@ -59,6 +59,14 @@ This directly supports the "train solver, freeze, attack, retrain solver" sequen
|
|
| 59 |
|
| 60 |
Weights are configurable in `generator_reward_weights`.
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
### Answerer (existing reward integration)
|
| 63 |
|
| 64 |
`AnswererRewardFunction` wraps existing environment reward logic:
|
|
@@ -88,12 +96,60 @@ In dry run mode, the pipeline still:
|
|
| 88 |
|
| 89 |
But it skips expensive GRPO updates.
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
## Compute Mode
|
| 92 |
|
| 93 |
When compute is available:
|
| 94 |
|
| 95 |
1. Install train dependencies: `python -m pip install -e ".[train]"`
|
| 96 |
2. Disable dry run (`--dry-run` off and/or `"dry_run": false` in config).
|
| 97 |
-
3. Run `osint-env train-self-play`.
|
| 98 |
|
| 99 |
Outputs are written under `artifacts/self_play` unless overridden.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
Weights are configurable in `generator_reward_weights`.
|
| 61 |
|
| 62 |
+
For `swarm_v2`, the reward now prioritizes:
|
| 63 |
+
|
| 64 |
+
- Valid, replayable task structure first.
|
| 65 |
+
- Hardness against the frozen answerer second.
|
| 66 |
+
- Diversity and compact multi-agent/shared-context usage after validity.
|
| 67 |
+
|
| 68 |
+
This avoids the degenerate regime where almost every sample is invalid and the whole batch stays negative.
|
| 69 |
+
|
| 70 |
### Answerer (existing reward integration)
|
| 71 |
|
| 72 |
`AnswererRewardFunction` wraps existing environment reward logic:
|
|
|
|
| 96 |
|
| 97 |
But it skips expensive GRPO updates.
|
| 98 |
|
| 99 |
+
## Post-Training Evaluation
|
| 100 |
+
|
| 101 |
+
After a non-dry-run training job completes, the runner now writes a post-training evaluation artifact that:
|
| 102 |
+
|
| 103 |
+
- Uses the finetuned generator to create fresh evaluation questions.
|
| 104 |
+
- Evaluates both the finetuned answerer and the original/base answerer on those generated questions.
|
| 105 |
+
- Reports a `delta_vs_original` summary so you can see whether fine-tuning actually improved task success, reward, and graph F1.
|
| 106 |
+
- Saves the summary and episode rows under `post_training_evaluation.json`.
|
| 107 |
+
|
| 108 |
+
You can control this flow with:
|
| 109 |
+
|
| 110 |
+
- `generated_task_max_new_tokens`: decoding budget for generator-side task sampling/eval.
|
| 111 |
+
- `post_training_eval_questions`: how many fresh tasks to evaluate after training.
|
| 112 |
+
- `post_training_eval_answer_max_new_tokens`: answerer decoding budget for the final eval pass.
|
| 113 |
+
|
| 114 |
+
## Checkpoints And Final Models
|
| 115 |
+
|
| 116 |
+
Self-play outputs are written under `output_dir` (default `artifacts/self_play`) unless overridden.
|
| 117 |
+
|
| 118 |
+
Per round and phase you will now find:
|
| 119 |
+
|
| 120 |
+
- `round_XXX/<phase>/checkpoint-*`: intermediate trainer checkpoints saved every `save_steps`.
|
| 121 |
+
- `round_XXX/<phase>/final_model`: final saved model for that phase, with tokenizer files.
|
| 122 |
+
- `self_play_summary.json`: top-level run summary.
|
| 123 |
+
- `post_training_evaluation.json`: generated-question evaluation written after training.
|
| 124 |
+
|
| 125 |
## Compute Mode
|
| 126 |
|
| 127 |
When compute is available:
|
| 128 |
|
| 129 |
1. Install train dependencies: `python -m pip install -e ".[train]"`
|
| 130 |
2. Disable dry run (`--dry-run` off and/or `"dry_run": false` in config).
|
| 131 |
+
3. Run `osint-env train-self-play`, or launch a dedicated Hugging Face Job with `osint-env-launch-hf-job` if you want the Space to stay on CPU while training runs on separate GPU compute.
|
| 132 |
|
| 133 |
Outputs are written under `artifacts/self_play` unless overridden.
|
| 134 |
+
|
| 135 |
+
## Standalone Server Script
|
| 136 |
+
|
| 137 |
+
For an SSH server or other standalone machine, you can use `scripts/train_self_play_standalone.sh`.
|
| 138 |
+
|
| 139 |
+
Example:
|
| 140 |
+
|
| 141 |
+
```bash
|
| 142 |
+
VENV_PATH="$HOME/arl" \
|
| 143 |
+
INSTALL_TRAIN_DEPS=1 \
|
| 144 |
+
TRAIN_ENV_CONFIG_PATH="config/shared_config.json" \
|
| 145 |
+
TRAIN_SELF_PLAY_CONFIG_PATH="config/self_play_training_hf_a10g_smoke.json" \
|
| 146 |
+
TRAIN_SELF_PLAY_OUTPUT_DIR="artifacts/self_play_server" \
|
| 147 |
+
bash scripts/train_self_play_standalone.sh
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Useful environment variables:
|
| 151 |
+
|
| 152 |
+
- `BOOTSTRAP_VENV=1`: create the virtualenv automatically if it does not exist yet.
|
| 153 |
+
- `TRAIN_SELF_PLAY_ROUNDS=2`: override the round count without editing JSON.
|
| 154 |
+
- `RUN_SELF_PLAY_DRY_RUN=1`: skip GRPO updates and only materialize artifacts.
|
| 155 |
+
- `TRAIN_SETUP_COMMAND='python -m pip install flash-attn --no-build-isolation'`: run any host-specific setup before training.
|
pyproject.toml
CHANGED
|
@@ -22,6 +22,7 @@ train = [
|
|
| 22 |
"accelerate>=0.33.0",
|
| 23 |
"trl>=0.15.0",
|
| 24 |
"peft>=0.11.0",
|
|
|
|
| 25 |
"pillow",
|
| 26 |
"torchvision",
|
| 27 |
"wandb",
|
|
@@ -30,6 +31,7 @@ train = [
|
|
| 30 |
[project.scripts]
|
| 31 |
osint-env = "osint_env.cli:main"
|
| 32 |
server = "osint_env.server_entry:main"
|
|
|
|
| 33 |
|
| 34 |
[build-system]
|
| 35 |
requires = ["setuptools>=68", "wheel"]
|
|
|
|
| 22 |
"accelerate>=0.33.0",
|
| 23 |
"trl>=0.15.0",
|
| 24 |
"peft>=0.11.0",
|
| 25 |
+
"huggingface_hub>=0.34.0",
|
| 26 |
"pillow",
|
| 27 |
"torchvision",
|
| 28 |
"wandb",
|
|
|
|
| 31 |
[project.scripts]
|
| 32 |
osint-env = "osint_env.cli:main"
|
| 33 |
server = "osint_env.server_entry:main"
|
| 34 |
+
osint-env-launch-hf-job = "osint_env.training.hf_jobs:main"
|
| 35 |
|
| 36 |
[build-system]
|
| 37 |
requires = ["setuptools>=68", "wheel"]
|
scripts/space_start.sh
CHANGED
|
@@ -10,23 +10,46 @@ _is_true() {
|
|
| 10 |
|
| 11 |
ENV_CONFIG_PATH="${TRAIN_ENV_CONFIG_PATH:-config/shared_config.json}"
|
| 12 |
TRAIN_CONFIG_PATH="${TRAIN_SELF_PLAY_CONFIG_PATH:-config/self_play_training_hf_a10g_smoke.json}"
|
|
|
|
| 13 |
RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-0}"
|
| 14 |
DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
if _is_true "$RUN_FLAG"; then
|
| 17 |
-
echo "[space_start] RUN_SELF_PLAY_TRAINING enabled."
|
| 18 |
-
echo "[space_start] Training start: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
|
| 19 |
-
echo "[space_start] Env config: ${ENV_CONFIG_PATH}"
|
| 20 |
-
echo "[space_start] Train config: ${TRAIN_CONFIG_PATH}"
|
| 21 |
if _is_true "$DRY_RUN_FLAG"; then
|
| 22 |
echo "[space_start] Running self-play in dry-run mode."
|
| 23 |
-
|
|
|
|
| 24 |
else
|
| 25 |
echo "[space_start] Running self-play training."
|
| 26 |
-
|
|
|
|
| 27 |
fi
|
|
|
|
| 28 |
echo "[space_start] Self-play command completed."
|
| 29 |
echo "[space_start] Training end: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
else
|
| 31 |
echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
|
| 32 |
fi
|
|
|
|
| 10 |
|
| 11 |
ENV_CONFIG_PATH="${TRAIN_ENV_CONFIG_PATH:-config/shared_config.json}"
|
| 12 |
TRAIN_CONFIG_PATH="${TRAIN_SELF_PLAY_CONFIG_PATH:-config/self_play_training_hf_a10g_smoke.json}"
|
| 13 |
+
TRAIN_OUTPUT_DIR="${TRAIN_SELF_PLAY_OUTPUT_DIR:-}"
|
| 14 |
RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-0}"
|
| 15 |
DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
|
| 16 |
+
BACKGROUND_FLAG="${RUN_SELF_PLAY_BACKGROUND:-1}"
|
| 17 |
+
|
| 18 |
+
_train_self_play() {
|
| 19 |
+
if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
|
| 20 |
+
OUTPUT_ARG="--train-output-dir ${TRAIN_OUTPUT_DIR}"
|
| 21 |
+
else
|
| 22 |
+
OUTPUT_ARG=""
|
| 23 |
+
fi
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if _is_true "$DRY_RUN_FLAG"; then
|
| 26 |
echo "[space_start] Running self-play in dry-run mode."
|
| 27 |
+
# shellcheck disable=SC2086
|
| 28 |
+
osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG} --dry-run
|
| 29 |
else
|
| 30 |
echo "[space_start] Running self-play training."
|
| 31 |
+
# shellcheck disable=SC2086
|
| 32 |
+
osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG}
|
| 33 |
fi
|
| 34 |
+
|
| 35 |
echo "[space_start] Self-play command completed."
|
| 36 |
echo "[space_start] Training end: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
if _is_true "$RUN_FLAG"; then
|
| 40 |
+
echo "[space_start] RUN_SELF_PLAY_TRAINING enabled."
|
| 41 |
+
echo "[space_start] Training start: $(date -u +"%Y-%m-%dT%H:%M:%SZ")"
|
| 42 |
+
echo "[space_start] Env config: ${ENV_CONFIG_PATH}"
|
| 43 |
+
echo "[space_start] Train config: ${TRAIN_CONFIG_PATH}"
|
| 44 |
+
if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
|
| 45 |
+
echo "[space_start] Train output dir: ${TRAIN_OUTPUT_DIR}"
|
| 46 |
+
fi
|
| 47 |
+
if _is_true "$BACKGROUND_FLAG"; then
|
| 48 |
+
echo "[space_start] Launching self-play in background so the Space API can stay online."
|
| 49 |
+
_train_self_play &
|
| 50 |
+
else
|
| 51 |
+
_train_self_play
|
| 52 |
+
fi
|
| 53 |
else
|
| 54 |
echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
|
| 55 |
fi
|
src/osint_env/agents/single_agent.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from osint_env.domain.models import Action, ActionType
|
| 4 |
from osint_env.env.environment import OSINTEnvironment
|
| 5 |
from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class SingleAgentRunner:
|
|
@@ -15,14 +18,31 @@ class SingleAgentRunner:
|
|
| 15 |
done = False
|
| 16 |
info = {}
|
| 17 |
while not done:
|
| 18 |
-
messages = [
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
try:
|
| 21 |
llm_resp = self.llm.generate(messages, tools)
|
| 22 |
planned_calls = llm_resp.tool_calls[:2]
|
| 23 |
except Exception:
|
| 24 |
planned_calls = []
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
for call in planned_calls:
|
| 27 |
obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
|
| 28 |
if done:
|
|
@@ -39,3 +59,16 @@ class SingleAgentRunner:
|
|
| 39 |
if token.startswith("alias_") or token.startswith("user_"):
|
| 40 |
return token
|
| 41 |
return "unknown"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
from osint_env.domain.models import Action, ActionType
|
| 6 |
from osint_env.env.environment import OSINTEnvironment
|
| 7 |
from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
|
| 8 |
+
from osint_env.platforms.tool_schemas import build_lookup_tools
|
| 9 |
|
| 10 |
|
| 11 |
class SingleAgentRunner:
|
|
|
|
| 18 |
done = False
|
| 19 |
info = {}
|
| 20 |
while not done:
|
| 21 |
+
messages = [
|
| 22 |
+
{
|
| 23 |
+
"role": "system",
|
| 24 |
+
"content": (
|
| 25 |
+
f"question: {obs.task['question']}\n"
|
| 26 |
+
f"shared_context_available: {bool(obs.task.get('shared_context_available', False))}\n"
|
| 27 |
+
"Use lookup tools to gather evidence before answering."
|
| 28 |
+
),
|
| 29 |
+
}
|
| 30 |
+
]
|
| 31 |
+
tools = build_lookup_tools()
|
| 32 |
try:
|
| 33 |
llm_resp = self.llm.generate(messages, tools)
|
| 34 |
planned_calls = llm_resp.tool_calls[:2]
|
| 35 |
except Exception:
|
| 36 |
planned_calls = []
|
| 37 |
|
| 38 |
+
if not planned_calls and bool(obs.task.get("shared_context_available", False)):
|
| 39 |
+
planned_calls = [
|
| 40 |
+
{
|
| 41 |
+
"tool_name": "search_shared_context",
|
| 42 |
+
"args": {"query": self._shared_context_query(obs.task["question"]), "k": 5},
|
| 43 |
+
}
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
for call in planned_calls:
|
| 47 |
obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
|
| 48 |
if done:
|
|
|
|
| 59 |
if token.startswith("alias_") or token.startswith("user_"):
|
| 60 |
return token
|
| 61 |
return "unknown"
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def _shared_context_query(question: str) -> str:
|
| 65 |
+
id_match = re.search(r"\b(?:alias|user|post|thr|thread|org|loc|event)_[A-Za-z0-9_]+\b", question)
|
| 66 |
+
if id_match:
|
| 67 |
+
return id_match.group(0)
|
| 68 |
+
path_match = re.search(r"relation path\s+(.+?),\s*which entity", question, flags=re.IGNORECASE)
|
| 69 |
+
if path_match:
|
| 70 |
+
first_relation = path_match.group(1).split("->", 1)[0].strip()
|
| 71 |
+
if first_relation:
|
| 72 |
+
return first_relation
|
| 73 |
+
tokens = re.findall(r"[A-Za-z0-9_]+", question)
|
| 74 |
+
return tokens[0] if tokens else question
|
src/osint_env/agents/swarm_agent.py
CHANGED
|
@@ -7,6 +7,7 @@ from osint_env.domain.models import Action, ActionType
|
|
| 7 |
from osint_env.env.environment import OSINTEnvironment
|
| 8 |
from osint_env.env.spawn_reward_hooks import critical_steps, parl_style_spawn_reward
|
| 9 |
from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class SwarmAgentRunner:
|
|
@@ -135,12 +136,13 @@ class SwarmAgentRunner:
|
|
| 135 |
"content": (
|
| 136 |
f"question: {obs.task['question']}\n"
|
| 137 |
f"agent_role: {role}_{agent_idx}\n"
|
|
|
|
| 138 |
"Return concise tool plan."
|
| 139 |
),
|
| 140 |
}
|
| 141 |
]
|
| 142 |
try:
|
| 143 |
-
response = self.llm.generate(messages, tools=
|
| 144 |
except Exception:
|
| 145 |
response = None
|
| 146 |
|
|
@@ -160,6 +162,11 @@ class SwarmAgentRunner:
|
|
| 160 |
return calls
|
| 161 |
|
| 162 |
question = str(obs.task.get("question", "")).lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if role == "explorer":
|
| 164 |
if "event" in question:
|
| 165 |
return [{"tool_name": "search_threads", "args": {"topic": "security"}}]
|
|
@@ -182,6 +189,19 @@ class SwarmAgentRunner:
|
|
| 182 |
|
| 183 |
return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
def _edge_plan(self, agent_idx: int) -> dict[str, Any] | None:
|
| 186 |
if self.env.state is None or not self.env.state.task.supporting_edges:
|
| 187 |
return None
|
|
|
|
| 7 |
from osint_env.env.environment import OSINTEnvironment
|
| 8 |
from osint_env.env.spawn_reward_hooks import critical_steps, parl_style_spawn_reward
|
| 9 |
from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
|
| 10 |
+
from osint_env.platforms.tool_schemas import build_lookup_tools
|
| 11 |
|
| 12 |
|
| 13 |
class SwarmAgentRunner:
|
|
|
|
| 136 |
"content": (
|
| 137 |
f"question: {obs.task['question']}\n"
|
| 138 |
f"agent_role: {role}_{agent_idx}\n"
|
| 139 |
+
f"shared_context_available: {bool(obs.task.get('shared_context_available', False))}\n"
|
| 140 |
"Return concise tool plan."
|
| 141 |
),
|
| 142 |
}
|
| 143 |
]
|
| 144 |
try:
|
| 145 |
+
response = self.llm.generate(messages, tools=build_lookup_tools())
|
| 146 |
except Exception:
|
| 147 |
response = None
|
| 148 |
|
|
|
|
| 162 |
return calls
|
| 163 |
|
| 164 |
question = str(obs.task.get("question", "")).lower()
|
| 165 |
+
shared_context_available = bool(obs.task.get("shared_context_available", False))
|
| 166 |
+
shared_query = self._shared_context_query(str(obs.task.get("question", "")))
|
| 167 |
+
if shared_context_available and role in {"explorer", "reasoner"}:
|
| 168 |
+
return [{"tool_name": "search_shared_context", "args": {"query": shared_query, "k": 5}}]
|
| 169 |
+
|
| 170 |
if role == "explorer":
|
| 171 |
if "event" in question:
|
| 172 |
return [{"tool_name": "search_threads", "args": {"topic": "security"}}]
|
|
|
|
| 189 |
|
| 190 |
return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
|
| 191 |
|
| 192 |
+
@staticmethod
|
| 193 |
+
def _shared_context_query(question: str) -> str:
|
| 194 |
+
id_match = re.search(r"\b(?:alias|user|post|thr|thread|org|loc|event)_[A-Za-z0-9_]+\b", question)
|
| 195 |
+
if id_match:
|
| 196 |
+
return id_match.group(0)
|
| 197 |
+
path_match = re.search(r"relation path\s+(.+?),\s*which entity", question, flags=re.IGNORECASE)
|
| 198 |
+
if path_match:
|
| 199 |
+
first_relation = path_match.group(1).split("->", 1)[0].strip()
|
| 200 |
+
if first_relation:
|
| 201 |
+
return first_relation
|
| 202 |
+
tokens = re.findall(r"[A-Za-z0-9_]+", question)
|
| 203 |
+
return tokens[0] if tokens else question
|
| 204 |
+
|
| 205 |
def _edge_plan(self, agent_idx: int) -> dict[str, Any] | None:
|
| 206 |
if self.env.state is None or not self.env.state.task.supporting_edges:
|
| 207 |
return None
|
src/osint_env/baselines/openai_runner.py
CHANGED
|
@@ -12,6 +12,7 @@ from osint_env.env.environment import OSINTEnvironment
|
|
| 12 |
from osint_env.env.reward import compute_graph_f1
|
| 13 |
from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard
|
| 14 |
from osint_env.eval.metrics import EvalMetrics
|
|
|
|
| 15 |
from osint_env.viz import export_dashboard
|
| 16 |
|
| 17 |
|
|
@@ -50,123 +51,6 @@ class OpenAIBaselineConfig:
|
|
| 50 |
max_steps: int = 8
|
| 51 |
seed: int | None = 7
|
| 52 |
append_leaderboard: bool = True
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _tool_schema(
|
| 56 |
-
name: str,
|
| 57 |
-
description: str,
|
| 58 |
-
properties: dict[str, Any],
|
| 59 |
-
required: list[str],
|
| 60 |
-
) -> dict[str, Any]:
|
| 61 |
-
return {
|
| 62 |
-
"type": "function",
|
| 63 |
-
"function": {
|
| 64 |
-
"name": name,
|
| 65 |
-
"description": description,
|
| 66 |
-
"parameters": {
|
| 67 |
-
"type": "object",
|
| 68 |
-
"properties": properties,
|
| 69 |
-
"required": required,
|
| 70 |
-
"additionalProperties": False,
|
| 71 |
-
},
|
| 72 |
-
},
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def build_action_tools() -> list[dict[str, Any]]:
|
| 77 |
-
return [
|
| 78 |
-
_tool_schema(
|
| 79 |
-
"search_posts",
|
| 80 |
-
"Search microblog posts by substring over post text, post id, author id, canonical user id, or referenced entity ids/names.",
|
| 81 |
-
{"query": {"type": "string", "description": "Substring to search for in post text."}},
|
| 82 |
-
["query"],
|
| 83 |
-
),
|
| 84 |
-
_tool_schema(
|
| 85 |
-
"get_post",
|
| 86 |
-
"Fetch a specific microblog post by exact post id.",
|
| 87 |
-
{"post_id": {"type": "string", "description": "Post node id such as post_midnight_manifest."}},
|
| 88 |
-
["post_id"],
|
| 89 |
-
),
|
| 90 |
-
_tool_schema(
|
| 91 |
-
"get_user_posts",
|
| 92 |
-
"Fetch posts authored by a user or alias id. Alias ids are resolved to the canonical user and vice versa.",
|
| 93 |
-
{"user_id": {"type": "string", "description": "User or alias node id."}},
|
| 94 |
-
["user_id"],
|
| 95 |
-
),
|
| 96 |
-
_tool_schema(
|
| 97 |
-
"get_mentions",
|
| 98 |
-
"Fetch posts that mention a given canonical user id.",
|
| 99 |
-
{"user_id": {"type": "string", "description": "Canonical user node id."}},
|
| 100 |
-
["user_id"],
|
| 101 |
-
),
|
| 102 |
-
_tool_schema(
|
| 103 |
-
"search_threads",
|
| 104 |
-
"Search forum threads by exact topic name.",
|
| 105 |
-
{"topic": {"type": "string", "description": "Thread topic such as security or ai."}},
|
| 106 |
-
["topic"],
|
| 107 |
-
),
|
| 108 |
-
_tool_schema(
|
| 109 |
-
"get_thread",
|
| 110 |
-
"Fetch a specific forum thread by id.",
|
| 111 |
-
{"thread_id": {"type": "string", "description": "Thread node id."}},
|
| 112 |
-
["thread_id"],
|
| 113 |
-
),
|
| 114 |
-
_tool_schema(
|
| 115 |
-
"get_user_activity",
|
| 116 |
-
"Fetch a user's known forum activity.",
|
| 117 |
-
{"user_id": {"type": "string", "description": "Canonical user node id."}},
|
| 118 |
-
["user_id"],
|
| 119 |
-
),
|
| 120 |
-
_tool_schema(
|
| 121 |
-
"get_profile",
|
| 122 |
-
"Fetch a profile record by canonical user id or alias id.",
|
| 123 |
-
{"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
|
| 124 |
-
["user_id"],
|
| 125 |
-
),
|
| 126 |
-
_tool_schema(
|
| 127 |
-
"search_people",
|
| 128 |
-
"Search profiles by name, alias id, organization name, or organization id.",
|
| 129 |
-
{
|
| 130 |
-
"name": {"type": "string", "description": "Optional name substring.", "default": ""},
|
| 131 |
-
"org": {"type": "string", "description": "Optional organization substring.", "default": ""},
|
| 132 |
-
},
|
| 133 |
-
[],
|
| 134 |
-
),
|
| 135 |
-
_tool_schema(
|
| 136 |
-
"get_connections",
|
| 137 |
-
"Fetch explicit profile connections for a user or alias id.",
|
| 138 |
-
{"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
|
| 139 |
-
["user_id"],
|
| 140 |
-
),
|
| 141 |
-
_tool_schema(
|
| 142 |
-
"search_memory",
|
| 143 |
-
"Search semantic memory over prior observations and tool outputs.",
|
| 144 |
-
{
|
| 145 |
-
"query": {"type": "string", "description": "Memory retrieval query."},
|
| 146 |
-
"k": {"type": "integer", "description": "Top-k matches.", "default": 5},
|
| 147 |
-
},
|
| 148 |
-
["query"],
|
| 149 |
-
),
|
| 150 |
-
_tool_schema(
|
| 151 |
-
"add_edge",
|
| 152 |
-
"Add a supported graph edge to the working memory graph.",
|
| 153 |
-
{
|
| 154 |
-
"src": {"type": "string"},
|
| 155 |
-
"rel": {"type": "string"},
|
| 156 |
-
"dst": {"type": "string"},
|
| 157 |
-
"confidence": {"type": "number", "default": 1.0},
|
| 158 |
-
},
|
| 159 |
-
["src", "rel", "dst"],
|
| 160 |
-
),
|
| 161 |
-
_tool_schema(
|
| 162 |
-
"submit_answer",
|
| 163 |
-
"Finish the episode by submitting the exact node id answer.",
|
| 164 |
-
{"answer": {"type": "string", "description": "Exact node id answer for the task."}},
|
| 165 |
-
["answer"],
|
| 166 |
-
),
|
| 167 |
-
]
|
| 168 |
-
|
| 169 |
-
|
| 170 |
def _message_text(message: Any) -> str:
|
| 171 |
content = getattr(message, "content", "")
|
| 172 |
if isinstance(content, str):
|
|
|
|
| 12 |
from osint_env.env.reward import compute_graph_f1
|
| 13 |
from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard
|
| 14 |
from osint_env.eval.metrics import EvalMetrics
|
| 15 |
+
from osint_env.platforms.tool_schemas import build_action_tools
|
| 16 |
from osint_env.viz import export_dashboard
|
| 17 |
|
| 18 |
|
|
|
|
| 51 |
max_steps: int = 8
|
| 52 |
seed: int | None = 7
|
| 53 |
append_leaderboard: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def _message_text(message: Any) -> str:
|
| 55 |
content = getattr(message, "content", "")
|
| 56 |
if isinstance(content, str):
|
src/osint_env/env/environment.py
CHANGED
|
@@ -137,6 +137,10 @@ class OSINTEnvironment(Env):
|
|
| 137 |
top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
|
| 138 |
results = self.semantic_memory.search(query=query, k=max(1, top_k)) if query else []
|
| 139 |
output = {"results": results, "count": len(results)}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
else:
|
| 141 |
output = self.tools.call(tool_name, args)
|
| 142 |
except Exception as exc:
|
|
@@ -207,6 +211,67 @@ class OSINTEnvironment(Env):
|
|
| 207 |
matches = sum(1 for token in clues if token in haystack)
|
| 208 |
return matches / len(clues)
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
def _accumulate_reward_components(self, values: dict[str, float]) -> None:
|
| 211 |
if self.state is None:
|
| 212 |
return
|
|
@@ -218,11 +283,17 @@ class OSINTEnvironment(Env):
|
|
| 218 |
raise RuntimeError("State is not initialized.")
|
| 219 |
metadata = dict(self.state.task.metadata or {})
|
| 220 |
grader = metadata.get("grader") if isinstance(metadata.get("grader"), dict) else None
|
|
|
|
| 221 |
task_payload = {
|
| 222 |
"task_id": self.state.task.task_id,
|
| 223 |
"task_type": self.state.task.task_type,
|
| 224 |
"question": self.state.task.question,
|
| 225 |
"difficulty": self.state.difficulty,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
"grader": (
|
| 227 |
dict(grader)
|
| 228 |
if grader is not None
|
|
|
|
| 137 |
top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
|
| 138 |
results = self.semantic_memory.search(query=query, k=max(1, top_k)) if query else []
|
| 139 |
output = {"results": results, "count": len(results)}
|
| 140 |
+
elif tool_name == "search_shared_context":
|
| 141 |
+
query = str(args.get("query", "")).strip()
|
| 142 |
+
top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
|
| 143 |
+
output = self._search_shared_context(query=query, k=max(1, top_k))
|
| 144 |
else:
|
| 145 |
output = self.tools.call(tool_name, args)
|
| 146 |
except Exception as exc:
|
|
|
|
| 211 |
matches = sum(1 for token in clues if token in haystack)
|
| 212 |
return matches / len(clues)
|
| 213 |
|
| 214 |
+
def _task_shared_context(self) -> dict[str, Any]:
|
| 215 |
+
if self.state is None:
|
| 216 |
+
return {"nodes": [], "edges": []}
|
| 217 |
+
metadata = dict(self.state.task.metadata or {})
|
| 218 |
+
canonical_graph = metadata.get("canonical_graph")
|
| 219 |
+
if isinstance(canonical_graph, dict):
|
| 220 |
+
return {
|
| 221 |
+
"nodes": list(canonical_graph.get("nodes", [])),
|
| 222 |
+
"edges": list(canonical_graph.get("edges", [])),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
nodes = sorted({edge.src for edge in self.state.task.supporting_edges} | {edge.dst for edge in self.state.task.supporting_edges})
|
| 226 |
+
edges = [
|
| 227 |
+
{
|
| 228 |
+
"src": edge.src,
|
| 229 |
+
"rel": edge.rel,
|
| 230 |
+
"dst": edge.dst,
|
| 231 |
+
"confidence": float(edge.confidence),
|
| 232 |
+
}
|
| 233 |
+
for edge in self.state.task.supporting_edges
|
| 234 |
+
]
|
| 235 |
+
return {"nodes": nodes, "edges": edges}
|
| 236 |
+
|
| 237 |
+
def _search_shared_context(self, query: str, k: int = 5) -> dict[str, Any]:
|
| 238 |
+
shared_context = self._task_shared_context()
|
| 239 |
+
needle = str(query or "").strip().lower()
|
| 240 |
+
results: list[dict[str, Any]] = []
|
| 241 |
+
|
| 242 |
+
for node_id in shared_context.get("nodes", []):
|
| 243 |
+
token = str(node_id).strip()
|
| 244 |
+
if not token:
|
| 245 |
+
continue
|
| 246 |
+
if needle and needle not in token.lower():
|
| 247 |
+
continue
|
| 248 |
+
results.append({"type": "node", "node_id": token})
|
| 249 |
+
|
| 250 |
+
for edge in shared_context.get("edges", []):
|
| 251 |
+
if not isinstance(edge, dict):
|
| 252 |
+
continue
|
| 253 |
+
src = str(edge.get("src", "")).strip()
|
| 254 |
+
rel = str(edge.get("rel", "")).strip()
|
| 255 |
+
dst = str(edge.get("dst", "")).strip()
|
| 256 |
+
haystack = " ".join(part for part in (src, rel, dst) if part).lower()
|
| 257 |
+
if needle and needle not in haystack:
|
| 258 |
+
continue
|
| 259 |
+
results.append(
|
| 260 |
+
{
|
| 261 |
+
"type": "edge",
|
| 262 |
+
"src": src,
|
| 263 |
+
"rel": rel,
|
| 264 |
+
"dst": dst,
|
| 265 |
+
"confidence": float(edge.get("confidence", 1.0)),
|
| 266 |
+
}
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return {
|
| 270 |
+
"results": results[: max(1, int(k))],
|
| 271 |
+
"count": len(results),
|
| 272 |
+
"shared_context_available": bool(shared_context.get("nodes") or shared_context.get("edges")),
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
def _accumulate_reward_components(self, values: dict[str, float]) -> None:
|
| 276 |
if self.state is None:
|
| 277 |
return
|
|
|
|
| 283 |
raise RuntimeError("State is not initialized.")
|
| 284 |
metadata = dict(self.state.task.metadata or {})
|
| 285 |
grader = metadata.get("grader") if isinstance(metadata.get("grader"), dict) else None
|
| 286 |
+
shared_context = self._task_shared_context()
|
| 287 |
task_payload = {
|
| 288 |
"task_id": self.state.task.task_id,
|
| 289 |
"task_type": self.state.task.task_type,
|
| 290 |
"question": self.state.task.question,
|
| 291 |
"difficulty": self.state.difficulty,
|
| 292 |
+
"shared_context_available": bool(shared_context.get("nodes") or shared_context.get("edges")),
|
| 293 |
+
"shared_context_size": {
|
| 294 |
+
"nodes": len(shared_context.get("nodes", [])),
|
| 295 |
+
"edges": len(shared_context.get("edges", [])),
|
| 296 |
+
},
|
| 297 |
"grader": (
|
| 298 |
dict(grader)
|
| 299 |
if grader is not None
|
src/osint_env/platforms/tool_schemas.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _tool_schema(
|
| 7 |
+
name: str,
|
| 8 |
+
description: str,
|
| 9 |
+
properties: dict[str, Any],
|
| 10 |
+
required: list[str],
|
| 11 |
+
) -> dict[str, Any]:
|
| 12 |
+
return {
|
| 13 |
+
"type": "function",
|
| 14 |
+
"function": {
|
| 15 |
+
"name": name,
|
| 16 |
+
"description": description,
|
| 17 |
+
"parameters": {
|
| 18 |
+
"type": "object",
|
| 19 |
+
"properties": properties,
|
| 20 |
+
"required": required,
|
| 21 |
+
"additionalProperties": False,
|
| 22 |
+
},
|
| 23 |
+
},
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_lookup_tools() -> list[dict[str, Any]]:
|
| 28 |
+
return [
|
| 29 |
+
_tool_schema(
|
| 30 |
+
"search_posts",
|
| 31 |
+
"Search microblog posts by substring over post text, post id, author id, canonical user id, or referenced entity ids/names.",
|
| 32 |
+
{"query": {"type": "string", "description": "Substring to search for in post text."}},
|
| 33 |
+
["query"],
|
| 34 |
+
),
|
| 35 |
+
_tool_schema(
|
| 36 |
+
"get_post",
|
| 37 |
+
"Fetch a specific microblog post by exact post id.",
|
| 38 |
+
{"post_id": {"type": "string", "description": "Post node id such as post_midnight_manifest."}},
|
| 39 |
+
["post_id"],
|
| 40 |
+
),
|
| 41 |
+
_tool_schema(
|
| 42 |
+
"get_user_posts",
|
| 43 |
+
"Fetch posts authored by a user or alias id. Alias ids are resolved to the canonical user and vice versa.",
|
| 44 |
+
{"user_id": {"type": "string", "description": "User or alias node id."}},
|
| 45 |
+
["user_id"],
|
| 46 |
+
),
|
| 47 |
+
_tool_schema(
|
| 48 |
+
"get_mentions",
|
| 49 |
+
"Fetch posts that mention a given canonical user id.",
|
| 50 |
+
{"user_id": {"type": "string", "description": "Canonical user node id."}},
|
| 51 |
+
["user_id"],
|
| 52 |
+
),
|
| 53 |
+
_tool_schema(
|
| 54 |
+
"search_threads",
|
| 55 |
+
"Search forum threads by exact topic name.",
|
| 56 |
+
{"topic": {"type": "string", "description": "Thread topic such as security or ai."}},
|
| 57 |
+
["topic"],
|
| 58 |
+
),
|
| 59 |
+
_tool_schema(
|
| 60 |
+
"get_thread",
|
| 61 |
+
"Fetch a specific forum thread by id.",
|
| 62 |
+
{"thread_id": {"type": "string", "description": "Thread node id."}},
|
| 63 |
+
["thread_id"],
|
| 64 |
+
),
|
| 65 |
+
_tool_schema(
|
| 66 |
+
"get_user_activity",
|
| 67 |
+
"Fetch a user's known forum activity.",
|
| 68 |
+
{"user_id": {"type": "string", "description": "Canonical user node id."}},
|
| 69 |
+
["user_id"],
|
| 70 |
+
),
|
| 71 |
+
_tool_schema(
|
| 72 |
+
"get_profile",
|
| 73 |
+
"Fetch a profile record by canonical user id or alias id.",
|
| 74 |
+
{"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
|
| 75 |
+
["user_id"],
|
| 76 |
+
),
|
| 77 |
+
_tool_schema(
|
| 78 |
+
"search_people",
|
| 79 |
+
"Search profiles by name, alias id, organization name, or organization id.",
|
| 80 |
+
{
|
| 81 |
+
"name": {"type": "string", "description": "Optional name substring.", "default": ""},
|
| 82 |
+
"org": {"type": "string", "description": "Optional organization substring.", "default": ""},
|
| 83 |
+
},
|
| 84 |
+
[],
|
| 85 |
+
),
|
| 86 |
+
_tool_schema(
|
| 87 |
+
"get_connections",
|
| 88 |
+
"Fetch explicit profile connections for a user or alias id.",
|
| 89 |
+
{"user_id": {"type": "string", "description": "Canonical user node id or alias id."}},
|
| 90 |
+
["user_id"],
|
| 91 |
+
),
|
| 92 |
+
_tool_schema(
|
| 93 |
+
"search_memory",
|
| 94 |
+
"Search semantic memory over prior observations and tool outputs.",
|
| 95 |
+
{
|
| 96 |
+
"query": {"type": "string", "description": "Memory retrieval query."},
|
| 97 |
+
"k": {"type": "integer", "description": "Top-k matches.", "default": 5},
|
| 98 |
+
},
|
| 99 |
+
["query"],
|
| 100 |
+
),
|
| 101 |
+
_tool_schema(
|
| 102 |
+
"search_shared_context",
|
| 103 |
+
"Search the task-local shared context graph carried with the current question.",
|
| 104 |
+
{
|
| 105 |
+
"query": {"type": "string", "description": "Substring query over shared-context node ids and edge fields."},
|
| 106 |
+
"k": {"type": "integer", "description": "Maximum number of node/edge hits to return.", "default": 5},
|
| 107 |
+
},
|
| 108 |
+
["query"],
|
| 109 |
+
),
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_action_tools() -> list[dict[str, Any]]:
|
| 114 |
+
return build_lookup_tools() + [
|
| 115 |
+
_tool_schema(
|
| 116 |
+
"add_edge",
|
| 117 |
+
"Add a supported graph edge to the working memory graph.",
|
| 118 |
+
{
|
| 119 |
+
"src": {"type": "string"},
|
| 120 |
+
"rel": {"type": "string"},
|
| 121 |
+
"dst": {"type": "string"},
|
| 122 |
+
"confidence": {"type": "number", "default": 1.0},
|
| 123 |
+
},
|
| 124 |
+
["src", "rel", "dst"],
|
| 125 |
+
),
|
| 126 |
+
_tool_schema(
|
| 127 |
+
"submit_answer",
|
| 128 |
+
"Finish the episode by submitting the exact node id answer.",
|
| 129 |
+
{"answer": {"type": "string", "description": "Exact node id answer for the task."}},
|
| 130 |
+
["answer"],
|
| 131 |
+
),
|
| 132 |
+
]
|
src/osint_env/training/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ from osint_env.training.config import (
|
|
| 11 |
SwarmV2ValidationConfig,
|
| 12 |
load_self_play_config,
|
| 13 |
)
|
|
|
|
| 14 |
from osint_env.training.self_play import run_adversarial_self_play
|
| 15 |
|
| 16 |
__all__ = [
|
|
@@ -23,5 +24,6 @@ __all__ = [
|
|
| 23 |
"SwarmV2SwarmConfig",
|
| 24 |
"SwarmV2ValidationConfig",
|
| 25 |
"load_self_play_config",
|
|
|
|
| 26 |
"run_adversarial_self_play",
|
| 27 |
]
|
|
|
|
| 11 |
SwarmV2ValidationConfig,
|
| 12 |
load_self_play_config,
|
| 13 |
)
|
| 14 |
+
from osint_env.training.hf_jobs import launch_hf_self_play_job
|
| 15 |
from osint_env.training.self_play import run_adversarial_self_play
|
| 16 |
|
| 17 |
__all__ = [
|
|
|
|
| 24 |
"SwarmV2SwarmConfig",
|
| 25 |
"SwarmV2ValidationConfig",
|
| 26 |
"load_self_play_config",
|
| 27 |
+
"launch_hf_self_play_job",
|
| 28 |
"run_adversarial_self_play",
|
| 29 |
]
|
src/osint_env/training/config.py
CHANGED
|
@@ -11,7 +11,7 @@ class KimiGRPOPhaseConfig:
|
|
| 11 |
"""Configuration for one GRPO phase in the alternating self-play loop."""
|
| 12 |
|
| 13 |
model_name_or_path: str = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 14 |
-
learning_rate: float =
|
| 15 |
max_steps: int = 64
|
| 16 |
per_device_train_batch_size: int = 2
|
| 17 |
gradient_accumulation_steps: int = 4
|
|
@@ -27,7 +27,17 @@ class KimiGRPOPhaseConfig:
|
|
| 27 |
scale_rewards: str = "none"
|
| 28 |
logging_steps: int = 10
|
| 29 |
save_steps: int = 50
|
|
|
|
| 30 |
output_subdir: str = "phase"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
use_vllm: bool = False
|
| 32 |
vllm_mode: str = "colocate"
|
| 33 |
|
|
@@ -36,10 +46,10 @@ class KimiGRPOPhaseConfig:
|
|
| 36 |
class GeneratorRewardWeights:
|
| 37 |
"""Weighted components for adversarial task-generator reward."""
|
| 38 |
|
| 39 |
-
validity: float = 0.
|
| 40 |
-
hardness: float = 0.
|
| 41 |
-
diversity: float = 0.
|
| 42 |
-
consistency: float = 0.
|
| 43 |
|
| 44 |
|
| 45 |
@dataclass(slots=True)
|
|
@@ -130,14 +140,25 @@ class SelfPlayTrainingConfig:
|
|
| 130 |
max_graph_context_edges: int = 100
|
| 131 |
max_support_edges: int = 8
|
| 132 |
answerer_judge_max_new_tokens: int = 48
|
|
|
|
|
|
|
|
|
|
| 133 |
generator_reward_weights: GeneratorRewardWeights = field(default_factory=GeneratorRewardWeights)
|
| 134 |
lora: LoraTuningConfig = field(default_factory=LoraTuningConfig)
|
| 135 |
swarm_v2: SwarmV2Config = field(default_factory=SwarmV2Config)
|
| 136 |
generator_phase: KimiGRPOPhaseConfig = field(
|
| 137 |
-
default_factory=lambda: KimiGRPOPhaseConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
)
|
| 139 |
answerer_phase: KimiGRPOPhaseConfig = field(
|
| 140 |
-
default_factory=lambda: KimiGRPOPhaseConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
|
|
@@ -224,6 +245,42 @@ def _parse_phase(data: dict[str, Any], fallback: KimiGRPOPhaseConfig) -> KimiGRP
|
|
| 224 |
logging_steps=_parse_int(data.get("logging_steps"), fallback.logging_steps, floor=1),
|
| 225 |
save_steps=_parse_int(data.get("save_steps"), fallback.save_steps, floor=1),
|
| 226 |
output_subdir=str(data.get("output_subdir", fallback.output_subdir)).strip() or fallback.output_subdir,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
use_vllm=_parse_bool(data.get("use_vllm"), fallback.use_vllm),
|
| 228 |
vllm_mode=str(data.get("vllm_mode", fallback.vllm_mode)).strip() or fallback.vllm_mode,
|
| 229 |
)
|
|
@@ -231,10 +288,10 @@ def _parse_phase(data: dict[str, Any], fallback: KimiGRPOPhaseConfig) -> KimiGRP
|
|
| 231 |
|
| 232 |
def _parse_generator_weights(data: dict[str, Any]) -> GeneratorRewardWeights:
|
| 233 |
return GeneratorRewardWeights(
|
| 234 |
-
validity=_parse_float(data.get("validity"), 0.
|
| 235 |
-
hardness=_parse_float(data.get("hardness"), 0.
|
| 236 |
-
diversity=_parse_float(data.get("diversity"), 0.
|
| 237 |
-
consistency=_parse_float(data.get("consistency"), 0.
|
| 238 |
)
|
| 239 |
|
| 240 |
|
|
@@ -420,6 +477,21 @@ def load_self_play_config(path: str | Path | None) -> SelfPlayTrainingConfig:
|
|
| 420 |
defaults.answerer_judge_max_new_tokens,
|
| 421 |
floor=1,
|
| 422 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
generator_reward_weights=_parse_generator_weights(
|
| 424 |
_as_dict(payload.get("generator_reward_weights"))
|
| 425 |
),
|
|
|
|
| 11 |
"""Configuration for one GRPO phase in the alternating self-play loop."""
|
| 12 |
|
| 13 |
model_name_or_path: str = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 14 |
+
learning_rate: float = 3e-6
|
| 15 |
max_steps: int = 64
|
| 16 |
per_device_train_batch_size: int = 2
|
| 17 |
gradient_accumulation_steps: int = 4
|
|
|
|
| 27 |
scale_rewards: str = "none"
|
| 28 |
logging_steps: int = 10
|
| 29 |
save_steps: int = 50
|
| 30 |
+
save_total_limit: int = 2
|
| 31 |
output_subdir: str = "phase"
|
| 32 |
+
optim: str = "adamw_torch_fused"
|
| 33 |
+
bf16: bool = True
|
| 34 |
+
tf32: bool = True
|
| 35 |
+
gradient_checkpointing: bool = False
|
| 36 |
+
dataloader_num_workers: int = 2
|
| 37 |
+
dataloader_persistent_workers: bool = True
|
| 38 |
+
dataloader_prefetch_factor: int = 2
|
| 39 |
+
generation_batch_size: int = 8
|
| 40 |
+
max_prompt_length: int = 1024
|
| 41 |
use_vllm: bool = False
|
| 42 |
vllm_mode: str = "colocate"
|
| 43 |
|
|
|
|
| 46 |
class GeneratorRewardWeights:
|
| 47 |
"""Weighted components for adversarial task-generator reward."""
|
| 48 |
|
| 49 |
+
validity: float = 0.45
|
| 50 |
+
hardness: float = 0.20
|
| 51 |
+
diversity: float = 0.15
|
| 52 |
+
consistency: float = 0.20
|
| 53 |
|
| 54 |
|
| 55 |
@dataclass(slots=True)
|
|
|
|
| 140 |
max_graph_context_edges: int = 100
|
| 141 |
max_support_edges: int = 8
|
| 142 |
answerer_judge_max_new_tokens: int = 48
|
| 143 |
+
generated_task_max_new_tokens: int = 512
|
| 144 |
+
post_training_eval_questions: int = 24
|
| 145 |
+
post_training_eval_answer_max_new_tokens: int = 128
|
| 146 |
generator_reward_weights: GeneratorRewardWeights = field(default_factory=GeneratorRewardWeights)
|
| 147 |
lora: LoraTuningConfig = field(default_factory=LoraTuningConfig)
|
| 148 |
swarm_v2: SwarmV2Config = field(default_factory=SwarmV2Config)
|
| 149 |
generator_phase: KimiGRPOPhaseConfig = field(
|
| 150 |
+
default_factory=lambda: KimiGRPOPhaseConfig(
|
| 151 |
+
output_subdir="generator",
|
| 152 |
+
learning_rate=5e-6,
|
| 153 |
+
max_completion_length=384,
|
| 154 |
+
)
|
| 155 |
)
|
| 156 |
answerer_phase: KimiGRPOPhaseConfig = field(
|
| 157 |
+
default_factory=lambda: KimiGRPOPhaseConfig(
|
| 158 |
+
output_subdir="answerer",
|
| 159 |
+
learning_rate=3e-6,
|
| 160 |
+
max_completion_length=192,
|
| 161 |
+
)
|
| 162 |
)
|
| 163 |
|
| 164 |
|
|
|
|
| 245 |
logging_steps=_parse_int(data.get("logging_steps"), fallback.logging_steps, floor=1),
|
| 246 |
save_steps=_parse_int(data.get("save_steps"), fallback.save_steps, floor=1),
|
| 247 |
output_subdir=str(data.get("output_subdir", fallback.output_subdir)).strip() or fallback.output_subdir,
|
| 248 |
+
optim=str(data.get("optim", fallback.optim)).strip() or fallback.optim,
|
| 249 |
+
bf16=_parse_bool(data.get("bf16"), fallback.bf16),
|
| 250 |
+
tf32=_parse_bool(data.get("tf32"), fallback.tf32),
|
| 251 |
+
gradient_checkpointing=_parse_bool(
|
| 252 |
+
data.get("gradient_checkpointing"),
|
| 253 |
+
fallback.gradient_checkpointing,
|
| 254 |
+
),
|
| 255 |
+
dataloader_num_workers=_parse_int(
|
| 256 |
+
data.get("dataloader_num_workers"),
|
| 257 |
+
fallback.dataloader_num_workers,
|
| 258 |
+
floor=0,
|
| 259 |
+
),
|
| 260 |
+
dataloader_persistent_workers=_parse_bool(
|
| 261 |
+
data.get("dataloader_persistent_workers"),
|
| 262 |
+
fallback.dataloader_persistent_workers,
|
| 263 |
+
),
|
| 264 |
+
dataloader_prefetch_factor=_parse_int(
|
| 265 |
+
data.get("dataloader_prefetch_factor"),
|
| 266 |
+
fallback.dataloader_prefetch_factor,
|
| 267 |
+
floor=1,
|
| 268 |
+
),
|
| 269 |
+
generation_batch_size=_parse_int(
|
| 270 |
+
data.get("generation_batch_size"),
|
| 271 |
+
fallback.generation_batch_size,
|
| 272 |
+
floor=1,
|
| 273 |
+
),
|
| 274 |
+
max_prompt_length=_parse_int(
|
| 275 |
+
data.get("max_prompt_length"),
|
| 276 |
+
fallback.max_prompt_length,
|
| 277 |
+
floor=32,
|
| 278 |
+
),
|
| 279 |
+
save_total_limit=_parse_int(
|
| 280 |
+
data.get("save_total_limit"),
|
| 281 |
+
fallback.save_total_limit,
|
| 282 |
+
floor=1,
|
| 283 |
+
),
|
| 284 |
use_vllm=_parse_bool(data.get("use_vllm"), fallback.use_vllm),
|
| 285 |
vllm_mode=str(data.get("vllm_mode", fallback.vllm_mode)).strip() or fallback.vllm_mode,
|
| 286 |
)
|
|
|
|
| 288 |
|
| 289 |
def _parse_generator_weights(data: dict[str, Any]) -> GeneratorRewardWeights:
|
| 290 |
return GeneratorRewardWeights(
|
| 291 |
+
validity=_parse_float(data.get("validity"), 0.45),
|
| 292 |
+
hardness=_parse_float(data.get("hardness"), 0.20),
|
| 293 |
+
diversity=_parse_float(data.get("diversity"), 0.15),
|
| 294 |
+
consistency=_parse_float(data.get("consistency"), 0.20),
|
| 295 |
)
|
| 296 |
|
| 297 |
|
|
|
|
| 477 |
defaults.answerer_judge_max_new_tokens,
|
| 478 |
floor=1,
|
| 479 |
),
|
| 480 |
+
generated_task_max_new_tokens=_parse_int(
|
| 481 |
+
payload.get("generated_task_max_new_tokens"),
|
| 482 |
+
defaults.generated_task_max_new_tokens,
|
| 483 |
+
floor=32,
|
| 484 |
+
),
|
| 485 |
+
post_training_eval_questions=_parse_int(
|
| 486 |
+
payload.get("post_training_eval_questions"),
|
| 487 |
+
defaults.post_training_eval_questions,
|
| 488 |
+
floor=1,
|
| 489 |
+
),
|
| 490 |
+
post_training_eval_answer_max_new_tokens=_parse_int(
|
| 491 |
+
payload.get("post_training_eval_answer_max_new_tokens"),
|
| 492 |
+
defaults.post_training_eval_answer_max_new_tokens,
|
| 493 |
+
floor=1,
|
| 494 |
+
),
|
| 495 |
generator_reward_weights=_parse_generator_weights(
|
| 496 |
_as_dict(payload.get("generator_reward_weights"))
|
| 497 |
),
|
src/osint_env/training/hf_jobs.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import shlex
|
| 7 |
+
import time
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
DEFAULT_HF_JOB_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _is_true(value: str | None) -> bool:
|
| 14 |
+
token = str(value or "").strip().lower()
|
| 15 |
+
return token in {"1", "true", "yes", "y", "on"}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _default_train_output_dir(bucket_name: str | None, run_name: str) -> str:
|
| 19 |
+
if bucket_name:
|
| 20 |
+
return f"/training-outputs/{run_name}"
|
| 21 |
+
return f"artifacts/{run_name}"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _require_hf_token(value: str | None) -> str:
|
| 25 |
+
token = str(value or "").strip()
|
| 26 |
+
if not token:
|
| 27 |
+
raise RuntimeError(
|
| 28 |
+
"HF_TOKEN is required to launch a Hugging Face Job. "
|
| 29 |
+
"Set HF_TOKEN in your environment or pass --hf-token."
|
| 30 |
+
)
|
| 31 |
+
return token
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _resolve_job_image(job_image: str | None, space_id: str | None) -> str:
|
| 35 |
+
image = str(job_image or "").strip()
|
| 36 |
+
if image:
|
| 37 |
+
return image
|
| 38 |
+
space = str(space_id or "").strip()
|
| 39 |
+
if space:
|
| 40 |
+
return f"hf.co/spaces/{space}"
|
| 41 |
+
return DEFAULT_HF_JOB_IMAGE
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _train_self_play_command(
|
| 45 |
+
*,
|
| 46 |
+
env_config_path: str,
|
| 47 |
+
train_config_path: str,
|
| 48 |
+
output_dir: str,
|
| 49 |
+
dry_run: bool,
|
| 50 |
+
) -> list[str]:
|
| 51 |
+
command = [
|
| 52 |
+
"osint-env",
|
| 53 |
+
"train-self-play",
|
| 54 |
+
"--config",
|
| 55 |
+
env_config_path,
|
| 56 |
+
"--train-config",
|
| 57 |
+
train_config_path,
|
| 58 |
+
"--train-output-dir",
|
| 59 |
+
output_dir,
|
| 60 |
+
]
|
| 61 |
+
if dry_run:
|
| 62 |
+
command.append("--dry-run")
|
| 63 |
+
return command
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _shell_join(parts: list[str]) -> str:
|
| 67 |
+
return " ".join(shlex.quote(part) for part in parts)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _build_job_command(
|
| 71 |
+
*,
|
| 72 |
+
env_config_path: str,
|
| 73 |
+
train_config_path: str,
|
| 74 |
+
output_dir: str,
|
| 75 |
+
dry_run: bool,
|
| 76 |
+
repo_url: str,
|
| 77 |
+
repo_ref: str,
|
| 78 |
+
repo_subdir: str,
|
| 79 |
+
setup_command: str,
|
| 80 |
+
) -> list[str]:
|
| 81 |
+
train_command = _train_self_play_command(
|
| 82 |
+
env_config_path=env_config_path,
|
| 83 |
+
train_config_path=train_config_path,
|
| 84 |
+
output_dir=output_dir,
|
| 85 |
+
dry_run=dry_run,
|
| 86 |
+
)
|
| 87 |
+
repo = str(repo_url).strip()
|
| 88 |
+
if not repo:
|
| 89 |
+
return train_command
|
| 90 |
+
|
| 91 |
+
worktree = "/workspace/osint_env_app"
|
| 92 |
+
clone_command = f"git clone --depth 1 {shlex.quote(repo)} {shlex.quote(worktree)}"
|
| 93 |
+
ref = str(repo_ref).strip()
|
| 94 |
+
if ref:
|
| 95 |
+
clone_command = (
|
| 96 |
+
f"git clone --depth 1 --branch {shlex.quote(ref)} "
|
| 97 |
+
f"{shlex.quote(repo)} {shlex.quote(worktree)}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
shell_lines = [
|
| 101 |
+
"set -euo pipefail",
|
| 102 |
+
"export PYTHONUNBUFFERED=1",
|
| 103 |
+
"export PIP_DISABLE_PIP_VERSION_CHECK=1",
|
| 104 |
+
"command -v git >/dev/null 2>&1 || { echo 'git is required when --repo-url is set' >&2; exit 1; }",
|
| 105 |
+
"mkdir -p /workspace",
|
| 106 |
+
clone_command,
|
| 107 |
+
f"cd {shlex.quote(worktree)}",
|
| 108 |
+
]
|
| 109 |
+
subdir = str(repo_subdir).strip()
|
| 110 |
+
if subdir:
|
| 111 |
+
shell_lines.append(f"cd {shlex.quote(subdir)}")
|
| 112 |
+
shell_lines.extend(
|
| 113 |
+
[
|
| 114 |
+
"python -m pip install --upgrade pip",
|
| 115 |
+
"python -m pip install -e '.[train]'",
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
setup = str(setup_command).strip()
|
| 119 |
+
if setup:
|
| 120 |
+
shell_lines.append(setup)
|
| 121 |
+
shell_lines.append(_shell_join(train_command))
|
| 122 |
+
return ["bash", "-lc", "\n".join(shell_lines)]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def launch_hf_self_play_job(
|
| 126 |
+
*,
|
| 127 |
+
hf_token: str,
|
| 128 |
+
job_image: str,
|
| 129 |
+
env_config_path: str,
|
| 130 |
+
train_config_path: str,
|
| 131 |
+
flavor: str,
|
| 132 |
+
timeout: str,
|
| 133 |
+
output_dir: str,
|
| 134 |
+
space_id: str = "",
|
| 135 |
+
namespace: str = "",
|
| 136 |
+
run_name: str = "",
|
| 137 |
+
dry_run: bool = False,
|
| 138 |
+
wait: bool = False,
|
| 139 |
+
output_bucket: str = "",
|
| 140 |
+
repo_url: str = "",
|
| 141 |
+
repo_ref: str = "",
|
| 142 |
+
repo_subdir: str = "",
|
| 143 |
+
setup_command: str = "",
|
| 144 |
+
) -> dict[str, Any]:
|
| 145 |
+
try:
|
| 146 |
+
from huggingface_hub import Volume, fetch_job_logs, inspect_job, login, run_job
|
| 147 |
+
except ImportError as exc:
|
| 148 |
+
raise RuntimeError(
|
| 149 |
+
"huggingface_hub is required to launch HF Jobs. "
|
| 150 |
+
"Install dependencies that include huggingface_hub first."
|
| 151 |
+
) from exc
|
| 152 |
+
|
| 153 |
+
token = _require_hf_token(hf_token)
|
| 154 |
+
image = _resolve_job_image(job_image=job_image, space_id=space_id)
|
| 155 |
+
login(token=token, add_to_git_credential=False)
|
| 156 |
+
|
| 157 |
+
command = _build_job_command(
|
| 158 |
+
env_config_path=env_config_path,
|
| 159 |
+
train_config_path=train_config_path,
|
| 160 |
+
output_dir=output_dir,
|
| 161 |
+
dry_run=dry_run,
|
| 162 |
+
repo_url=repo_url,
|
| 163 |
+
repo_ref=repo_ref,
|
| 164 |
+
repo_subdir=repo_subdir,
|
| 165 |
+
setup_command=setup_command,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
secrets = {"HF_TOKEN": token}
|
| 169 |
+
for secret_name in ("WANDB_API_KEY", "OPENAI_API_KEY", "GITHUB_TOKEN", "GH_TOKEN"):
|
| 170 |
+
secret_value = str(os.getenv(secret_name, "")).strip()
|
| 171 |
+
if secret_value:
|
| 172 |
+
secrets[secret_name] = secret_value
|
| 173 |
+
|
| 174 |
+
env: dict[str, str] = {
|
| 175 |
+
"PYTHONUNBUFFERED": "1",
|
| 176 |
+
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
| 177 |
+
}
|
| 178 |
+
if run_name:
|
| 179 |
+
env["OSINT_HF_JOB_RUN_NAME"] = run_name
|
| 180 |
+
for env_name in (
|
| 181 |
+
"WANDB_ENTITY",
|
| 182 |
+
"WANDB_PROJECT",
|
| 183 |
+
"WANDB_RUN_GROUP",
|
| 184 |
+
"OSINT_TRAIN_STRICT_ASSERTS",
|
| 185 |
+
"HF_HOME",
|
| 186 |
+
"TRANSFORMERS_CACHE",
|
| 187 |
+
):
|
| 188 |
+
env_value = str(os.getenv(env_name, "")).strip()
|
| 189 |
+
if env_value:
|
| 190 |
+
env[env_name] = env_value
|
| 191 |
+
|
| 192 |
+
volumes: list[Any] = []
|
| 193 |
+
if output_bucket:
|
| 194 |
+
volumes.append(Volume(type="bucket", source=output_bucket, mount_path="/training-outputs"))
|
| 195 |
+
|
| 196 |
+
job = run_job(
|
| 197 |
+
image=image,
|
| 198 |
+
command=command,
|
| 199 |
+
flavor=flavor,
|
| 200 |
+
timeout=timeout,
|
| 201 |
+
namespace=namespace or None,
|
| 202 |
+
env=env,
|
| 203 |
+
secrets=secrets,
|
| 204 |
+
volumes=volumes or None,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
payload: dict[str, Any] = {
|
| 208 |
+
"job_id": str(job.id),
|
| 209 |
+
"job_url": str(job.url),
|
| 210 |
+
"job_image": image,
|
| 211 |
+
"flavor": flavor,
|
| 212 |
+
"timeout": timeout,
|
| 213 |
+
"output_dir": output_dir,
|
| 214 |
+
"output_bucket": output_bucket,
|
| 215 |
+
"repo_url": repo_url,
|
| 216 |
+
"repo_ref": repo_ref,
|
| 217 |
+
"repo_subdir": repo_subdir,
|
| 218 |
+
"space_id_compat": space_id,
|
| 219 |
+
"dry_run": dry_run,
|
| 220 |
+
"waited": False,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
if wait:
|
| 224 |
+
terminal_states = {"COMPLETED", "ERROR", "CANCELLED", "TIMEOUT"}
|
| 225 |
+
last_stage = ""
|
| 226 |
+
while True:
|
| 227 |
+
info = inspect_job(job_id=job.id)
|
| 228 |
+
stage = str(getattr(getattr(info, "status", None), "stage", "") or "")
|
| 229 |
+
if stage != last_stage:
|
| 230 |
+
print(json.dumps({"job_id": str(job.id), "stage": stage, "url": str(job.url)}))
|
| 231 |
+
last_stage = stage
|
| 232 |
+
if stage in terminal_states:
|
| 233 |
+
payload["waited"] = True
|
| 234 |
+
payload["final_stage"] = stage
|
| 235 |
+
if stage != "COMPLETED":
|
| 236 |
+
payload["logs"] = list(fetch_job_logs(job_id=job.id))
|
| 237 |
+
break
|
| 238 |
+
time.sleep(15)
|
| 239 |
+
|
| 240 |
+
return payload
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 244 |
+
parser = argparse.ArgumentParser(
|
| 245 |
+
description="Launch OSINT self-play training as a separate Hugging Face Job on dedicated compute."
|
| 246 |
+
)
|
| 247 |
+
parser.add_argument("--hf-token", default=os.getenv("HF_TOKEN", ""), help="HF token. Defaults to HF_TOKEN env var.")
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--job-image",
|
| 250 |
+
default=os.getenv("HF_JOB_IMAGE", ""),
|
| 251 |
+
help=(
|
| 252 |
+
"Docker image for the dedicated training job. "
|
| 253 |
+
f"Defaults to {DEFAULT_HF_JOB_IMAGE!r} unless --space-id is provided."
|
| 254 |
+
),
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--space-id",
|
| 258 |
+
default=os.getenv("HF_SPACE_ID", ""),
|
| 259 |
+
help="Optional compatibility fallback to reuse a Space image, e.g. owner/space-name.",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--env-config",
|
| 263 |
+
default=os.getenv("TRAIN_ENV_CONFIG_PATH", "config/shared_config.json"),
|
| 264 |
+
help="Environment config path inside the training image or checked-out repo.",
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--train-config",
|
| 268 |
+
default=os.getenv("TRAIN_SELF_PLAY_CONFIG_PATH", "config/self_play_training_hf_a10g_smoke.json"),
|
| 269 |
+
help="Training config path inside the training image or checked-out repo.",
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument("--flavor", default=os.getenv("HF_JOB_FLAVOR", "a10g-small"))
|
| 272 |
+
parser.add_argument("--timeout", default=os.getenv("HF_JOB_TIMEOUT", "8h"))
|
| 273 |
+
parser.add_argument("--namespace", default=os.getenv("HF_JOB_NAMESPACE", ""))
|
| 274 |
+
parser.add_argument("--run-name", default=os.getenv("HF_JOB_RUN_NAME", "osint-self-play-job"))
|
| 275 |
+
parser.add_argument("--output-bucket", default=os.getenv("HF_JOB_OUTPUT_BUCKET", ""))
|
| 276 |
+
parser.add_argument("--output-dir", default=os.getenv("TRAIN_SELF_PLAY_OUTPUT_DIR", ""))
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--repo-url",
|
| 279 |
+
default=os.getenv("HF_JOB_REPO_URL", ""),
|
| 280 |
+
help="Optional git repository URL to clone inside the job before training.",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--repo-ref",
|
| 284 |
+
default=os.getenv("HF_JOB_REPO_REF", ""),
|
| 285 |
+
help="Optional git branch, tag, or commit-ish to check out when --repo-url is used.",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--repo-subdir",
|
| 289 |
+
default=os.getenv("HF_JOB_REPO_SUBDIR", ""),
|
| 290 |
+
help="Optional subdirectory inside the cloned repo that contains pyproject.toml.",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--setup-command",
|
| 294 |
+
default=os.getenv("HF_JOB_SETUP_COMMAND", ""),
|
| 295 |
+
help="Optional shell command to run after install and before training.",
|
| 296 |
+
)
|
| 297 |
+
parser.add_argument("--dry-run", action="store_true", default=_is_true(os.getenv("RUN_SELF_PLAY_DRY_RUN", "")))
|
| 298 |
+
parser.add_argument("--wait", action="store_true", default=_is_true(os.getenv("HF_JOB_WAIT", "")))
|
| 299 |
+
return parser
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def main() -> None:
|
| 303 |
+
args = build_parser().parse_args()
|
| 304 |
+
run_name = str(args.run_name).strip() or "osint-self-play-job"
|
| 305 |
+
output_bucket = str(args.output_bucket).strip()
|
| 306 |
+
output_dir = str(args.output_dir).strip() or _default_train_output_dir(output_bucket, run_name)
|
| 307 |
+
|
| 308 |
+
payload = launch_hf_self_play_job(
|
| 309 |
+
hf_token=str(args.hf_token),
|
| 310 |
+
job_image=str(args.job_image),
|
| 311 |
+
env_config_path=str(args.env_config),
|
| 312 |
+
train_config_path=str(args.train_config),
|
| 313 |
+
flavor=str(args.flavor),
|
| 314 |
+
timeout=str(args.timeout),
|
| 315 |
+
output_dir=output_dir,
|
| 316 |
+
space_id=str(args.space_id),
|
| 317 |
+
namespace=str(args.namespace),
|
| 318 |
+
run_name=run_name,
|
| 319 |
+
dry_run=bool(args.dry_run),
|
| 320 |
+
wait=bool(args.wait),
|
| 321 |
+
output_bucket=output_bucket,
|
| 322 |
+
repo_url=str(args.repo_url),
|
| 323 |
+
repo_ref=str(args.repo_ref),
|
| 324 |
+
repo_subdir=str(args.repo_subdir),
|
| 325 |
+
setup_command=str(args.setup_command),
|
| 326 |
+
)
|
| 327 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
main()
|
src/osint_env/training/rewards.py
CHANGED
|
@@ -8,6 +8,7 @@ from functools import lru_cache
|
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
from osint_env.data.generator import (
|
|
|
|
| 11 |
emit_swarm_v2_question,
|
| 12 |
enumerate_swarm_v2_neighbors,
|
| 13 |
select_swarm_v2_answer,
|
|
@@ -224,7 +225,7 @@ def _parse_tool_trace(value: Any) -> list[SwarmReplayToolCall]:
|
|
| 224 |
continue
|
| 225 |
tool_name = str(row.get("tool_name", row.get("tool", ""))).strip()
|
| 226 |
args = row.get("args", {})
|
| 227 |
-
output = row.get("output", {})
|
| 228 |
if not tool_name:
|
| 229 |
continue
|
| 230 |
out.append(
|
|
@@ -281,6 +282,12 @@ def _coerce_int(value: Any, default: int) -> int:
|
|
| 281 |
try:
|
| 282 |
return int(float(token))
|
| 283 |
except ValueError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
return default
|
| 285 |
return default
|
| 286 |
|
|
@@ -497,11 +504,16 @@ class SwarmV2ReplayValidator:
|
|
| 497 |
replayed_edges: list[Edge] = []
|
| 498 |
replayed_answer = ""
|
| 499 |
replayed_question = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
-
if not candidate.
|
| 502 |
-
|
|
|
|
| 503 |
|
| 504 |
-
for call in
|
| 505 |
if call.tool_name == "enumerate_neighbors":
|
| 506 |
node_id = str(call.args.get("node_id", "")).strip()
|
| 507 |
expected_edge = call.args.get("expected_edge", {})
|
|
@@ -520,21 +532,31 @@ class SwarmV2ReplayValidator:
|
|
| 520 |
if expected_key not in {(edge.src, edge.rel, edge.dst) for edge in neighbors}:
|
| 521 |
reasons.append("non_replayable_tool_calls")
|
| 522 |
elif call.tool_name == "trace_path":
|
| 523 |
-
|
| 524 |
-
replayed_edges = trace_swarm_v2_path(self.graph,
|
| 525 |
if not replayed_edges:
|
| 526 |
reasons.append("non_replayable_tool_calls")
|
| 527 |
elif call.tool_name == "select_answer":
|
| 528 |
-
|
| 529 |
-
if not replayed_answer:
|
| 530 |
-
reasons.append("non_replayable_tool_calls")
|
| 531 |
elif call.tool_name == "emit_question":
|
| 532 |
-
|
| 533 |
-
if not replayed_question:
|
| 534 |
-
reasons.append("non_replayable_tool_calls")
|
| 535 |
else:
|
| 536 |
reasons.append("non_replayable_tool_calls")
|
| 537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
return reasons, replayed_edges, replayed_answer, replayed_question
|
| 539 |
|
| 540 |
def validate(self, candidate: GeneratedTaskCandidate) -> ReplayValidationResult:
|
|
@@ -747,37 +769,37 @@ class GeneratorRewardFunction:
|
|
| 747 |
# (3) tiny text-level signal so completely-collapsed completions
|
| 748 |
# differ from completions that at least *attempt* JSON.
|
| 749 |
reason_penalty = {
|
| 750 |
-
"missing_question_or_answer": 0.
|
| 751 |
-
"malformed_support_edges": 0.
|
| 752 |
-
"non_replayable_tool_calls": 0.
|
| 753 |
-
"non_unique_derivation_path": 0.
|
| 754 |
-
"unseen_nodes_or_edges": 0.
|
| 755 |
-
"answer_leakage": 0.
|
| 756 |
-
"duplicate_or_near_duplicate": 0.
|
| 757 |
-
"context_or_support_budget_overflow": 0.
|
| 758 |
}
|
| 759 |
-
penalty = 0.
|
| 760 |
for reason in validation_result.reasons:
|
| 761 |
penalty += reason_penalty.get(reason, 0.10)
|
| 762 |
|
| 763 |
partial_credit = 0.0
|
| 764 |
if candidate.question:
|
| 765 |
-
partial_credit += 0.
|
| 766 |
if candidate.answer:
|
| 767 |
-
partial_credit += 0.
|
| 768 |
if candidate.supporting_edges:
|
| 769 |
-
partial_credit += min(0.
|
| 770 |
if candidate.tool_trace:
|
| 771 |
-
partial_credit += min(0.
|
| 772 |
if candidate.subagent_outputs:
|
| 773 |
partial_credit += 0.10
|
| 774 |
if candidate.canonical_edges or candidate.canonical_nodes:
|
| 775 |
-
partial_credit += 0.
|
| 776 |
|
| 777 |
text_signal = self._completion_text_signal(completion_text)
|
| 778 |
|
| 779 |
reward = partial_credit - penalty + text_signal
|
| 780 |
-
return float(max(-1.
|
| 781 |
|
| 782 |
@staticmethod
|
| 783 |
def _completion_text_signal(completion_text: str) -> float:
|
|
@@ -930,15 +952,24 @@ class GeneratorRewardFunction:
|
|
| 930 |
swarm_diversity = self._swarm_diversity_score(candidate)
|
| 931 |
context_pressure = self._context_pressure_score(validation_result)
|
| 932 |
parl_parallel, parl_finish = self._parl_scores(candidate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
|
| 934 |
reward = (
|
| 935 |
-
|
| 936 |
-
+
|
| 937 |
-
+ (
|
| 938 |
-
+ (
|
| 939 |
-
+ (0.
|
| 940 |
-
+ (0.025 * parl_parallel)
|
| 941 |
-
+ (0.025 * parl_finish)
|
| 942 |
)
|
| 943 |
return reward, validation_result
|
| 944 |
|
|
|
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
from osint_env.data.generator import (
|
| 11 |
+
build_swarm_v2_tool_trace,
|
| 12 |
emit_swarm_v2_question,
|
| 13 |
enumerate_swarm_v2_neighbors,
|
| 14 |
select_swarm_v2_answer,
|
|
|
|
| 225 |
continue
|
| 226 |
tool_name = str(row.get("tool_name", row.get("tool", ""))).strip()
|
| 227 |
args = row.get("args", {})
|
| 228 |
+
output = row.get("output", row.get("result", {}))
|
| 229 |
if not tool_name:
|
| 230 |
continue
|
| 231 |
out.append(
|
|
|
|
| 282 |
try:
|
| 283 |
return int(float(token))
|
| 284 |
except ValueError:
|
| 285 |
+
match = re.search(r"[-+]?\d+(?:\.\d+)?", token)
|
| 286 |
+
if match:
|
| 287 |
+
try:
|
| 288 |
+
return int(float(match.group(0)))
|
| 289 |
+
except ValueError:
|
| 290 |
+
return default
|
| 291 |
return default
|
| 292 |
return default
|
| 293 |
|
|
|
|
| 504 |
replayed_edges: list[Edge] = []
|
| 505 |
replayed_answer = ""
|
| 506 |
replayed_question = ""
|
| 507 |
+
declared_answer = ""
|
| 508 |
+
declared_question = ""
|
| 509 |
+
tool_trace = list(candidate.tool_trace)
|
| 510 |
+
trace_path_source: Any = candidate.supporting_edges
|
| 511 |
|
| 512 |
+
if not tool_trace and candidate.supporting_edges:
|
| 513 |
+
synthesized_trace = build_swarm_v2_tool_trace(self.graph, candidate.supporting_edges)
|
| 514 |
+
tool_trace = _parse_tool_trace(synthesized_trace)
|
| 515 |
|
| 516 |
+
for call in tool_trace:
|
| 517 |
if call.tool_name == "enumerate_neighbors":
|
| 518 |
node_id = str(call.args.get("node_id", "")).strip()
|
| 519 |
expected_edge = call.args.get("expected_edge", {})
|
|
|
|
| 532 |
if expected_key not in {(edge.src, edge.rel, edge.dst) for edge in neighbors}:
|
| 533 |
reasons.append("non_replayable_tool_calls")
|
| 534 |
elif call.tool_name == "trace_path":
|
| 535 |
+
trace_path_source = call.args.get("path", trace_path_source)
|
| 536 |
+
replayed_edges = trace_swarm_v2_path(self.graph, trace_path_source)
|
| 537 |
if not replayed_edges:
|
| 538 |
reasons.append("non_replayable_tool_calls")
|
| 539 |
elif call.tool_name == "select_answer":
|
| 540 |
+
declared_answer = normalize_answer(str(call.output.get("answer", "")).strip())
|
|
|
|
|
|
|
| 541 |
elif call.tool_name == "emit_question":
|
| 542 |
+
declared_question = str(call.output.get("question", "")).strip()
|
|
|
|
|
|
|
| 543 |
else:
|
| 544 |
reasons.append("non_replayable_tool_calls")
|
| 545 |
|
| 546 |
+
if not replayed_edges:
|
| 547 |
+
replayed_edges = trace_swarm_v2_path(self.graph, trace_path_source)
|
| 548 |
+
if not replayed_edges and candidate.supporting_edges:
|
| 549 |
+
replayed_edges = trace_swarm_v2_path(self.graph, candidate.supporting_edges)
|
| 550 |
+
if not replayed_edges:
|
| 551 |
+
reasons.append("non_replayable_tool_calls")
|
| 552 |
+
return reasons, replayed_edges, replayed_answer, replayed_question
|
| 553 |
+
|
| 554 |
+
replayed_answer = select_swarm_v2_answer(replayed_edges)
|
| 555 |
+
replayed_question = emit_swarm_v2_question(replayed_edges)
|
| 556 |
+
if declared_answer and declared_answer != normalize_answer(replayed_answer):
|
| 557 |
+
reasons.append("non_replayable_tool_calls")
|
| 558 |
+
if declared_question and declared_question != replayed_question:
|
| 559 |
+
reasons.append("non_replayable_tool_calls")
|
| 560 |
return reasons, replayed_edges, replayed_answer, replayed_question
|
| 561 |
|
| 562 |
def validate(self, candidate: GeneratedTaskCandidate) -> ReplayValidationResult:
|
|
|
|
| 769 |
# (3) tiny text-level signal so completely-collapsed completions
|
| 770 |
# differ from completions that at least *attempt* JSON.
|
| 771 |
reason_penalty = {
|
| 772 |
+
"missing_question_or_answer": 0.35,
|
| 773 |
+
"malformed_support_edges": 0.25,
|
| 774 |
+
"non_replayable_tool_calls": 0.25,
|
| 775 |
+
"non_unique_derivation_path": 0.20,
|
| 776 |
+
"unseen_nodes_or_edges": 0.25,
|
| 777 |
+
"answer_leakage": 0.30,
|
| 778 |
+
"duplicate_or_near_duplicate": 0.15,
|
| 779 |
+
"context_or_support_budget_overflow": 0.15,
|
| 780 |
}
|
| 781 |
+
penalty = 0.10
|
| 782 |
for reason in validation_result.reasons:
|
| 783 |
penalty += reason_penalty.get(reason, 0.10)
|
| 784 |
|
| 785 |
partial_credit = 0.0
|
| 786 |
if candidate.question:
|
| 787 |
+
partial_credit += 0.25
|
| 788 |
if candidate.answer:
|
| 789 |
+
partial_credit += 0.25
|
| 790 |
if candidate.supporting_edges:
|
| 791 |
+
partial_credit += min(0.36, 0.12 * len(candidate.supporting_edges))
|
| 792 |
if candidate.tool_trace:
|
| 793 |
+
partial_credit += min(0.20, 0.05 * len(candidate.tool_trace))
|
| 794 |
if candidate.subagent_outputs:
|
| 795 |
partial_credit += 0.10
|
| 796 |
if candidate.canonical_edges or candidate.canonical_nodes:
|
| 797 |
+
partial_credit += 0.12
|
| 798 |
|
| 799 |
text_signal = self._completion_text_signal(completion_text)
|
| 800 |
|
| 801 |
reward = partial_credit - penalty + text_signal
|
| 802 |
+
return float(max(-1.25, min(-0.02, reward)))
|
| 803 |
|
| 804 |
@staticmethod
|
| 805 |
def _completion_text_signal(completion_text: str) -> float:
|
|
|
|
| 952 |
swarm_diversity = self._swarm_diversity_score(candidate)
|
| 953 |
context_pressure = self._context_pressure_score(validation_result)
|
| 954 |
parl_parallel, parl_finish = self._parl_scores(candidate)
|
| 955 |
+
hardness_component = max(0.0, min(1.0, (hardness + 0.4) / 1.4))
|
| 956 |
+
consistency_component = max(
|
| 957 |
+
0.0,
|
| 958 |
+
min(
|
| 959 |
+
1.0,
|
| 960 |
+
(0.55 * context_pressure)
|
| 961 |
+
+ (0.25 * parl_parallel)
|
| 962 |
+
+ (0.20 * parl_finish),
|
| 963 |
+
),
|
| 964 |
+
)
|
| 965 |
+
completion_component = max(0.0, min(1.0, self._completion_text_signal(completion_text) / 0.25))
|
| 966 |
|
| 967 |
reward = (
|
| 968 |
+
self.weights.validity
|
| 969 |
+
+ (self.weights.hardness * hardness_component)
|
| 970 |
+
+ (self.weights.diversity * swarm_diversity)
|
| 971 |
+
+ (self.weights.consistency * consistency_component)
|
| 972 |
+
+ (0.05 * completion_component)
|
|
|
|
|
|
|
| 973 |
)
|
| 974 |
return reward, validation_result
|
| 975 |
|
src/osint_env/training/self_play.py
CHANGED
|
@@ -18,6 +18,7 @@ from osint_env.data.generator import (
|
|
| 18 |
)
|
| 19 |
from osint_env.domain.models import Edge, EnvironmentConfig, TaskInstance
|
| 20 |
from osint_env.env.environment import OSINTEnvironment
|
|
|
|
| 21 |
from osint_env.llm import build_llm_client
|
| 22 |
from osint_env.training.config import (
|
| 23 |
KimiGRPOPhaseConfig,
|
|
@@ -31,6 +32,8 @@ from osint_env.training.rewards import (
|
|
| 31 |
GeneratorRewardFunction,
|
| 32 |
SwarmV2ReplayValidator,
|
| 33 |
decode_completion_text,
|
|
|
|
|
|
|
| 34 |
parse_generated_task_completion,
|
| 35 |
)
|
| 36 |
|
|
@@ -99,6 +102,92 @@ def _edges_from_payload(rows: Any, max_edges: int) -> list[Edge]:
|
|
| 99 |
return edges
|
| 100 |
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def _canonical_example_payload(
|
| 104 |
graph: Any,
|
|
@@ -113,7 +202,6 @@ def _canonical_example_payload(
|
|
| 113 |
"answer": "",
|
| 114 |
"task_type": "swarm_v2_trace",
|
| 115 |
"supporting_edges": [],
|
| 116 |
-
"tool_trace": [],
|
| 117 |
"subagent_outputs": ["path_agent: no replayable edge"],
|
| 118 |
"orchestrator": {
|
| 119 |
"spawn_count": 1,
|
|
@@ -126,16 +214,18 @@ def _canonical_example_payload(
|
|
| 126 |
|
| 127 |
traced_edges = traced_edges[:2]
|
| 128 |
spawn_count = min(swarm_cfg.max_agents, max(1, len(traced_edges) + 1))
|
| 129 |
-
full_tool_trace = build_swarm_v2_tool_trace(graph, traced_edges)
|
| 130 |
return {
|
| 131 |
"question": emit_swarm_v2_question(traced_edges),
|
| 132 |
"answer": select_swarm_v2_answer(traced_edges),
|
| 133 |
"task_type": f"swarm_v2_{len(traced_edges)}hop_trace",
|
| 134 |
"supporting_edges": [_edge_payload(edge) for edge in traced_edges],
|
| 135 |
-
"tool_trace": full_tool_trace[:4],
|
| 136 |
"subagent_outputs": [
|
| 137 |
-
f"path_agent_{idx}: {edge.src}->{edge.dst}"
|
| 138 |
for idx, edge in enumerate(traced_edges)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
],
|
| 140 |
"orchestrator": {
|
| 141 |
"spawn_count": spawn_count,
|
|
@@ -176,10 +266,7 @@ def _swarm_v2_answer_prompt(
|
|
| 176 |
swarm_cfg: SwarmV2SwarmConfig,
|
| 177 |
) -> str:
|
| 178 |
del swarm_cfg # kept for signature compatibility
|
| 179 |
-
compact_context =
|
| 180 |
-
"nodes": list(shared_context.get("nodes", []))[:8],
|
| 181 |
-
"edges": list(shared_context.get("edges", []))[:6],
|
| 182 |
-
}
|
| 183 |
return (
|
| 184 |
"You answer one OSINT graph question using ONLY the shared context.\n"
|
| 185 |
"Output rules:\n"
|
|
@@ -222,21 +309,7 @@ def _build_swarm_v2_answerer_rows(
|
|
| 222 |
) -> list[dict[str, Any]]:
|
| 223 |
rows: list[dict[str, Any]] = []
|
| 224 |
for task in tasks:
|
| 225 |
-
|
| 226 |
-
canonical_graph = metadata.get("canonical_graph")
|
| 227 |
-
if isinstance(canonical_graph, dict):
|
| 228 |
-
shared_context = {
|
| 229 |
-
"nodes": list(canonical_graph.get("nodes", []))[: cfg.swarm_v2.shared_context.max_nodes],
|
| 230 |
-
"edges": list(canonical_graph.get("edges", []))[: cfg.swarm_v2.shared_context.max_edges],
|
| 231 |
-
}
|
| 232 |
-
else:
|
| 233 |
-
deterministic_seed = sum(ord(ch) for ch in task.task_id)
|
| 234 |
-
shared_context = _graph_context_for_prompt(
|
| 235 |
-
env=env,
|
| 236 |
-
max_nodes=cfg.swarm_v2.shared_context.max_nodes,
|
| 237 |
-
max_edges=cfg.swarm_v2.shared_context.max_edges,
|
| 238 |
-
rng=random.Random(deterministic_seed),
|
| 239 |
-
)
|
| 240 |
|
| 241 |
rows.append(
|
| 242 |
{
|
|
@@ -338,31 +411,39 @@ def _swarm_v2_generator_prompt(
|
|
| 338 |
anchors = "\n".join(f"- {question}" for question in anchor_questions)
|
| 339 |
canonical_mode = str(canonical_graph_mode).strip().lower() or "generate"
|
| 340 |
example_payload = _canonical_example_payload(graph, canonical_candidate, swarm_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
canonical_instruction = (
|
| 342 |
"You may propose canonical_graph updates when they improve replayability and keep it graph-grounded."
|
| 343 |
if canonical_mode == "generate"
|
| 344 |
else "Reuse the provided canonical candidate as-is; do not add, remove, or modify canonical_graph nodes/edges."
|
| 345 |
)
|
| 346 |
-
canonical_compact =
|
| 347 |
-
"nodes": list(canonical_candidate.get("nodes", []))[:8],
|
| 348 |
-
"edges": list(canonical_candidate.get("edges", []))[:6],
|
| 349 |
-
}
|
| 350 |
return (
|
| 351 |
-
"You
|
| 352 |
"Output rules:\n"
|
| 353 |
"- Return ONLY one JSON object. No markdown. No prose. End with }.\n"
|
| 354 |
-
"- Required keys: question, answer, task_type, supporting_edges,
|
| 355 |
-
"
|
| 356 |
"- supporting_edges: non-empty list of {src, rel, dst, confidence}, taken from canonical edges.\n"
|
| 357 |
-
"-
|
| 358 |
-
"
|
| 359 |
-
"- answer = final dst of the trace. question describes the path.\n"
|
|
|
|
| 360 |
"- orchestrator: integer keys spawn_count, finished_subtasks, critical_steps, breadth, depth.\n"
|
| 361 |
f"- canonical_graph_mode={canonical_mode}: {canonical_instruction}\n"
|
|
|
|
| 362 |
"Example (copy schema, not values):\n"
|
| 363 |
f"{json.dumps(example_payload, separators=(',', ':'), sort_keys=True)}\n"
|
|
|
|
|
|
|
| 364 |
"Canonical candidate (use these edges):\n"
|
| 365 |
f"{json.dumps(canonical_compact, separators=(',', ':'), sort_keys=True)}\n"
|
|
|
|
|
|
|
| 366 |
f"Avoid these prior questions: {anchors}\n"
|
| 367 |
"JSON:"
|
| 368 |
)
|
|
@@ -410,6 +491,15 @@ def _build_swarm_v2_generator_rows(
|
|
| 410 |
"prompt": prompt,
|
| 411 |
"candidate_id": f"candidate_{idx}",
|
| 412 |
"canonical_graph_json": json.dumps(canonical_candidate, sort_keys=True),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
}
|
| 414 |
)
|
| 415 |
canonical_candidates.append(canonical_candidate)
|
|
@@ -442,6 +532,16 @@ def _safe_build_grpo_config(
|
|
| 442 |
"scale_rewards": str(phase.scale_rewards),
|
| 443 |
"logging_steps": int(phase.logging_steps),
|
| 444 |
"save_steps": int(phase.save_steps),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
"remove_unused_columns": False,
|
| 446 |
"use_vllm": bool(phase.use_vllm),
|
| 447 |
"vllm_mode": str(phase.vllm_mode),
|
|
@@ -550,16 +650,25 @@ def _train_grpo_phase(
|
|
| 550 |
|
| 551 |
final_dir = output_dir / "final_model"
|
| 552 |
trainer.save_model(str(final_dir))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
global_step = int(getattr(train_output, "global_step", 0))
|
| 555 |
training_loss = float(getattr(train_output, "training_loss", 0.0))
|
| 556 |
|
|
|
|
| 557 |
result = {
|
| 558 |
"model_path": str(final_dir),
|
|
|
|
|
|
|
|
|
|
| 559 |
"global_step": global_step,
|
| 560 |
"training_loss": training_loss,
|
| 561 |
"train_rows": len(rows),
|
| 562 |
"tuning_mode": str(tuning_mode).strip().lower() or "full",
|
|
|
|
| 563 |
}
|
| 564 |
|
| 565 |
log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
|
|
@@ -599,7 +708,13 @@ def _train_grpo_phase(
|
|
| 599 |
"grad_norm_max": max(grad_norm_values) if grad_norm_values else 0.0,
|
| 600 |
"entropy_min": min(entropy_values) if entropy_values else 0.0,
|
| 601 |
"entropy_max": max(entropy_values) if entropy_values else 0.0,
|
|
|
|
| 602 |
"trainable_param_count": trainable_param_count,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
"params_with_grad": params_with_grad,
|
| 604 |
"nonzero_grad_tensors": nonzero_grad_tensors,
|
| 605 |
"fingerprint_param_count": len(pre_update_fingerprint),
|
|
@@ -721,8 +836,10 @@ def _sample_generated_tasks_with_model(
|
|
| 721 |
round_index: int,
|
| 722 |
count: int,
|
| 723 |
max_support_edges: int,
|
|
|
|
| 724 |
) -> list[TaskInstance]:
|
| 725 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 726 |
|
| 727 |
if count <= 0:
|
| 728 |
return []
|
|
@@ -730,11 +847,13 @@ def _sample_generated_tasks_with_model(
|
|
| 730 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 731 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 732 |
tokenizer.pad_token = tokenizer.eos_token
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
model.eval()
|
| 735 |
|
| 736 |
-
import torch
|
| 737 |
-
|
| 738 |
device = next(model.parameters()).device
|
| 739 |
generated: list[TaskInstance] = []
|
| 740 |
|
|
@@ -747,7 +866,7 @@ def _sample_generated_tasks_with_model(
|
|
| 747 |
with torch.no_grad():
|
| 748 |
output = model.generate(
|
| 749 |
**encoded,
|
| 750 |
-
max_new_tokens=
|
| 751 |
do_sample=True,
|
| 752 |
top_p=0.95,
|
| 753 |
temperature=1.0,
|
|
@@ -841,6 +960,256 @@ def _save_payload(path: Path, payload: Any) -> None:
|
|
| 841 |
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 842 |
|
| 843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
def _fallback_swarm_v2_completion_texts(
|
| 845 |
env: OSINTEnvironment,
|
| 846 |
cfg: SelfPlayTrainingConfig,
|
|
@@ -914,6 +1283,7 @@ def _sample_swarm_v2_completion_texts_with_model(
|
|
| 914 |
seen_questions: list[str],
|
| 915 |
) -> list[str]:
|
| 916 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 917 |
|
| 918 |
if count <= 0:
|
| 919 |
return []
|
|
@@ -921,11 +1291,13 @@ def _sample_swarm_v2_completion_texts_with_model(
|
|
| 921 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 922 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 923 |
tokenizer.pad_token = tokenizer.eos_token
|
| 924 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
model.eval()
|
| 926 |
|
| 927 |
-
import torch
|
| 928 |
-
|
| 929 |
device = next(model.parameters()).device
|
| 930 |
completions: list[str] = []
|
| 931 |
validator = SwarmV2ReplayValidator(
|
|
@@ -946,7 +1318,7 @@ def _sample_swarm_v2_completion_texts_with_model(
|
|
| 946 |
with torch.no_grad():
|
| 947 |
output = model.generate(
|
| 948 |
**encoded,
|
| 949 |
-
max_new_tokens=max(
|
| 950 |
do_sample=True,
|
| 951 |
top_p=top_p,
|
| 952 |
temperature=temperature,
|
|
@@ -1003,6 +1375,10 @@ def _materialize_swarm_v2_completions(
|
|
| 1003 |
max_support_edges=cfg.swarm_v2.validation.max_support_edges,
|
| 1004 |
)
|
| 1005 |
validation = validator.validate(candidate)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
|
| 1007 |
if use_fixed_canonical and prompt_canonical_candidates and completion_idx < len(prompt_canonical_candidates):
|
| 1008 |
canonical_graph = dict(prompt_canonical_candidates[completion_idx])
|
|
@@ -1045,14 +1421,7 @@ def _materialize_swarm_v2_completions(
|
|
| 1045 |
{
|
| 1046 |
"candidate_index": completion_idx,
|
| 1047 |
"question": candidate.question,
|
| 1048 |
-
"tool_trace":
|
| 1049 |
-
{
|
| 1050 |
-
"tool_name": call.tool_name,
|
| 1051 |
-
"args": dict(call.args),
|
| 1052 |
-
"output": dict(call.output),
|
| 1053 |
-
}
|
| 1054 |
-
for call in candidate.tool_trace
|
| 1055 |
-
],
|
| 1056 |
"replayed_edges": validation.to_dict()["replayed_edges"],
|
| 1057 |
}
|
| 1058 |
)
|
|
@@ -1078,14 +1447,7 @@ def _materialize_swarm_v2_completions(
|
|
| 1078 |
"difficulty": "hard",
|
| 1079 |
"scenario": "swarm_v2_trace",
|
| 1080 |
"canonical_graph": canonical_graph,
|
| 1081 |
-
"tool_trace":
|
| 1082 |
-
{
|
| 1083 |
-
"tool_name": call.tool_name,
|
| 1084 |
-
"args": dict(call.args),
|
| 1085 |
-
"output": dict(call.output),
|
| 1086 |
-
}
|
| 1087 |
-
for call in candidate.tool_trace
|
| 1088 |
-
],
|
| 1089 |
"subagent_outputs": list(candidate.subagent_outputs),
|
| 1090 |
"validation": validation.to_dict(),
|
| 1091 |
"shared_context_budget": {
|
|
@@ -1106,7 +1468,7 @@ def _materialize_swarm_v2_completions(
|
|
| 1106 |
task_type=candidate.task_type or "swarm_v2_trace",
|
| 1107 |
question=candidate.question,
|
| 1108 |
answer=candidate.answer,
|
| 1109 |
-
supporting_edges=
|
| 1110 |
metadata=metadata,
|
| 1111 |
)
|
| 1112 |
)
|
|
@@ -1132,6 +1494,8 @@ def _run_adversarial_self_play_swarm_v2(
|
|
| 1132 |
seed_tasks = list(env.tasks)
|
| 1133 |
seed_questions = [task.question for task in seed_tasks]
|
| 1134 |
generator_model, answerer_model = _resolve_initial_models(training_config)
|
|
|
|
|
|
|
| 1135 |
rng = random.Random(env_config.seed)
|
| 1136 |
|
| 1137 |
bootstrap_completions = _fallback_swarm_v2_completion_texts(
|
|
@@ -1383,6 +1747,18 @@ def _run_adversarial_self_play_swarm_v2(
|
|
| 1383 |
}
|
| 1384 |
)
|
| 1385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1386 |
final_payload = {
|
| 1387 |
"dry_run": effective_dry_run,
|
| 1388 |
"pipeline_mode": "swarm_v2",
|
|
@@ -1396,6 +1772,11 @@ def _run_adversarial_self_play_swarm_v2(
|
|
| 1396 |
"generator": generator_model,
|
| 1397 |
"answerer": answerer_model,
|
| 1398 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1399 |
"kimi_objective_mapping": {
|
| 1400 |
"grouped_rollouts": "TRL GRPO num_generations",
|
| 1401 |
"mean_centered_advantage": "GRPO relative reward baseline",
|
|
@@ -1437,6 +1818,8 @@ def run_adversarial_self_play(
|
|
| 1437 |
seed_tasks = list(env.tasks)
|
| 1438 |
|
| 1439 |
generator_model, answerer_model = _resolve_initial_models(training_config)
|
|
|
|
|
|
|
| 1440 |
|
| 1441 |
rng = random.Random(env_config.seed)
|
| 1442 |
rounds_payload: list[dict[str, Any]] = []
|
|
@@ -1557,6 +1940,7 @@ def run_adversarial_self_play(
|
|
| 1557 |
round_index=round_index,
|
| 1558 |
count=training_config.generated_tasks_per_round,
|
| 1559 |
max_support_edges=training_config.max_support_edges,
|
|
|
|
| 1560 |
)
|
| 1561 |
if not generated_tasks:
|
| 1562 |
generated_tasks = _fallback_generated_tasks(
|
|
@@ -1641,6 +2025,18 @@ def run_adversarial_self_play(
|
|
| 1641 |
}
|
| 1642 |
)
|
| 1643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1644 |
final_payload = {
|
| 1645 |
"dry_run": effective_dry_run,
|
| 1646 |
"pipeline_mode": "legacy",
|
|
@@ -1654,6 +2050,11 @@ def run_adversarial_self_play(
|
|
| 1654 |
"generator": generator_model,
|
| 1655 |
"answerer": answerer_model,
|
| 1656 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1657 |
"kimi_objective_mapping": {
|
| 1658 |
"grouped_rollouts": "TRL GRPO num_generations",
|
| 1659 |
"mean_centered_advantage": "GRPO relative reward baseline",
|
|
|
|
| 18 |
)
|
| 19 |
from osint_env.domain.models import Edge, EnvironmentConfig, TaskInstance
|
| 20 |
from osint_env.env.environment import OSINTEnvironment
|
| 21 |
+
from osint_env.env.reward import compute_graph_f1
|
| 22 |
from osint_env.llm import build_llm_client
|
| 23 |
from osint_env.training.config import (
|
| 24 |
KimiGRPOPhaseConfig,
|
|
|
|
| 32 |
GeneratorRewardFunction,
|
| 33 |
SwarmV2ReplayValidator,
|
| 34 |
decode_completion_text,
|
| 35 |
+
extract_answer_from_completion,
|
| 36 |
+
normalize_answer,
|
| 37 |
parse_generated_task_completion,
|
| 38 |
)
|
| 39 |
|
|
|
|
| 102 |
return edges
|
| 103 |
|
| 104 |
|
| 105 |
+
def _compact_shared_context(
|
| 106 |
+
shared_context: dict[str, Any],
|
| 107 |
+
max_nodes: int = 8,
|
| 108 |
+
max_edges: int = 6,
|
| 109 |
+
) -> dict[str, Any]:
|
| 110 |
+
return {
|
| 111 |
+
"nodes": list(shared_context.get("nodes", []))[:max_nodes],
|
| 112 |
+
"edges": list(shared_context.get("edges", []))[:max_edges],
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _task_shared_context(
|
| 117 |
+
env: OSINTEnvironment,
|
| 118 |
+
task: TaskInstance,
|
| 119 |
+
cfg: SelfPlayTrainingConfig,
|
| 120 |
+
) -> dict[str, Any]:
|
| 121 |
+
metadata = dict(task.metadata or {})
|
| 122 |
+
canonical_graph = metadata.get("canonical_graph")
|
| 123 |
+
if isinstance(canonical_graph, dict):
|
| 124 |
+
return {
|
| 125 |
+
"nodes": list(canonical_graph.get("nodes", []))[: cfg.swarm_v2.shared_context.max_nodes],
|
| 126 |
+
"edges": list(canonical_graph.get("edges", []))[: cfg.swarm_v2.shared_context.max_edges],
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
deterministic_seed = sum(ord(ch) for ch in task.task_id)
|
| 130 |
+
return _graph_context_for_prompt(
|
| 131 |
+
env=env,
|
| 132 |
+
max_nodes=cfg.swarm_v2.shared_context.max_nodes,
|
| 133 |
+
max_edges=cfg.swarm_v2.shared_context.max_edges,
|
| 134 |
+
rng=random.Random(deterministic_seed),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _swarm_v2_worker_packets(
|
| 139 |
+
canonical_candidate: dict[str, Any],
|
| 140 |
+
shared_context: dict[str, Any],
|
| 141 |
+
swarm_cfg: SwarmV2SwarmConfig,
|
| 142 |
+
) -> dict[str, Any]:
|
| 143 |
+
path_edges = _edges_from_payload(
|
| 144 |
+
canonical_candidate.get("path", canonical_candidate.get("edges", [])),
|
| 145 |
+
max_edges=max(1, swarm_cfg.max_depth * 2),
|
| 146 |
+
)
|
| 147 |
+
if not path_edges:
|
| 148 |
+
path_edges = _edges_from_payload(canonical_candidate.get("edges", []), max_edges=2)
|
| 149 |
+
relation_path = [edge.rel for edge in path_edges]
|
| 150 |
+
start_node = path_edges[0].src if path_edges else ""
|
| 151 |
+
return {
|
| 152 |
+
"path_agent": {
|
| 153 |
+
"path_edges": [_edge_payload(edge) for edge in path_edges],
|
| 154 |
+
"goal": "Choose one contiguous replayable path from the canonical candidate.",
|
| 155 |
+
},
|
| 156 |
+
"question_agent": {
|
| 157 |
+
"start_node": start_node,
|
| 158 |
+
"relation_path": relation_path,
|
| 159 |
+
"goal": "Write a compact question that describes the path without leaking the answer.",
|
| 160 |
+
},
|
| 161 |
+
"context_agent": {
|
| 162 |
+
"shared_context": _compact_shared_context(shared_context),
|
| 163 |
+
"goal": "Keep support/context usage compact and graph-grounded.",
|
| 164 |
+
},
|
| 165 |
+
"planner": {
|
| 166 |
+
"max_agents": int(swarm_cfg.max_agents),
|
| 167 |
+
"max_breadth": int(swarm_cfg.max_breadth),
|
| 168 |
+
"max_depth": int(swarm_cfg.max_depth),
|
| 169 |
+
},
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _serialize_tool_trace(tool_trace: Any) -> list[dict[str, Any]]:
|
| 174 |
+
serialized: list[dict[str, Any]] = []
|
| 175 |
+
for call in tool_trace or []:
|
| 176 |
+
tool_name = getattr(call, "tool_name", "")
|
| 177 |
+
args = getattr(call, "args", {})
|
| 178 |
+
output = getattr(call, "output", {})
|
| 179 |
+
if not tool_name:
|
| 180 |
+
continue
|
| 181 |
+
serialized.append(
|
| 182 |
+
{
|
| 183 |
+
"tool_name": str(tool_name),
|
| 184 |
+
"args": dict(args) if isinstance(args, dict) else {},
|
| 185 |
+
"output": dict(output) if isinstance(output, dict) else {},
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
return serialized
|
| 189 |
+
|
| 190 |
+
|
| 191 |
|
| 192 |
def _canonical_example_payload(
|
| 193 |
graph: Any,
|
|
|
|
| 202 |
"answer": "",
|
| 203 |
"task_type": "swarm_v2_trace",
|
| 204 |
"supporting_edges": [],
|
|
|
|
| 205 |
"subagent_outputs": ["path_agent: no replayable edge"],
|
| 206 |
"orchestrator": {
|
| 207 |
"spawn_count": 1,
|
|
|
|
| 214 |
|
| 215 |
traced_edges = traced_edges[:2]
|
| 216 |
spawn_count = min(swarm_cfg.max_agents, max(1, len(traced_edges) + 1))
|
|
|
|
| 217 |
return {
|
| 218 |
"question": emit_swarm_v2_question(traced_edges),
|
| 219 |
"answer": select_swarm_v2_answer(traced_edges),
|
| 220 |
"task_type": f"swarm_v2_{len(traced_edges)}hop_trace",
|
| 221 |
"supporting_edges": [_edge_payload(edge) for edge in traced_edges],
|
|
|
|
| 222 |
"subagent_outputs": [
|
| 223 |
+
f"path_agent_{idx}: {edge.src} --{edge.rel}--> {edge.dst}"
|
| 224 |
for idx, edge in enumerate(traced_edges)
|
| 225 |
+
]
|
| 226 |
+
+ [
|
| 227 |
+
"question_agent: emitted compact relation-path question",
|
| 228 |
+
"context_agent: kept shared context focused on replayable edges",
|
| 229 |
],
|
| 230 |
"orchestrator": {
|
| 231 |
"spawn_count": spawn_count,
|
|
|
|
| 266 |
swarm_cfg: SwarmV2SwarmConfig,
|
| 267 |
) -> str:
|
| 268 |
del swarm_cfg # kept for signature compatibility
|
| 269 |
+
compact_context = _compact_shared_context(shared_context)
|
|
|
|
|
|
|
|
|
|
| 270 |
return (
|
| 271 |
"You answer one OSINT graph question using ONLY the shared context.\n"
|
| 272 |
"Output rules:\n"
|
|
|
|
| 309 |
) -> list[dict[str, Any]]:
|
| 310 |
rows: list[dict[str, Any]] = []
|
| 311 |
for task in tasks:
|
| 312 |
+
shared_context = _task_shared_context(env=env, task=task, cfg=cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
rows.append(
|
| 315 |
{
|
|
|
|
| 411 |
anchors = "\n".join(f"- {question}" for question in anchor_questions)
|
| 412 |
canonical_mode = str(canonical_graph_mode).strip().lower() or "generate"
|
| 413 |
example_payload = _canonical_example_payload(graph, canonical_candidate, swarm_cfg)
|
| 414 |
+
worker_packets = _swarm_v2_worker_packets(
|
| 415 |
+
canonical_candidate=canonical_candidate,
|
| 416 |
+
shared_context=shared_context,
|
| 417 |
+
swarm_cfg=swarm_cfg,
|
| 418 |
+
)
|
| 419 |
canonical_instruction = (
|
| 420 |
"You may propose canonical_graph updates when they improve replayability and keep it graph-grounded."
|
| 421 |
if canonical_mode == "generate"
|
| 422 |
else "Reuse the provided canonical candidate as-is; do not add, remove, or modify canonical_graph nodes/edges."
|
| 423 |
)
|
| 424 |
+
canonical_compact = _compact_shared_context(canonical_candidate)
|
|
|
|
|
|
|
|
|
|
| 425 |
return (
|
| 426 |
+
"You coordinate a compact multi-agent OSINT task-generation swarm.\n"
|
| 427 |
"Output rules:\n"
|
| 428 |
"- Return ONLY one JSON object. No markdown. No prose. End with }.\n"
|
| 429 |
+
"- Required keys: question, answer, task_type, supporting_edges, subagent_outputs, orchestrator.\n"
|
| 430 |
+
"- Optional keys: canonical_graph, validation.\n"
|
| 431 |
"- supporting_edges: non-empty list of {src, rel, dst, confidence}, taken from canonical edges.\n"
|
| 432 |
+
"- supporting_edges must form one contiguous replayable path. Keep it compact.\n"
|
| 433 |
+
"- Do NOT emit verbose tool traces or neighbor dumps; replay tools are derived from supporting_edges.\n"
|
| 434 |
+
"- answer = final dst of the trace. question describes the path without leaking the answer.\n"
|
| 435 |
+
"- subagent_outputs: 2-4 terse strings summarizing path_agent/question_agent/context_agent work.\n"
|
| 436 |
"- orchestrator: integer keys spawn_count, finished_subtasks, critical_steps, breadth, depth.\n"
|
| 437 |
f"- canonical_graph_mode={canonical_mode}: {canonical_instruction}\n"
|
| 438 |
+
"- Favor minimal shared context per worker so question generation stays parallel-friendly.\n"
|
| 439 |
"Example (copy schema, not values):\n"
|
| 440 |
f"{json.dumps(example_payload, separators=(',', ':'), sort_keys=True)}\n"
|
| 441 |
+
"Worker packets:\n"
|
| 442 |
+
f"{json.dumps(worker_packets, separators=(',', ':'), sort_keys=True)}\n"
|
| 443 |
"Canonical candidate (use these edges):\n"
|
| 444 |
f"{json.dumps(canonical_compact, separators=(',', ':'), sort_keys=True)}\n"
|
| 445 |
+
"Shared context:\n"
|
| 446 |
+
f"{json.dumps(_compact_shared_context(shared_context), separators=(',', ':'), sort_keys=True)}\n"
|
| 447 |
f"Avoid these prior questions: {anchors}\n"
|
| 448 |
"JSON:"
|
| 449 |
)
|
|
|
|
| 491 |
"prompt": prompt,
|
| 492 |
"candidate_id": f"candidate_{idx}",
|
| 493 |
"canonical_graph_json": json.dumps(canonical_candidate, sort_keys=True),
|
| 494 |
+
"shared_context_json": json.dumps(shared_context, sort_keys=True),
|
| 495 |
+
"worker_packets_json": json.dumps(
|
| 496 |
+
_swarm_v2_worker_packets(
|
| 497 |
+
canonical_candidate=canonical_candidate,
|
| 498 |
+
shared_context=shared_context,
|
| 499 |
+
swarm_cfg=cfg.swarm_v2.generator_swarm,
|
| 500 |
+
),
|
| 501 |
+
sort_keys=True,
|
| 502 |
+
),
|
| 503 |
}
|
| 504 |
)
|
| 505 |
canonical_candidates.append(canonical_candidate)
|
|
|
|
| 532 |
"scale_rewards": str(phase.scale_rewards),
|
| 533 |
"logging_steps": int(phase.logging_steps),
|
| 534 |
"save_steps": int(phase.save_steps),
|
| 535 |
+
"save_total_limit": int(phase.save_total_limit),
|
| 536 |
+
"optim": str(phase.optim),
|
| 537 |
+
"bf16": bool(phase.bf16),
|
| 538 |
+
"tf32": bool(phase.tf32),
|
| 539 |
+
"gradient_checkpointing": bool(phase.gradient_checkpointing),
|
| 540 |
+
"dataloader_num_workers": int(phase.dataloader_num_workers),
|
| 541 |
+
"dataloader_persistent_workers": bool(phase.dataloader_persistent_workers),
|
| 542 |
+
"dataloader_prefetch_factor": int(phase.dataloader_prefetch_factor),
|
| 543 |
+
"generation_batch_size": int(phase.generation_batch_size),
|
| 544 |
+
"max_prompt_length": int(phase.max_prompt_length),
|
| 545 |
"remove_unused_columns": False,
|
| 546 |
"use_vllm": bool(phase.use_vllm),
|
| 547 |
"vllm_mode": str(phase.vllm_mode),
|
|
|
|
| 650 |
|
| 651 |
final_dir = output_dir / "final_model"
|
| 652 |
trainer.save_model(str(final_dir))
|
| 653 |
+
trainer_tokenizer = getattr(trainer, "processing_class", None) or getattr(trainer, "tokenizer", None)
|
| 654 |
+
if trainer_tokenizer is not None and hasattr(trainer_tokenizer, "save_pretrained"):
|
| 655 |
+
trainer_tokenizer.save_pretrained(str(final_dir))
|
| 656 |
+
checkpoint_dirs = [str(path) for path in sorted(output_dir.glob("checkpoint-*")) if path.is_dir()]
|
| 657 |
|
| 658 |
global_step = int(getattr(train_output, "global_step", 0))
|
| 659 |
training_loss = float(getattr(train_output, "training_loss", 0.0))
|
| 660 |
|
| 661 |
+
total_param_count = int(sum(param.numel() for param in trainer.model.parameters()))
|
| 662 |
result = {
|
| 663 |
"model_path": str(final_dir),
|
| 664 |
+
"final_model_path": str(final_dir),
|
| 665 |
+
"phase_output_dir": str(output_dir),
|
| 666 |
+
"checkpoint_dirs": checkpoint_dirs,
|
| 667 |
"global_step": global_step,
|
| 668 |
"training_loss": training_loss,
|
| 669 |
"train_rows": len(rows),
|
| 670 |
"tuning_mode": str(tuning_mode).strip().lower() or "full",
|
| 671 |
+
"is_full_finetune": str(tuning_mode).strip().lower() != "lora",
|
| 672 |
}
|
| 673 |
|
| 674 |
log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
|
|
|
|
| 708 |
"grad_norm_max": max(grad_norm_values) if grad_norm_values else 0.0,
|
| 709 |
"entropy_min": min(entropy_values) if entropy_values else 0.0,
|
| 710 |
"entropy_max": max(entropy_values) if entropy_values else 0.0,
|
| 711 |
+
"total_param_count": total_param_count,
|
| 712 |
"trainable_param_count": trainable_param_count,
|
| 713 |
+
"trainable_fraction": (
|
| 714 |
+
float(trainable_param_count / total_param_count)
|
| 715 |
+
if total_param_count > 0
|
| 716 |
+
else 0.0
|
| 717 |
+
),
|
| 718 |
"params_with_grad": params_with_grad,
|
| 719 |
"nonzero_grad_tensors": nonzero_grad_tensors,
|
| 720 |
"fingerprint_param_count": len(pre_update_fingerprint),
|
|
|
|
| 836 |
round_index: int,
|
| 837 |
count: int,
|
| 838 |
max_support_edges: int,
|
| 839 |
+
max_new_tokens: int,
|
| 840 |
) -> list[TaskInstance]:
|
| 841 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 842 |
+
import torch
|
| 843 |
|
| 844 |
if count <= 0:
|
| 845 |
return []
|
|
|
|
| 847 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 848 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 849 |
tokenizer.pad_token = tokenizer.eos_token
|
| 850 |
+
model_kwargs: dict[str, Any] = {}
|
| 851 |
+
if torch.cuda.is_available():
|
| 852 |
+
model_kwargs["device_map"] = "auto"
|
| 853 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 854 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 855 |
model.eval()
|
| 856 |
|
|
|
|
|
|
|
| 857 |
device = next(model.parameters()).device
|
| 858 |
generated: list[TaskInstance] = []
|
| 859 |
|
|
|
|
| 866 |
with torch.no_grad():
|
| 867 |
output = model.generate(
|
| 868 |
**encoded,
|
| 869 |
+
max_new_tokens=max(64, int(max_new_tokens)),
|
| 870 |
do_sample=True,
|
| 871 |
top_p=0.95,
|
| 872 |
temperature=1.0,
|
|
|
|
| 960 |
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 961 |
|
| 962 |
|
| 963 |
+
def _generate_answerer_completion_texts_with_model(
|
| 964 |
+
model_name_or_path: str,
|
| 965 |
+
prompts: list[str],
|
| 966 |
+
max_new_tokens: int,
|
| 967 |
+
) -> list[str]:
|
| 968 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 969 |
+
import torch
|
| 970 |
+
|
| 971 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 972 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 973 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 974 |
+
|
| 975 |
+
model_kwargs: dict[str, Any] = {}
|
| 976 |
+
if torch.cuda.is_available():
|
| 977 |
+
model_kwargs["device_map"] = "auto"
|
| 978 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 979 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 980 |
+
model.eval()
|
| 981 |
+
device = next(model.parameters()).device
|
| 982 |
+
|
| 983 |
+
completions: list[str] = []
|
| 984 |
+
for prompt in prompts:
|
| 985 |
+
encoded = tokenizer(prompt, return_tensors="pt")
|
| 986 |
+
encoded = {key: value.to(device) for key, value in encoded.items()}
|
| 987 |
+
with torch.no_grad():
|
| 988 |
+
output = model.generate(
|
| 989 |
+
**encoded,
|
| 990 |
+
max_new_tokens=max(16, int(max_new_tokens)),
|
| 991 |
+
do_sample=False,
|
| 992 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 993 |
+
)
|
| 994 |
+
completion_ids = output[0][encoded["input_ids"].shape[1] :]
|
| 995 |
+
completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
|
| 996 |
+
return completions
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def _top_validation_reasons(validation_reports: list[dict[str, Any]]) -> list[tuple[str, int]]:
|
| 1000 |
+
counts: dict[str, int] = {}
|
| 1001 |
+
for report in validation_reports:
|
| 1002 |
+
validation = report.get("validation", {}) if isinstance(report, dict) else {}
|
| 1003 |
+
reasons = validation.get("reasons", []) if isinstance(validation, dict) else []
|
| 1004 |
+
for reason in reasons:
|
| 1005 |
+
token = str(reason).strip()
|
| 1006 |
+
if not token:
|
| 1007 |
+
continue
|
| 1008 |
+
counts[token] = counts.get(token, 0) + 1
|
| 1009 |
+
return sorted(counts.items(), key=lambda item: (-item[1], item[0]))
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def _run_post_training_evaluation(
|
| 1013 |
+
env_config: EnvironmentConfig,
|
| 1014 |
+
training_config: SelfPlayTrainingConfig,
|
| 1015 |
+
generator_model: str,
|
| 1016 |
+
answerer_models: dict[str, str],
|
| 1017 |
+
output_dir: Path,
|
| 1018 |
+
pipeline_mode: str,
|
| 1019 |
+
effective_dry_run: bool,
|
| 1020 |
+
) -> dict[str, Any]:
|
| 1021 |
+
tasks_path = output_dir / "post_training_eval_generated_tasks.json"
|
| 1022 |
+
validation_path = output_dir / "post_training_eval_validation_reports.json"
|
| 1023 |
+
payload_path = output_dir / "post_training_evaluation.json"
|
| 1024 |
+
payload: dict[str, Any] = {
|
| 1025 |
+
"pipeline_mode": pipeline_mode,
|
| 1026 |
+
"generator_model": generator_model,
|
| 1027 |
+
"answerer_models": dict(answerer_models),
|
| 1028 |
+
"generated_tasks_path": str(tasks_path),
|
| 1029 |
+
"validation_reports_path": str(validation_path),
|
| 1030 |
+
"skipped": False,
|
| 1031 |
+
}
|
| 1032 |
+
|
| 1033 |
+
if effective_dry_run:
|
| 1034 |
+
payload.update({"skipped": True, "reason": "dry_run"})
|
| 1035 |
+
_save_payload(validation_path, [])
|
| 1036 |
+
_save_payload(tasks_path, [])
|
| 1037 |
+
_save_payload(payload_path, payload)
|
| 1038 |
+
payload["path"] = str(payload_path)
|
| 1039 |
+
return payload
|
| 1040 |
+
|
| 1041 |
+
try:
|
| 1042 |
+
env = OSINTEnvironment(env_config, llm=build_llm_client(env_config.llm))
|
| 1043 |
+
rng = random.Random(env_config.seed + 9973)
|
| 1044 |
+
validation_reports: list[dict[str, Any]] = []
|
| 1045 |
+
|
| 1046 |
+
if pipeline_mode == "swarm_v2":
|
| 1047 |
+
generator_rows, prompt_canonical_candidates = _build_swarm_v2_generator_rows(env, training_config, rng)
|
| 1048 |
+
completion_texts = _sample_swarm_v2_completion_texts_with_model(
|
| 1049 |
+
env=env,
|
| 1050 |
+
cfg=training_config,
|
| 1051 |
+
model_name_or_path=generator_model,
|
| 1052 |
+
prompts=[row["prompt"] for row in generator_rows],
|
| 1053 |
+
count=max(1, training_config.post_training_eval_questions * 2),
|
| 1054 |
+
seen_questions=[task.question for task in env.tasks],
|
| 1055 |
+
)
|
| 1056 |
+
generated_tasks, validation_reports, _, _ = _materialize_swarm_v2_completions(
|
| 1057 |
+
env=env,
|
| 1058 |
+
cfg=training_config,
|
| 1059 |
+
completion_texts=completion_texts,
|
| 1060 |
+
round_index=max(1, training_config.rounds) + 1,
|
| 1061 |
+
seen_questions=[task.question for task in env.tasks],
|
| 1062 |
+
prompt_canonical_candidates=prompt_canonical_candidates,
|
| 1063 |
+
)
|
| 1064 |
+
if not generated_tasks:
|
| 1065 |
+
generated_tasks, validation_reports, _, _ = _materialize_swarm_v2_completions(
|
| 1066 |
+
env=env,
|
| 1067 |
+
cfg=training_config,
|
| 1068 |
+
completion_texts=_fallback_swarm_v2_completion_texts(
|
| 1069 |
+
env=env,
|
| 1070 |
+
cfg=training_config,
|
| 1071 |
+
round_index=max(1, training_config.rounds) + 1,
|
| 1072 |
+
rng=rng,
|
| 1073 |
+
),
|
| 1074 |
+
round_index=max(1, training_config.rounds) + 1,
|
| 1075 |
+
seen_questions=[task.question for task in env.tasks],
|
| 1076 |
+
prompt_canonical_candidates=None,
|
| 1077 |
+
)
|
| 1078 |
+
generated_tasks = generated_tasks[: max(1, training_config.post_training_eval_questions)]
|
| 1079 |
+
answer_rows = _build_swarm_v2_answerer_rows(env, generated_tasks, training_config)
|
| 1080 |
+
reward_fn = AnswererRewardFunction(
|
| 1081 |
+
graph=env.graph,
|
| 1082 |
+
pipeline_mode="swarm_v2",
|
| 1083 |
+
parl_max_parallel_hint=training_config.swarm_v2.answerer_swarm.max_agents,
|
| 1084 |
+
)
|
| 1085 |
+
else:
|
| 1086 |
+
generator_rows = _build_generator_rows(env=env, cfg=training_config, rng=rng)
|
| 1087 |
+
generated_tasks = _sample_generated_tasks_with_model(
|
| 1088 |
+
model_name_or_path=generator_model,
|
| 1089 |
+
prompts=[row["prompt"] for row in generator_rows],
|
| 1090 |
+
round_index=max(1, training_config.rounds) + 1,
|
| 1091 |
+
count=max(1, training_config.post_training_eval_questions),
|
| 1092 |
+
max_support_edges=training_config.max_support_edges,
|
| 1093 |
+
max_new_tokens=training_config.generated_task_max_new_tokens,
|
| 1094 |
+
)
|
| 1095 |
+
if not generated_tasks:
|
| 1096 |
+
generated_tasks = _fallback_generated_tasks(
|
| 1097 |
+
base_tasks=list(env.tasks),
|
| 1098 |
+
round_index=max(1, training_config.rounds) + 1,
|
| 1099 |
+
count=max(1, training_config.post_training_eval_questions),
|
| 1100 |
+
rng=rng,
|
| 1101 |
+
)
|
| 1102 |
+
answer_rows = _build_answerer_rows(generated_tasks)
|
| 1103 |
+
reward_fn = AnswererRewardFunction(graph=env.graph)
|
| 1104 |
+
|
| 1105 |
+
_save_tasks(tasks_path, generated_tasks)
|
| 1106 |
+
_save_payload(validation_path, validation_reports)
|
| 1107 |
+
|
| 1108 |
+
model_evaluations: dict[str, dict[str, Any]] = {}
|
| 1109 |
+
for model_label, answerer_model in answerer_models.items():
|
| 1110 |
+
answerer_completions = _generate_answerer_completion_texts_with_model(
|
| 1111 |
+
model_name_or_path=answerer_model,
|
| 1112 |
+
prompts=[row["prompt"] for row in answer_rows],
|
| 1113 |
+
max_new_tokens=training_config.post_training_eval_answer_max_new_tokens,
|
| 1114 |
+
)
|
| 1115 |
+
rewards = reward_fn(
|
| 1116 |
+
prompts=[row["prompt"] for row in answer_rows],
|
| 1117 |
+
completions=answerer_completions,
|
| 1118 |
+
answer=[row["answer"] for row in answer_rows],
|
| 1119 |
+
question=[row["question"] for row in answer_rows],
|
| 1120 |
+
supporting_edges_json=[row["supporting_edges_json"] for row in answer_rows],
|
| 1121 |
+
difficulty=[row["difficulty"] for row in answer_rows],
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
episodes: list[dict[str, Any]] = []
|
| 1125 |
+
for task, row, completion_text, reward in zip(generated_tasks, answer_rows, answerer_completions, rewards):
|
| 1126 |
+
support_edges = AnswererRewardFunction._parse_support_edges(row["supporting_edges_json"])
|
| 1127 |
+
pred_edges = AnswererRewardFunction._extract_predicted_edges(completion_text, support_edges)
|
| 1128 |
+
predicted_answer = normalize_answer(extract_answer_from_completion(completion_text))
|
| 1129 |
+
target_answer = normalize_answer(task.answer)
|
| 1130 |
+
graph_f1 = compute_graph_f1(pred_edges, support_edges)
|
| 1131 |
+
episodes.append(
|
| 1132 |
+
{
|
| 1133 |
+
"task_id": task.task_id,
|
| 1134 |
+
"task_type": task.task_type,
|
| 1135 |
+
"question": task.question,
|
| 1136 |
+
"task_answer": target_answer,
|
| 1137 |
+
"agent_answer": predicted_answer,
|
| 1138 |
+
"reward": float(reward),
|
| 1139 |
+
"graph_f1": float(graph_f1),
|
| 1140 |
+
"success": int(predicted_answer == target_answer),
|
| 1141 |
+
"support_edge_count": len(support_edges),
|
| 1142 |
+
"predicted_edge_count": len(pred_edges),
|
| 1143 |
+
"completion_length": len(completion_text),
|
| 1144 |
+
}
|
| 1145 |
+
)
|
| 1146 |
+
|
| 1147 |
+
episode_count = len(episodes)
|
| 1148 |
+
model_evaluations[model_label] = {
|
| 1149 |
+
"model_path": answerer_model,
|
| 1150 |
+
"episodes": episodes,
|
| 1151 |
+
"summary": {
|
| 1152 |
+
"episodes": episode_count,
|
| 1153 |
+
"task_success_rate": (
|
| 1154 |
+
float(sum(row["success"] for row in episodes) / max(1, episode_count))
|
| 1155 |
+
if episodes
|
| 1156 |
+
else 0.0
|
| 1157 |
+
),
|
| 1158 |
+
"avg_reward": (
|
| 1159 |
+
float(sum(float(row["reward"]) for row in episodes) / max(1, episode_count))
|
| 1160 |
+
if episodes
|
| 1161 |
+
else 0.0
|
| 1162 |
+
),
|
| 1163 |
+
"avg_graph_f1": (
|
| 1164 |
+
float(sum(float(row["graph_f1"]) for row in episodes) / max(1, episode_count))
|
| 1165 |
+
if episodes
|
| 1166 |
+
else 0.0
|
| 1167 |
+
),
|
| 1168 |
+
"avg_completion_length": (
|
| 1169 |
+
float(sum(int(row["completion_length"]) for row in episodes) / max(1, episode_count))
|
| 1170 |
+
if episodes
|
| 1171 |
+
else 0.0
|
| 1172 |
+
),
|
| 1173 |
+
},
|
| 1174 |
+
}
|
| 1175 |
+
|
| 1176 |
+
final_summary = model_evaluations.get("finetuned_answerer", {}).get("summary", {})
|
| 1177 |
+
baseline_summary = model_evaluations.get("original_answerer", {}).get("summary", {})
|
| 1178 |
+
summary = {
|
| 1179 |
+
"generated_task_count": len(generated_tasks),
|
| 1180 |
+
"generator_valid_rate": (
|
| 1181 |
+
float(len(generated_tasks) / max(1, len(validation_reports)))
|
| 1182 |
+
if validation_reports
|
| 1183 |
+
else 1.0
|
| 1184 |
+
),
|
| 1185 |
+
"compared_models": sorted(model_evaluations.keys()),
|
| 1186 |
+
"finetuned_answerer": dict(final_summary),
|
| 1187 |
+
"original_answerer": dict(baseline_summary),
|
| 1188 |
+
"delta_vs_original": {
|
| 1189 |
+
"task_success_rate": float(final_summary.get("task_success_rate", 0.0) - baseline_summary.get("task_success_rate", 0.0)),
|
| 1190 |
+
"avg_reward": float(final_summary.get("avg_reward", 0.0) - baseline_summary.get("avg_reward", 0.0)),
|
| 1191 |
+
"avg_graph_f1": float(final_summary.get("avg_graph_f1", 0.0) - baseline_summary.get("avg_graph_f1", 0.0)),
|
| 1192 |
+
},
|
| 1193 |
+
"top_generator_invalid_reasons": _top_validation_reasons(validation_reports)[:5],
|
| 1194 |
+
}
|
| 1195 |
+
payload.update(
|
| 1196 |
+
{
|
| 1197 |
+
"summary": summary,
|
| 1198 |
+
"model_evaluations": model_evaluations,
|
| 1199 |
+
}
|
| 1200 |
+
)
|
| 1201 |
+
except Exception as exc:
|
| 1202 |
+
payload.update({"skipped": True, "reason": f"{type(exc).__name__}: {exc}"})
|
| 1203 |
+
|
| 1204 |
+
if not tasks_path.exists():
|
| 1205 |
+
_save_payload(tasks_path, [])
|
| 1206 |
+
if not validation_path.exists():
|
| 1207 |
+
_save_payload(validation_path, [])
|
| 1208 |
+
_save_payload(payload_path, payload)
|
| 1209 |
+
payload["path"] = str(payload_path)
|
| 1210 |
+
return payload
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
def _fallback_swarm_v2_completion_texts(
|
| 1214 |
env: OSINTEnvironment,
|
| 1215 |
cfg: SelfPlayTrainingConfig,
|
|
|
|
| 1283 |
seen_questions: list[str],
|
| 1284 |
) -> list[str]:
|
| 1285 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1286 |
+
import torch
|
| 1287 |
|
| 1288 |
if count <= 0:
|
| 1289 |
return []
|
|
|
|
| 1291 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 1292 |
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 1293 |
tokenizer.pad_token = tokenizer.eos_token
|
| 1294 |
+
model_kwargs: dict[str, Any] = {}
|
| 1295 |
+
if torch.cuda.is_available():
|
| 1296 |
+
model_kwargs["device_map"] = "auto"
|
| 1297 |
+
model_kwargs["torch_dtype"] = torch.bfloat16
|
| 1298 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
|
| 1299 |
model.eval()
|
| 1300 |
|
|
|
|
|
|
|
| 1301 |
device = next(model.parameters()).device
|
| 1302 |
completions: list[str] = []
|
| 1303 |
validator = SwarmV2ReplayValidator(
|
|
|
|
| 1318 |
with torch.no_grad():
|
| 1319 |
output = model.generate(
|
| 1320 |
**encoded,
|
| 1321 |
+
max_new_tokens=max(64, int(cfg.generated_task_max_new_tokens)),
|
| 1322 |
do_sample=True,
|
| 1323 |
top_p=top_p,
|
| 1324 |
temperature=temperature,
|
|
|
|
| 1375 |
max_support_edges=cfg.swarm_v2.validation.max_support_edges,
|
| 1376 |
)
|
| 1377 |
validation = validator.validate(candidate)
|
| 1378 |
+
replay_edges = list(validation.replayed_edges or candidate.supporting_edges)
|
| 1379 |
+
materialized_tool_trace = _serialize_tool_trace(candidate.tool_trace)
|
| 1380 |
+
if not materialized_tool_trace and replay_edges:
|
| 1381 |
+
materialized_tool_trace = build_swarm_v2_tool_trace(env.graph, replay_edges)
|
| 1382 |
|
| 1383 |
if use_fixed_canonical and prompt_canonical_candidates and completion_idx < len(prompt_canonical_candidates):
|
| 1384 |
canonical_graph = dict(prompt_canonical_candidates[completion_idx])
|
|
|
|
| 1421 |
{
|
| 1422 |
"candidate_index": completion_idx,
|
| 1423 |
"question": candidate.question,
|
| 1424 |
+
"tool_trace": materialized_tool_trace,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1425 |
"replayed_edges": validation.to_dict()["replayed_edges"],
|
| 1426 |
}
|
| 1427 |
)
|
|
|
|
| 1447 |
"difficulty": "hard",
|
| 1448 |
"scenario": "swarm_v2_trace",
|
| 1449 |
"canonical_graph": canonical_graph,
|
| 1450 |
+
"tool_trace": materialized_tool_trace,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1451 |
"subagent_outputs": list(candidate.subagent_outputs),
|
| 1452 |
"validation": validation.to_dict(),
|
| 1453 |
"shared_context_budget": {
|
|
|
|
| 1468 |
task_type=candidate.task_type or "swarm_v2_trace",
|
| 1469 |
question=candidate.question,
|
| 1470 |
answer=candidate.answer,
|
| 1471 |
+
supporting_edges=replay_edges,
|
| 1472 |
metadata=metadata,
|
| 1473 |
)
|
| 1474 |
)
|
|
|
|
| 1494 |
seed_tasks = list(env.tasks)
|
| 1495 |
seed_questions = [task.question for task in seed_tasks]
|
| 1496 |
generator_model, answerer_model = _resolve_initial_models(training_config)
|
| 1497 |
+
initial_generator_model = str(generator_model)
|
| 1498 |
+
initial_answerer_model = str(answerer_model)
|
| 1499 |
rng = random.Random(env_config.seed)
|
| 1500 |
|
| 1501 |
bootstrap_completions = _fallback_swarm_v2_completion_texts(
|
|
|
|
| 1747 |
}
|
| 1748 |
)
|
| 1749 |
|
| 1750 |
+
post_training_evaluation = _run_post_training_evaluation(
|
| 1751 |
+
env_config=env_config,
|
| 1752 |
+
training_config=training_config,
|
| 1753 |
+
generator_model=generator_model,
|
| 1754 |
+
answerer_models={
|
| 1755 |
+
"finetuned_answerer": answerer_model,
|
| 1756 |
+
"original_answerer": initial_answerer_model,
|
| 1757 |
+
},
|
| 1758 |
+
output_dir=run_dir,
|
| 1759 |
+
pipeline_mode="swarm_v2",
|
| 1760 |
+
effective_dry_run=effective_dry_run,
|
| 1761 |
+
)
|
| 1762 |
final_payload = {
|
| 1763 |
"dry_run": effective_dry_run,
|
| 1764 |
"pipeline_mode": "swarm_v2",
|
|
|
|
| 1772 |
"generator": generator_model,
|
| 1773 |
"answerer": answerer_model,
|
| 1774 |
},
|
| 1775 |
+
"initial_models": {
|
| 1776 |
+
"generator": initial_generator_model,
|
| 1777 |
+
"answerer": initial_answerer_model,
|
| 1778 |
+
},
|
| 1779 |
+
"post_training_evaluation": post_training_evaluation,
|
| 1780 |
"kimi_objective_mapping": {
|
| 1781 |
"grouped_rollouts": "TRL GRPO num_generations",
|
| 1782 |
"mean_centered_advantage": "GRPO relative reward baseline",
|
|
|
|
| 1818 |
seed_tasks = list(env.tasks)
|
| 1819 |
|
| 1820 |
generator_model, answerer_model = _resolve_initial_models(training_config)
|
| 1821 |
+
initial_generator_model = str(generator_model)
|
| 1822 |
+
initial_answerer_model = str(answerer_model)
|
| 1823 |
|
| 1824 |
rng = random.Random(env_config.seed)
|
| 1825 |
rounds_payload: list[dict[str, Any]] = []
|
|
|
|
| 1940 |
round_index=round_index,
|
| 1941 |
count=training_config.generated_tasks_per_round,
|
| 1942 |
max_support_edges=training_config.max_support_edges,
|
| 1943 |
+
max_new_tokens=training_config.generated_task_max_new_tokens,
|
| 1944 |
)
|
| 1945 |
if not generated_tasks:
|
| 1946 |
generated_tasks = _fallback_generated_tasks(
|
|
|
|
| 2025 |
}
|
| 2026 |
)
|
| 2027 |
|
| 2028 |
+
post_training_evaluation = _run_post_training_evaluation(
|
| 2029 |
+
env_config=env_config,
|
| 2030 |
+
training_config=training_config,
|
| 2031 |
+
generator_model=generator_model,
|
| 2032 |
+
answerer_models={
|
| 2033 |
+
"finetuned_answerer": answerer_model,
|
| 2034 |
+
"original_answerer": initial_answerer_model,
|
| 2035 |
+
},
|
| 2036 |
+
output_dir=run_dir,
|
| 2037 |
+
pipeline_mode="legacy",
|
| 2038 |
+
effective_dry_run=effective_dry_run,
|
| 2039 |
+
)
|
| 2040 |
final_payload = {
|
| 2041 |
"dry_run": effective_dry_run,
|
| 2042 |
"pipeline_mode": "legacy",
|
|
|
|
| 2050 |
"generator": generator_model,
|
| 2051 |
"answerer": answerer_model,
|
| 2052 |
},
|
| 2053 |
+
"initial_models": {
|
| 2054 |
+
"generator": initial_generator_model,
|
| 2055 |
+
"answerer": initial_answerer_model,
|
| 2056 |
+
},
|
| 2057 |
+
"post_training_evaluation": post_training_evaluation,
|
| 2058 |
"kimi_objective_mapping": {
|
| 2059 |
"grouped_rollouts": "TRL GRPO num_generations",
|
| 2060 |
"mean_centered_advantage": "GRPO relative reward baseline",
|
tests/test_environment.py
CHANGED
|
@@ -30,6 +30,22 @@ def test_search_memory_tool_returns_results_after_tool_use():
|
|
| 30 |
assert obs.tool_outputs[-1]["output"]["count"] >= 1
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def test_invalid_tool_call_does_not_crash_episode():
|
| 34 |
env = OSINTEnvironment(EnvironmentConfig(max_steps=4, seed=8))
|
| 35 |
env.reset()
|
|
|
|
| 30 |
assert obs.tool_outputs[-1]["output"]["count"] >= 1
|
| 31 |
|
| 32 |
|
| 33 |
+
def test_search_shared_context_returns_task_local_hits():
|
| 34 |
+
env = OSINTEnvironment(EnvironmentConfig(max_steps=6, seed=7))
|
| 35 |
+
obs = env.reset()
|
| 36 |
+
assert obs.task["shared_context_available"] is True
|
| 37 |
+
answer = str(env.state.task.answer if env.state else "")
|
| 38 |
+
|
| 39 |
+
obs, reward, done, _ = env.step(
|
| 40 |
+
Action(ActionType.CALL_TOOL, {"tool_name": "search_shared_context", "args": {"query": answer, "k": 5}})
|
| 41 |
+
)
|
| 42 |
+
assert done is False
|
| 43 |
+
assert isinstance(reward, float)
|
| 44 |
+
assert obs.tool_outputs[-1]["tool"] == "search_shared_context"
|
| 45 |
+
assert obs.tool_outputs[-1]["output"]["shared_context_available"] is True
|
| 46 |
+
assert obs.tool_outputs[-1]["output"]["count"] >= 1
|
| 47 |
+
|
| 48 |
+
|
| 49 |
def test_invalid_tool_call_does_not_crash_episode():
|
| 50 |
env = OSINTEnvironment(EnvironmentConfig(max_steps=4, seed=8))
|
| 51 |
env.reset()
|
tests/test_hf_jobs.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from osint_env.training.hf_jobs import (
|
| 2 |
+
DEFAULT_HF_JOB_IMAGE,
|
| 3 |
+
_build_job_command,
|
| 4 |
+
_default_train_output_dir,
|
| 5 |
+
_resolve_job_image,
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_resolve_job_image_prefers_explicit_image():
|
| 10 |
+
assert _resolve_job_image("python:3.12", "owner/space") == "python:3.12"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_resolve_job_image_supports_space_fallback():
|
| 14 |
+
assert _resolve_job_image("", "owner/space") == "hf.co/spaces/owner/space"
|
| 15 |
+
assert _resolve_job_image("", "") == DEFAULT_HF_JOB_IMAGE
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_default_train_output_dir_uses_bucket_mount_when_present():
|
| 19 |
+
assert _default_train_output_dir("my-bucket", "run-42") == "/training-outputs/run-42"
|
| 20 |
+
assert _default_train_output_dir("", "run-42") == "artifacts/run-42"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_build_job_command_runs_train_directly_when_image_has_code():
|
| 24 |
+
command = _build_job_command(
|
| 25 |
+
env_config_path="config/shared_config.json",
|
| 26 |
+
train_config_path="config/train.json",
|
| 27 |
+
output_dir="artifacts/self_play",
|
| 28 |
+
dry_run=False,
|
| 29 |
+
repo_url="",
|
| 30 |
+
repo_ref="",
|
| 31 |
+
repo_subdir="",
|
| 32 |
+
setup_command="",
|
| 33 |
+
)
|
| 34 |
+
assert command == [
|
| 35 |
+
"osint-env",
|
| 36 |
+
"train-self-play",
|
| 37 |
+
"--config",
|
| 38 |
+
"config/shared_config.json",
|
| 39 |
+
"--train-config",
|
| 40 |
+
"config/train.json",
|
| 41 |
+
"--train-output-dir",
|
| 42 |
+
"artifacts/self_play",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_build_job_command_bootstraps_repo_when_requested():
|
| 47 |
+
command = _build_job_command(
|
| 48 |
+
env_config_path="config/shared_config.json",
|
| 49 |
+
train_config_path="config/train.json",
|
| 50 |
+
output_dir="/training-outputs/run-1",
|
| 51 |
+
dry_run=True,
|
| 52 |
+
repo_url="https://github.com/example/osint-env.git",
|
| 53 |
+
repo_ref="main",
|
| 54 |
+
repo_subdir=".",
|
| 55 |
+
setup_command="python -m pip install flash-attn --no-build-isolation",
|
| 56 |
+
)
|
| 57 |
+
assert command[:2] == ["bash", "-lc"]
|
| 58 |
+
script = command[2]
|
| 59 |
+
assert "git clone --depth 1 --branch main https://github.com/example/osint-env.git /workspace/osint_env_app" in script
|
| 60 |
+
assert "python -m pip install -e '.[train]'" in script
|
| 61 |
+
assert "python -m pip install flash-attn --no-build-isolation" in script
|
| 62 |
+
assert "--train-config config/train.json" in script
|
| 63 |
+
assert "--train-output-dir /training-outputs/run-1" in script
|
| 64 |
+
assert "--dry-run" in script
|
tests/test_openai_baseline.py
CHANGED
|
@@ -7,6 +7,7 @@ def test_openai_baseline_toolset_contains_answer_and_graph_actions():
|
|
| 7 |
assert "submit_answer" in names
|
| 8 |
assert "add_edge" in names
|
| 9 |
assert "search_memory" in names
|
|
|
|
| 10 |
assert "get_post" in names
|
| 11 |
|
| 12 |
|
|
|
|
| 7 |
assert "submit_answer" in names
|
| 8 |
assert "add_edge" in names
|
| 9 |
assert "search_memory" in names
|
| 10 |
+
assert "search_shared_context" in names
|
| 11 |
assert "get_post" in names
|
| 12 |
|
| 13 |
|
tests/test_self_play_swarm_v2.py
CHANGED
|
@@ -195,8 +195,8 @@ def test_swarm_v2_replay_validator_accepts_valid_candidate_and_rejects_invalid_c
|
|
| 195 |
no_trace_payload = deepcopy(payload)
|
| 196 |
no_trace_payload["tool_trace"] = []
|
| 197 |
no_trace = validator.validate(parse_generated_task_completion(json.dumps(no_trace_payload)))
|
| 198 |
-
assert no_trace.is_valid is
|
| 199 |
-
assert
|
| 200 |
|
| 201 |
unseen_payload = deepcopy(payload)
|
| 202 |
unseen_payload["supporting_edges"][0]["dst"] = "user_missing"
|
|
@@ -205,6 +205,22 @@ def test_swarm_v2_replay_validator_accepts_valid_candidate_and_rejects_invalid_c
|
|
| 205 |
assert "unseen_nodes_or_edges" in unseen.reasons
|
| 206 |
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
def test_swarm_v2_replay_validator_rejects_non_unique_paths():
|
| 209 |
graph = CanonicalGraph(
|
| 210 |
nodes={
|
|
@@ -337,7 +353,9 @@ def test_swarm_v2_generator_reward_grades_invalid_outputs_instead_of_constant_pe
|
|
| 337 |
scores = reward_fn(completions=[missing_everything, partial_json, partial_edges, json.dumps(valid_payload)])
|
| 338 |
|
| 339 |
assert len(set(scores)) > 2
|
| 340 |
-
assert scores[
|
|
|
|
|
|
|
| 341 |
assert reward_fn._debug_last_batch["batch_reward_std"] > 0.0
|
| 342 |
assert reward_fn._debug_last_batch["valid_output_ratio"] == 0.25
|
| 343 |
|
|
@@ -373,6 +391,24 @@ def test_parse_generated_task_completion_handles_garbage_orchestrator_values():
|
|
| 373 |
assert candidate.orchestrator.depth == 0
|
| 374 |
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def test_swarm_v2_generator_reward_is_robust_to_parse_crashes():
|
| 377 |
"""Reward function must never raise: any malformed completion gets a floor reward."""
|
| 378 |
cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
|
|
@@ -432,6 +468,11 @@ def test_swarm_v2_dry_run_writes_new_artifacts_and_preserves_legacy_contract(tmp
|
|
| 432 |
loaded = json.loads(Path(artifacts[key]).read_text(encoding="utf-8"))
|
| 433 |
assert loaded is not None
|
| 434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
def test_swarm_v2_fixed_canonical_mode_reuses_prompt_candidates(tmp_path: Path):
|
| 437 |
env_cfg = EnvironmentConfig(seed=19, n_users=14, max_steps=6)
|
|
|
|
| 195 |
no_trace_payload = deepcopy(payload)
|
| 196 |
no_trace_payload["tool_trace"] = []
|
| 197 |
no_trace = validator.validate(parse_generated_task_completion(json.dumps(no_trace_payload)))
|
| 198 |
+
assert no_trace.is_valid is True
|
| 199 |
+
assert no_trace.replayed_edges
|
| 200 |
|
| 201 |
unseen_payload = deepcopy(payload)
|
| 202 |
unseen_payload["supporting_edges"][0]["dst"] = "user_missing"
|
|
|
|
| 205 |
assert "unseen_nodes_or_edges" in unseen.reasons
|
| 206 |
|
| 207 |
|
| 208 |
+
def test_swarm_v2_replay_validator_can_derive_tool_trace_from_support_edges():
|
| 209 |
+
cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
|
| 210 |
+
env = OSINTEnvironment(EnvironmentConfig(seed=27, n_users=18, max_steps=6))
|
| 211 |
+
payload = _build_valid_candidate_payload(env, cfg)
|
| 212 |
+
payload.pop("tool_trace", None)
|
| 213 |
+
|
| 214 |
+
validator = SwarmV2ReplayValidator(
|
| 215 |
+
graph=env.graph,
|
| 216 |
+
validation=cfg.swarm_v2.validation,
|
| 217 |
+
shared_context=cfg.swarm_v2.shared_context,
|
| 218 |
+
seen_questions=[],
|
| 219 |
+
)
|
| 220 |
+
result = validator.validate(parse_generated_task_completion(json.dumps(payload)))
|
| 221 |
+
assert result.is_valid is True
|
| 222 |
+
|
| 223 |
+
|
| 224 |
def test_swarm_v2_replay_validator_rejects_non_unique_paths():
|
| 225 |
graph = CanonicalGraph(
|
| 226 |
nodes={
|
|
|
|
| 353 |
scores = reward_fn(completions=[missing_everything, partial_json, partial_edges, json.dumps(valid_payload)])
|
| 354 |
|
| 355 |
assert len(set(scores)) > 2
|
| 356 |
+
assert scores[2] > scores[0]
|
| 357 |
+
assert scores[2] > scores[1]
|
| 358 |
+
assert scores[3] != scores[0]
|
| 359 |
assert reward_fn._debug_last_batch["batch_reward_std"] > 0.0
|
| 360 |
assert reward_fn._debug_last_batch["valid_output_ratio"] == 0.25
|
| 361 |
|
|
|
|
| 391 |
assert candidate.orchestrator.depth == 0
|
| 392 |
|
| 393 |
|
| 394 |
+
def test_parse_generated_task_completion_accepts_result_alias_in_tool_trace():
|
| 395 |
+
cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
|
| 396 |
+
env = OSINTEnvironment(EnvironmentConfig(seed=35, n_users=18, max_steps=6))
|
| 397 |
+
payload = _build_valid_candidate_payload(env, cfg)
|
| 398 |
+
payload["tool_trace"] = [
|
| 399 |
+
{
|
| 400 |
+
"tool": call["tool_name"],
|
| 401 |
+
"args": dict(call["args"]),
|
| 402 |
+
"result": dict(call["output"]),
|
| 403 |
+
}
|
| 404 |
+
for call in payload["tool_trace"]
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
candidate = parse_generated_task_completion(json.dumps(payload))
|
| 408 |
+
assert candidate.tool_trace
|
| 409 |
+
assert all(call.output for call in candidate.tool_trace)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
def test_swarm_v2_generator_reward_is_robust_to_parse_crashes():
|
| 413 |
"""Reward function must never raise: any malformed completion gets a floor reward."""
|
| 414 |
cfg = SelfPlayTrainingConfig(pipeline_mode="swarm_v2")
|
|
|
|
| 468 |
loaded = json.loads(Path(artifacts[key]).read_text(encoding="utf-8"))
|
| 469 |
assert loaded is not None
|
| 470 |
|
| 471 |
+
post_eval = payload["post_training_evaluation"]
|
| 472 |
+
assert Path(post_eval["path"]).exists()
|
| 473 |
+
assert sorted(post_eval["answerer_models"].keys()) == ["finetuned_answerer", "original_answerer"]
|
| 474 |
+
assert json.loads(Path(post_eval["path"]).read_text(encoding="utf-8"))["skipped"] is True
|
| 475 |
+
|
| 476 |
|
| 477 |
def test_swarm_v2_fixed_canonical_mode_reuses_prompt_candidates(tmp_path: Path):
|
| 478 |
env_cfg = EnvironmentConfig(seed=19, n_users=14, max_steps=6)
|
tests/test_swarm_agent.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from osint_env.agents.swarm_agent import SwarmAgentRunner
|
| 2 |
from osint_env.domain.models import EnvironmentConfig, SwarmConfig
|
| 3 |
from osint_env.env.environment import OSINTEnvironment
|
|
@@ -15,3 +16,27 @@ def test_swarm_runner_emits_spawn_telemetry():
|
|
| 15 |
assert info["spawn_count"] > 0
|
| 16 |
assert "spawn_auxiliary" in info["reward_components"]
|
| 17 |
assert info["spawn_critical_steps"] > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from osint_env.llm.interface import LLMResponse
|
| 2 |
from osint_env.agents.swarm_agent import SwarmAgentRunner
|
| 3 |
from osint_env.domain.models import EnvironmentConfig, SwarmConfig
|
| 4 |
from osint_env.env.environment import OSINTEnvironment
|
|
|
|
| 16 |
assert info["spawn_count"] > 0
|
| 17 |
assert "spawn_auxiliary" in info["reward_components"]
|
| 18 |
assert info["spawn_critical_steps"] > 0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RecordingLLM:
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.tool_names: list[str] = []
|
| 24 |
+
|
| 25 |
+
def generate(self, messages, tools):
|
| 26 |
+
del messages
|
| 27 |
+
self.tool_names = [tool["function"]["name"] for tool in tools]
|
| 28 |
+
return LLMResponse(content="{}", tool_calls=[])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_swarm_runner_passes_lookup_tools_to_llm():
|
| 32 |
+
config = EnvironmentConfig(
|
| 33 |
+
seed=16,
|
| 34 |
+
max_steps=6,
|
| 35 |
+
swarm=SwarmConfig(enabled=True, max_agents=2, max_breadth=2, max_width=2, max_depth=1, planner_rounds=1),
|
| 36 |
+
)
|
| 37 |
+
env = OSINTEnvironment(config)
|
| 38 |
+
llm = RecordingLLM()
|
| 39 |
+
SwarmAgentRunner(env, llm=llm).run_episode()
|
| 40 |
+
|
| 41 |
+
assert "search_memory" in llm.tool_names
|
| 42 |
+
assert "search_shared_context" in llm.tool_names
|
tests/test_training_config.py
CHANGED
|
@@ -14,6 +14,13 @@ def test_self_play_config_defaults_when_missing():
|
|
| 14 |
assert cfg.generator_phase.max_steps >= 1
|
| 15 |
assert cfg.answerer_phase.max_steps >= 1
|
| 16 |
assert cfg.generator_reward_weights.hardness > 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
assert cfg.swarm_v2.generator_swarm.shared_context is True
|
| 18 |
assert cfg.swarm_v2.validation.max_support_edges >= 1
|
| 19 |
assert cfg.wandb_enabled is False
|
|
@@ -41,6 +48,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
|
|
| 41 |
"shared_model_name_or_path": "/models/local-base",
|
| 42 |
"seed_tasks_per_round": 12,
|
| 43 |
"generated_tasks_per_round": 18,
|
|
|
|
|
|
|
|
|
|
| 44 |
"swarm_v2": {
|
| 45 |
"generator_swarm": {
|
| 46 |
"shared_context": True,
|
|
@@ -90,6 +100,12 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
|
|
| 90 |
"model_name_or_path": "Qwen/Qwen2.5-3B-Instruct",
|
| 91 |
"max_steps": 77,
|
| 92 |
"num_generations": 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
"loss_type": "grpo",
|
| 94 |
"scale_rewards": "group",
|
| 95 |
"output_subdir": "gen_phase",
|
|
@@ -98,6 +114,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
|
|
| 98 |
"model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct",
|
| 99 |
"max_steps": 55,
|
| 100 |
"num_generations": 5,
|
|
|
|
|
|
|
|
|
|
| 101 |
"output_subdir": "ans_phase",
|
| 102 |
},
|
| 103 |
}
|
|
@@ -121,6 +140,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
|
|
| 121 |
assert cfg.shared_model_name_or_path == "/models/local-base"
|
| 122 |
assert cfg.seed_tasks_per_round == 12
|
| 123 |
assert cfg.generated_tasks_per_round == 18
|
|
|
|
|
|
|
|
|
|
| 124 |
assert cfg.swarm_v2.generator_swarm.max_agents == 5
|
| 125 |
assert cfg.swarm_v2.answerer_swarm.max_agents == 4
|
| 126 |
assert cfg.swarm_v2.validation.max_support_edges == 6
|
|
@@ -133,6 +155,12 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
|
|
| 133 |
assert cfg.generator_phase.model_name_or_path == "Qwen/Qwen2.5-3B-Instruct"
|
| 134 |
assert cfg.generator_phase.max_steps == 77
|
| 135 |
assert cfg.generator_phase.num_generations == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
assert cfg.generator_phase.loss_type == "grpo"
|
| 137 |
assert cfg.generator_phase.scale_rewards == "group"
|
| 138 |
assert cfg.generator_phase.output_subdir == "gen_phase"
|
|
@@ -140,6 +168,9 @@ def test_self_play_config_parses_overrides(tmp_path: Path):
|
|
| 140 |
assert cfg.answerer_phase.model_name_or_path == "Qwen/Qwen2.5-1.5B-Instruct"
|
| 141 |
assert cfg.answerer_phase.max_steps == 55
|
| 142 |
assert cfg.answerer_phase.num_generations == 5
|
|
|
|
|
|
|
|
|
|
| 143 |
assert cfg.answerer_phase.output_subdir == "ans_phase"
|
| 144 |
|
| 145 |
|
|
|
|
| 14 |
assert cfg.generator_phase.max_steps >= 1
|
| 15 |
assert cfg.answerer_phase.max_steps >= 1
|
| 16 |
assert cfg.generator_reward_weights.hardness > 0.0
|
| 17 |
+
assert cfg.generated_task_max_new_tokens >= 32
|
| 18 |
+
assert cfg.post_training_eval_questions >= 1
|
| 19 |
+
assert cfg.generator_phase.optim == "adamw_torch_fused"
|
| 20 |
+
assert cfg.generator_phase.bf16 is True
|
| 21 |
+
assert cfg.generator_phase.tf32 is True
|
| 22 |
+
assert cfg.generator_phase.generation_batch_size >= 1
|
| 23 |
+
assert cfg.generator_phase.max_prompt_length >= 32
|
| 24 |
assert cfg.swarm_v2.generator_swarm.shared_context is True
|
| 25 |
assert cfg.swarm_v2.validation.max_support_edges >= 1
|
| 26 |
assert cfg.wandb_enabled is False
|
|
|
|
| 48 |
"shared_model_name_or_path": "/models/local-base",
|
| 49 |
"seed_tasks_per_round": 12,
|
| 50 |
"generated_tasks_per_round": 18,
|
| 51 |
+
"generated_task_max_new_tokens": 640,
|
| 52 |
+
"post_training_eval_questions": 9,
|
| 53 |
+
"post_training_eval_answer_max_new_tokens": 96,
|
| 54 |
"swarm_v2": {
|
| 55 |
"generator_swarm": {
|
| 56 |
"shared_context": True,
|
|
|
|
| 100 |
"model_name_or_path": "Qwen/Qwen2.5-3B-Instruct",
|
| 101 |
"max_steps": 77,
|
| 102 |
"num_generations": 6,
|
| 103 |
+
"optim": "adamw_torch",
|
| 104 |
+
"bf16": False,
|
| 105 |
+
"tf32": False,
|
| 106 |
+
"generation_batch_size": 12,
|
| 107 |
+
"max_prompt_length": 768,
|
| 108 |
+
"save_total_limit": 3,
|
| 109 |
"loss_type": "grpo",
|
| 110 |
"scale_rewards": "group",
|
| 111 |
"output_subdir": "gen_phase",
|
|
|
|
| 114 |
"model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct",
|
| 115 |
"max_steps": 55,
|
| 116 |
"num_generations": 5,
|
| 117 |
+
"dataloader_num_workers": 6,
|
| 118 |
+
"dataloader_persistent_workers": False,
|
| 119 |
+
"dataloader_prefetch_factor": 6,
|
| 120 |
"output_subdir": "ans_phase",
|
| 121 |
},
|
| 122 |
}
|
|
|
|
| 140 |
assert cfg.shared_model_name_or_path == "/models/local-base"
|
| 141 |
assert cfg.seed_tasks_per_round == 12
|
| 142 |
assert cfg.generated_tasks_per_round == 18
|
| 143 |
+
assert cfg.generated_task_max_new_tokens == 640
|
| 144 |
+
assert cfg.post_training_eval_questions == 9
|
| 145 |
+
assert cfg.post_training_eval_answer_max_new_tokens == 96
|
| 146 |
assert cfg.swarm_v2.generator_swarm.max_agents == 5
|
| 147 |
assert cfg.swarm_v2.answerer_swarm.max_agents == 4
|
| 148 |
assert cfg.swarm_v2.validation.max_support_edges == 6
|
|
|
|
| 155 |
assert cfg.generator_phase.model_name_or_path == "Qwen/Qwen2.5-3B-Instruct"
|
| 156 |
assert cfg.generator_phase.max_steps == 77
|
| 157 |
assert cfg.generator_phase.num_generations == 6
|
| 158 |
+
assert cfg.generator_phase.optim == "adamw_torch"
|
| 159 |
+
assert cfg.generator_phase.bf16 is False
|
| 160 |
+
assert cfg.generator_phase.tf32 is False
|
| 161 |
+
assert cfg.generator_phase.generation_batch_size == 12
|
| 162 |
+
assert cfg.generator_phase.max_prompt_length == 768
|
| 163 |
+
assert cfg.generator_phase.save_total_limit == 3
|
| 164 |
assert cfg.generator_phase.loss_type == "grpo"
|
| 165 |
assert cfg.generator_phase.scale_rewards == "group"
|
| 166 |
assert cfg.generator_phase.output_subdir == "gen_phase"
|
|
|
|
| 168 |
assert cfg.answerer_phase.model_name_or_path == "Qwen/Qwen2.5-1.5B-Instruct"
|
| 169 |
assert cfg.answerer_phase.max_steps == 55
|
| 170 |
assert cfg.answerer_phase.num_generations == 5
|
| 171 |
+
assert cfg.answerer_phase.dataloader_num_workers == 6
|
| 172 |
+
assert cfg.answerer_phase.dataloader_persistent_workers is False
|
| 173 |
+
assert cfg.answerer_phase.dataloader_prefetch_factor == 6
|
| 174 |
assert cfg.answerer_phase.output_subdir == "ans_phase"
|
| 175 |
|
| 176 |
|