OSINT / src /osint_env /training /config.py
siddeshwar-kagatikar
Sync current main to Hugging Face Space
fe1f842
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass(slots=True)
class KimiGRPOPhaseConfig:
"""Configuration for one GRPO phase in the alternating self-play loop."""
model_name_or_path: str = "Qwen/Qwen2.5-0.5B-Instruct"
learning_rate: float = 3e-6
max_steps: int = 64
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 4
num_generations: int = 4
max_completion_length: int = 256
temperature: float = 1.0
top_p: float = 1.0
repetition_penalty: float = 1.0
beta: float = 0.01
epsilon: float = 0.2
num_iterations: int = 1
loss_type: str = "dapo"
scale_rewards: str = "none"
logging_steps: int = 10
save_steps: int = 50
save_total_limit: int = 2
output_subdir: str = "phase"
optim: str = "adamw_torch_fused"
bf16: bool = True
tf32: bool = True
gradient_checkpointing: bool = False
dataloader_num_workers: int = 2
dataloader_persistent_workers: bool = True
dataloader_prefetch_factor: int = 2
generation_batch_size: int = 8
max_prompt_length: int = 1024
use_vllm: bool = False
vllm_mode: str = "colocate"
@dataclass(slots=True)
class GeneratorRewardWeights:
"""Weighted components for adversarial task-generator reward."""
validity: float = 0.45
hardness: float = 0.20
diversity: float = 0.15
consistency: float = 0.20
@dataclass(slots=True)
class LoraTuningConfig:
"""LoRA hyperparameters for parameter-efficient GRPO updates."""
r: int = 16
alpha: int = 32
dropout: float = 0.05
target_modules: list[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"])
bias: str = "none"
task_type: str = "CAUSAL_LM"
@dataclass(slots=True)
class SwarmV2SwarmConfig:
"""Config for one orchestrated swarm role inside the swarm_v2 pipeline."""
shared_context: bool = True
max_agents: int = 4
max_breadth: int = 3
max_depth: int = 2
planner_rounds: int = 2
tools_per_agent: int = 2
@dataclass(slots=True)
class SwarmV2ValidationConfig:
"""Validation and replay limits for swarm_v2 task generation."""
max_support_edges: int = 8
max_path_hops: int = 4
max_context_nodes: int = 14
max_context_edges: int = 8
duplicate_similarity_threshold: float = 0.8
@dataclass(slots=True)
class SwarmV2SharedContextConfig:
"""Shared context budgets used by both generator and answerer swarms."""
shared_by_default: bool = True
max_nodes: int = 14
max_edges: int = 8
target_pressure: float = 0.85
@dataclass(slots=True)
class SwarmV2Config:
"""Config block for the config-gated Swarm Self-Play v2 pipeline."""
generator_swarm: SwarmV2SwarmConfig = field(default_factory=SwarmV2SwarmConfig)
answerer_swarm: SwarmV2SwarmConfig = field(
default_factory=lambda: SwarmV2SwarmConfig(
shared_context=True,
max_agents=3,
max_breadth=2,
max_depth=2,
planner_rounds=2,
tools_per_agent=2,
)
)
validation: SwarmV2ValidationConfig = field(default_factory=SwarmV2ValidationConfig)
shared_context: SwarmV2SharedContextConfig = field(default_factory=SwarmV2SharedContextConfig)
@dataclass(slots=True)
class SelfPlayTrainingConfig:
"""Top-level adversarial self-play training configuration."""
rounds: int = 3
output_dir: str = "artifacts/self_play"
dry_run: bool = True
wandb_enabled: bool = False
wandb_project: str = "osint-self-play"
wandb_entity: str = ""
wandb_run_name_prefix: str = "self-play"
canonical_graph_mode: str = "generate"
pipeline_mode: str = "legacy"
model_topology: str = "dual"
phase_schedule: str = "generator_answerer"
tuning_mode: str = "full"
shared_model_name_or_path: str = ""
seed_tasks_per_round: int = 16
generated_tasks_per_round: int = 24
generator_prompts_per_round: int = 24
max_graph_context_nodes: int = 100
max_graph_context_edges: int = 100
max_support_edges: int = 8
answerer_judge_max_new_tokens: int = 48
generated_task_max_new_tokens: int = 512
post_training_eval_questions: int = 24
post_training_eval_answer_max_new_tokens: int = 128
generator_reward_weights: GeneratorRewardWeights = field(default_factory=GeneratorRewardWeights)
lora: LoraTuningConfig = field(default_factory=LoraTuningConfig)
swarm_v2: SwarmV2Config = field(default_factory=SwarmV2Config)
generator_phase: KimiGRPOPhaseConfig = field(
default_factory=lambda: KimiGRPOPhaseConfig(
output_subdir="generator",
learning_rate=5e-6,
max_completion_length=384,
)
)
answerer_phase: KimiGRPOPhaseConfig = field(
default_factory=lambda: KimiGRPOPhaseConfig(
output_subdir="answerer",
learning_rate=3e-6,
max_completion_length=192,
)
)
def _as_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
def _parse_int(value: Any, default: int, floor: int | None = None) -> int:
try:
out = int(value)
except (TypeError, ValueError):
out = default
if floor is not None:
out = max(floor, out)
return out
def _parse_float(value: Any, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _parse_bool(value: Any, default: bool) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
token = value.strip().lower()
if token in {"1", "true", "yes", "y", "on"}:
return True
if token in {"0", "false", "no", "n", "off"}:
return False
return default
def _parse_str_choice(value: Any, default: str, allowed: set[str]) -> str:
token = str(value).strip().lower()
if token in allowed:
return token
return default
def _parse_str_list(value: Any, fallback: list[str]) -> list[str]:
if isinstance(value, list):
out = [str(item).strip() for item in value if str(item).strip()]
return out or list(fallback)
if isinstance(value, str):
out = [item.strip() for item in value.split(",") if item.strip()]
return out or list(fallback)
return list(fallback)
def _parse_phase(data: dict[str, Any], fallback: KimiGRPOPhaseConfig) -> KimiGRPOPhaseConfig:
return KimiGRPOPhaseConfig(
model_name_or_path=str(data.get("model_name_or_path", fallback.model_name_or_path)).strip()
or fallback.model_name_or_path,
learning_rate=_parse_float(data.get("learning_rate"), fallback.learning_rate),
max_steps=_parse_int(data.get("max_steps"), fallback.max_steps, floor=1),
per_device_train_batch_size=_parse_int(
data.get("per_device_train_batch_size"),
fallback.per_device_train_batch_size,
floor=1,
),
gradient_accumulation_steps=_parse_int(
data.get("gradient_accumulation_steps"),
fallback.gradient_accumulation_steps,
floor=1,
),
num_generations=_parse_int(data.get("num_generations"), fallback.num_generations, floor=1),
max_completion_length=_parse_int(
data.get("max_completion_length"),
fallback.max_completion_length,
floor=1,
),
temperature=_parse_float(data.get("temperature"), fallback.temperature),
top_p=_parse_float(data.get("top_p"), fallback.top_p),
repetition_penalty=_parse_float(data.get("repetition_penalty"), fallback.repetition_penalty),
beta=_parse_float(data.get("beta"), fallback.beta),
epsilon=_parse_float(data.get("epsilon"), fallback.epsilon),
num_iterations=_parse_int(data.get("num_iterations"), fallback.num_iterations, floor=1),
loss_type=str(data.get("loss_type", fallback.loss_type)).strip() or fallback.loss_type,
scale_rewards=str(data.get("scale_rewards", fallback.scale_rewards)).strip() or fallback.scale_rewards,
logging_steps=_parse_int(data.get("logging_steps"), fallback.logging_steps, floor=1),
save_steps=_parse_int(data.get("save_steps"), fallback.save_steps, floor=1),
output_subdir=str(data.get("output_subdir", fallback.output_subdir)).strip() or fallback.output_subdir,
optim=str(data.get("optim", fallback.optim)).strip() or fallback.optim,
bf16=_parse_bool(data.get("bf16"), fallback.bf16),
tf32=_parse_bool(data.get("tf32"), fallback.tf32),
gradient_checkpointing=_parse_bool(
data.get("gradient_checkpointing"),
fallback.gradient_checkpointing,
),
dataloader_num_workers=_parse_int(
data.get("dataloader_num_workers"),
fallback.dataloader_num_workers,
floor=0,
),
dataloader_persistent_workers=_parse_bool(
data.get("dataloader_persistent_workers"),
fallback.dataloader_persistent_workers,
),
dataloader_prefetch_factor=_parse_int(
data.get("dataloader_prefetch_factor"),
fallback.dataloader_prefetch_factor,
floor=1,
),
generation_batch_size=_parse_int(
data.get("generation_batch_size"),
fallback.generation_batch_size,
floor=1,
),
max_prompt_length=_parse_int(
data.get("max_prompt_length"),
fallback.max_prompt_length,
floor=32,
),
save_total_limit=_parse_int(
data.get("save_total_limit"),
fallback.save_total_limit,
floor=1,
),
use_vllm=_parse_bool(data.get("use_vllm"), fallback.use_vllm),
vllm_mode=str(data.get("vllm_mode", fallback.vllm_mode)).strip() or fallback.vllm_mode,
)
def _parse_generator_weights(data: dict[str, Any]) -> GeneratorRewardWeights:
return GeneratorRewardWeights(
validity=_parse_float(data.get("validity"), 0.45),
hardness=_parse_float(data.get("hardness"), 0.20),
diversity=_parse_float(data.get("diversity"), 0.15),
consistency=_parse_float(data.get("consistency"), 0.20),
)
def _parse_lora_config(data: dict[str, Any], fallback: LoraTuningConfig) -> LoraTuningConfig:
return LoraTuningConfig(
r=_parse_int(data.get("r"), fallback.r, floor=1),
alpha=_parse_int(data.get("alpha"), fallback.alpha, floor=1),
dropout=_parse_float(data.get("dropout"), fallback.dropout),
target_modules=_parse_str_list(data.get("target_modules"), fallback.target_modules),
bias=str(data.get("bias", fallback.bias)).strip() or fallback.bias,
task_type=str(data.get("task_type", fallback.task_type)).strip() or fallback.task_type,
)
def _parse_swarm_v2_swarm_config(
data: dict[str, Any],
fallback: SwarmV2SwarmConfig,
) -> SwarmV2SwarmConfig:
return SwarmV2SwarmConfig(
shared_context=_parse_bool(data.get("shared_context"), fallback.shared_context),
max_agents=_parse_int(data.get("max_agents"), fallback.max_agents, floor=1),
max_breadth=_parse_int(data.get("max_breadth"), fallback.max_breadth, floor=1),
max_depth=_parse_int(data.get("max_depth"), fallback.max_depth, floor=1),
planner_rounds=_parse_int(data.get("planner_rounds"), fallback.planner_rounds, floor=1),
tools_per_agent=_parse_int(data.get("tools_per_agent"), fallback.tools_per_agent, floor=1),
)
def _parse_swarm_v2_validation_config(
data: dict[str, Any],
fallback: SwarmV2ValidationConfig,
legacy_max_support_edges: int,
) -> SwarmV2ValidationConfig:
default_max_support_edges = (
_parse_int(data.get("max_support_edges"), legacy_max_support_edges, floor=1)
if "max_support_edges" not in data
else _parse_int(data.get("max_support_edges"), fallback.max_support_edges, floor=1)
)
return SwarmV2ValidationConfig(
max_support_edges=default_max_support_edges,
max_path_hops=_parse_int(data.get("max_path_hops"), fallback.max_path_hops, floor=1),
max_context_nodes=_parse_int(data.get("max_context_nodes"), fallback.max_context_nodes, floor=1),
max_context_edges=_parse_int(data.get("max_context_edges"), fallback.max_context_edges, floor=1),
duplicate_similarity_threshold=max(
0.0,
min(
1.0,
_parse_float(
data.get("duplicate_similarity_threshold"),
fallback.duplicate_similarity_threshold,
),
),
),
)
def _parse_swarm_v2_shared_context_config(
data: dict[str, Any],
fallback: SwarmV2SharedContextConfig,
) -> SwarmV2SharedContextConfig:
return SwarmV2SharedContextConfig(
shared_by_default=_parse_bool(data.get("shared_by_default"), fallback.shared_by_default),
max_nodes=_parse_int(data.get("max_nodes"), fallback.max_nodes, floor=1),
max_edges=_parse_int(data.get("max_edges"), fallback.max_edges, floor=1),
target_pressure=max(0.0, min(1.0, _parse_float(data.get("target_pressure"), fallback.target_pressure))),
)
def _parse_swarm_v2_config(
data: dict[str, Any],
fallback: SwarmV2Config,
legacy_max_support_edges: int,
) -> SwarmV2Config:
return SwarmV2Config(
generator_swarm=_parse_swarm_v2_swarm_config(
_as_dict(data.get("generator_swarm")),
fallback.generator_swarm,
),
answerer_swarm=_parse_swarm_v2_swarm_config(
_as_dict(data.get("answerer_swarm")),
fallback.answerer_swarm,
),
validation=_parse_swarm_v2_validation_config(
_as_dict(data.get("validation")),
fallback.validation,
legacy_max_support_edges=legacy_max_support_edges,
),
shared_context=_parse_swarm_v2_shared_context_config(
_as_dict(data.get("shared_context")),
fallback.shared_context,
),
)
def load_self_play_config(path: str | Path | None) -> SelfPlayTrainingConfig:
if not path:
return SelfPlayTrainingConfig()
file_path = Path(path)
if not file_path.exists():
return SelfPlayTrainingConfig()
payload = json.loads(file_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise ValueError("Self-play config file must contain a JSON object.")
defaults = SelfPlayTrainingConfig()
generator_phase = _parse_phase(_as_dict(payload.get("generator_phase")), defaults.generator_phase)
answerer_phase = _parse_phase(_as_dict(payload.get("answerer_phase")), defaults.answerer_phase)
lora_cfg = _parse_lora_config(_as_dict(payload.get("lora")), defaults.lora)
legacy_max_support_edges = _parse_int(payload.get("max_support_edges"), defaults.max_support_edges, floor=1)
swarm_v2_cfg = _parse_swarm_v2_config(
_as_dict(payload.get("swarm_v2")),
defaults.swarm_v2,
legacy_max_support_edges=legacy_max_support_edges,
)
return SelfPlayTrainingConfig(
rounds=_parse_int(payload.get("rounds"), defaults.rounds, floor=1),
output_dir=str(payload.get("output_dir", defaults.output_dir)).strip() or defaults.output_dir,
dry_run=_parse_bool(payload.get("dry_run"), defaults.dry_run),
wandb_enabled=_parse_bool(payload.get("wandb_enabled"), defaults.wandb_enabled),
wandb_project=str(payload.get("wandb_project", defaults.wandb_project)).strip() or defaults.wandb_project,
wandb_entity=str(payload.get("wandb_entity", defaults.wandb_entity)).strip(),
wandb_run_name_prefix=str(payload.get("wandb_run_name_prefix", defaults.wandb_run_name_prefix)).strip()
or defaults.wandb_run_name_prefix,
canonical_graph_mode=_parse_str_choice(
payload.get("canonical_graph_mode"),
defaults.canonical_graph_mode,
{"generate", "fixed"},
),
pipeline_mode=_parse_str_choice(
payload.get("pipeline_mode"),
defaults.pipeline_mode,
{"legacy", "swarm_v2"},
),
model_topology=_parse_str_choice(
payload.get("model_topology"),
defaults.model_topology,
{"dual", "shared"},
),
phase_schedule=_parse_str_choice(
payload.get("phase_schedule"),
defaults.phase_schedule,
{"generator_answerer", "answerer_generator_answerer"},
),
tuning_mode=_parse_str_choice(
payload.get("tuning_mode"),
defaults.tuning_mode,
{"full", "lora"},
),
shared_model_name_or_path=str(
payload.get("shared_model_name_or_path", defaults.shared_model_name_or_path)
).strip(),
seed_tasks_per_round=_parse_int(
payload.get("seed_tasks_per_round"),
defaults.seed_tasks_per_round,
floor=1,
),
generated_tasks_per_round=_parse_int(
payload.get("generated_tasks_per_round"),
defaults.generated_tasks_per_round,
floor=1,
),
generator_prompts_per_round=_parse_int(
payload.get("generator_prompts_per_round"),
defaults.generator_prompts_per_round,
floor=1,
),
max_graph_context_nodes=_parse_int(
payload.get("max_graph_context_nodes"),
defaults.max_graph_context_nodes,
floor=1,
),
max_graph_context_edges=_parse_int(
payload.get("max_graph_context_edges"),
defaults.max_graph_context_edges,
floor=1,
),
max_support_edges=legacy_max_support_edges,
answerer_judge_max_new_tokens=_parse_int(
payload.get("answerer_judge_max_new_tokens"),
defaults.answerer_judge_max_new_tokens,
floor=1,
),
generated_task_max_new_tokens=_parse_int(
payload.get("generated_task_max_new_tokens"),
defaults.generated_task_max_new_tokens,
floor=32,
),
post_training_eval_questions=_parse_int(
payload.get("post_training_eval_questions"),
defaults.post_training_eval_questions,
floor=1,
),
post_training_eval_answer_max_new_tokens=_parse_int(
payload.get("post_training_eval_answer_max_new_tokens"),
defaults.post_training_eval_answer_max_new_tokens,
floor=1,
),
generator_reward_weights=_parse_generator_weights(
_as_dict(payload.get("generator_reward_weights"))
),
lora=lora_cfg,
swarm_v2=swarm_v2_cfg,
generator_phase=generator_phase,
answerer_phase=answerer_phase,
)