from __future__ import annotations import copy import json from dataclasses import dataclass, field from pathlib import Path from typing import Any from osint_env.domain.models import ( EnvironmentConfig, LLMConfig, NodeType, SeedingConfig, SeedEdgeSpec, SeedNodeSpec, SeedQuestionSpec, SpawnRewardConfig, SwarmConfig, ) @dataclass(slots=True) class RuntimeDefaults: default_episodes: int = 20 leaderboard_path: str = "artifacts/leaderboard.json" dashboard_path: str = "artifacts/osint_dashboard.html" sweep_dashboard_dir: str = "artifacts/sweep_dashboards" @dataclass(slots=True) class SharedConfig: environment: EnvironmentConfig = field(default_factory=EnvironmentConfig) runtime: RuntimeDefaults = field(default_factory=RuntimeDefaults) def clone_environment_config(config: EnvironmentConfig) -> EnvironmentConfig: return copy.deepcopy(config) def _as_dict(value: Any) -> dict[str, Any]: return value if isinstance(value, dict) else {} def _parse_int(value: Any, default: int) -> int: try: return int(value) except (TypeError, ValueError): return default def _parse_float(value: Any, default: float) -> float: try: return float(value) except (TypeError, ValueError): return default def _parse_bool(value: Any, default: bool) -> bool: if isinstance(value, bool): return value if isinstance(value, str): lowered = value.strip().lower() if lowered in {"1", "true", "yes", "y", "on"}: return True if lowered in {"0", "false", "no", "n", "off"}: return False return default def _parse_str_list(value: Any, default: list[str]) -> list[str]: if isinstance(value, list): items = [str(item).strip() for item in value if str(item).strip()] return items or list(default) if isinstance(value, str): items = [part.strip() for part in value.split(",") if part.strip()] return items or list(default) return list(default) def _infer_node_type(node_id: str) -> NodeType: prefix = str(node_id).split("_", 1)[0].lower() mapping = { "user": NodeType.USER, "alias": NodeType.ALIAS, "org": NodeType.ORG, "loc": NodeType.LOCATION, "location": NodeType.LOCATION, "post": NodeType.POST, "thr": NodeType.THREAD, "thread": NodeType.THREAD, "event": NodeType.EVENT, } return mapping.get(prefix, NodeType.USER) def _parse_node_type(value: Any, node_id: str) -> NodeType: if isinstance(value, NodeType): return value if isinstance(value, str): raw = value.strip().lower() try: return NodeType(raw) except ValueError: return _infer_node_type(node_id) return _infer_node_type(node_id) def _parse_seed_edge(item: dict[str, Any]) -> SeedEdgeSpec | None: src = str(item.get("src", "")).strip() rel = str(item.get("rel", "")).strip() dst = str(item.get("dst", "")).strip() if not src or not rel or not dst: return None confidence = _parse_float(item.get("confidence", 1.0), 1.0) return SeedEdgeSpec(src=src, rel=rel, dst=dst, confidence=confidence) def _parse_seeding(data: dict[str, Any]) -> SeedingConfig: seeded_nodes: list[SeedNodeSpec] = [] for item in data.get("seeded_nodes", []): row = _as_dict(item) node_id = str(row.get("node_id", "")).strip() if not node_id: continue node_type = _parse_node_type(row.get("node_type"), node_id) attrs = _as_dict(row.get("attrs")) seeded_nodes.append(SeedNodeSpec(node_id=node_id, node_type=node_type, attrs=attrs)) seeded_edges: list[SeedEdgeSpec] = [] for item in data.get("seeded_edges", []): edge = _parse_seed_edge(_as_dict(item)) if edge is not None: seeded_edges.append(edge) seeded_questions: list[SeedQuestionSpec] = [] for item in data.get("seeded_questions", []): row = _as_dict(item) question = str(row.get("question", "")).strip() if not question: continue answer_val = row.get("answer") answer = str(answer_val).strip() if answer_val is not None and str(answer_val).strip() else None task_type = str(row.get("task_type", "seeded")).strip() or "seeded" support_edges: list[SeedEdgeSpec] = [] for edge_item in row.get("supporting_edges", []): edge = _parse_seed_edge(_as_dict(edge_item)) if edge is not None: support_edges.append(edge) metadata = _as_dict(row.get("metadata")) seeded_questions.append( SeedQuestionSpec( question=question, answer=answer, task_type=task_type, supporting_edges=support_edges, metadata=metadata, ) ) return SeedingConfig( seeded_nodes=seeded_nodes, seeded_edges=seeded_edges, seeded_questions=seeded_questions, llm_generate_remaining_graph=_parse_bool(data.get("llm_generate_remaining_graph"), True), llm_generate_remaining_tasks=_parse_bool(data.get("llm_generate_remaining_tasks"), True), llm_generated_edge_budget=max(0, _parse_int(data.get("llm_generated_edge_budget"), 6)), llm_generated_task_budget=max(0, _parse_int(data.get("llm_generated_task_budget"), 8)), llm_generation_parallel=_parse_bool(data.get("llm_generation_parallel"), True), llm_generation_workers=max(1, _parse_int(data.get("llm_generation_workers"), 3)), llm_generation_retries=max(1, _parse_int(data.get("llm_generation_retries"), 2)), allow_template_fallback_on_llm_failure=_parse_bool( data.get("allow_template_fallback_on_llm_failure"), False, ), ) def load_seeding_config(path: str | Path) -> SeedingConfig: payload = json.loads(Path(path).read_text(encoding="utf-8")) if not isinstance(payload, dict): raise ValueError("Seed file must contain a JSON object.") source = _as_dict(payload.get("seeding", payload)) return _parse_seeding(source) def _parse_environment(payload: dict[str, Any]) -> EnvironmentConfig: env_data = _as_dict(payload.get("environment", payload)) dataset_data = _as_dict(payload.get("dataset", env_data.get("dataset", {}))) swarm_data = _as_dict(payload.get("swarm", env_data.get("swarm", {}))) spawn_data = _as_dict(payload.get("spawn_reward", env_data.get("spawn_reward", {}))) seeding_data = _as_dict(payload.get("seeding", env_data.get("seeding", {}))) llm_data = _as_dict(payload.get("llm", env_data.get("llm", {}))) dataset_mode = str(dataset_data.get("mode", env_data.get("dataset_mode", "canonical"))).strip().lower() if dataset_mode not in {"canonical", "metaqa"}: dataset_mode = "canonical" metaqa_variant = str(dataset_data.get("metaqa_variant", env_data.get("metaqa_variant", "vanilla"))).strip().lower() if metaqa_variant not in {"vanilla", "ntm"}: metaqa_variant = "vanilla" env = EnvironmentConfig( n_users=max(4, _parse_int(env_data.get("n_users"), 40)), alias_density=max(0.0, min(1.0, _parse_float(env_data.get("alias_density"), 0.35))), noise_level=max(0.0, min(1.0, _parse_float(env_data.get("noise_level"), 0.15))), red_herring_rate=max(0.0, min(1.0, _parse_float(env_data.get("red_herring_rate"), 0.1))), max_steps=max(2, _parse_int(env_data.get("max_steps"), 18)), seed=_parse_int(env_data.get("seed"), 7), dataset_mode=dataset_mode, metaqa_root=str(dataset_data.get("metaqa_root", env_data.get("metaqa_root", "metaQA"))).strip() or "metaQA", metaqa_kb_path=str(dataset_data.get("metaqa_kb_path", env_data.get("metaqa_kb_path", ""))).strip(), metaqa_variant=metaqa_variant, metaqa_hops=_parse_str_list( dataset_data.get("metaqa_hops", env_data.get("metaqa_hops", ["1-hop", "2-hop", "3-hop"])), ["1-hop", "2-hop", "3-hop"], ), metaqa_splits=_parse_str_list( dataset_data.get("metaqa_splits", env_data.get("metaqa_splits", ["train", "dev", "test"])), ["train", "dev", "test"], ), ) env.swarm = SwarmConfig( enabled=_parse_bool(swarm_data.get("enabled"), False), max_agents=max(1, _parse_int(swarm_data.get("max_agents"), 3)), max_breadth=max(1, _parse_int(swarm_data.get("max_breadth"), 2)), max_width=max(1, _parse_int(swarm_data.get("max_width"), 2)), max_depth=max(1, _parse_int(swarm_data.get("max_depth"), 2)), planner_rounds=max(1, _parse_int(swarm_data.get("planner_rounds"), 2)), tools_per_agent=max(1, _parse_int(swarm_data.get("tools_per_agent"), 1)), ) env.spawn_reward = SpawnRewardConfig( lambda_parallel=max(0.0, _parse_float(spawn_data.get("lambda_parallel"), 0.15)), lambda_finish=max(0.0, _parse_float(spawn_data.get("lambda_finish"), 0.2)), anneal=max(0.0, min(1.0, _parse_float(spawn_data.get("anneal"), 1.0))), max_parallel_hint=max(1, _parse_int(spawn_data.get("max_parallel_hint"), 3)), ) env.seeding = _parse_seeding(seeding_data) env.llm = LLMConfig( provider=str(llm_data.get("provider", "mock")).strip() or "mock", model=str(llm_data.get("model", "qwen3:2b")).strip() or "qwen3:2b", temperature=_parse_float(llm_data.get("temperature"), 0.1), max_tokens=max(1, _parse_int(llm_data.get("max_tokens"), 256)), timeout_seconds=max(1, _parse_int(llm_data.get("timeout_seconds"), 240)), ollama_base_url=str(llm_data.get("ollama_base_url", "http://127.0.0.1:11434")).strip() or "http://127.0.0.1:11434", openai_base_url=str(llm_data.get("openai_base_url", "https://api.openai.com/v1")).strip() or "https://api.openai.com/v1", openai_api_key_env=str(llm_data.get("openai_api_key_env", "OPENAI_API_KEY")).strip() or "OPENAI_API_KEY", openai_api_key=str(llm_data.get("openai_api_key", "")).strip(), ) return env def _parse_runtime(payload: dict[str, Any]) -> RuntimeDefaults: runtime = _as_dict(payload.get("runtime", {})) return RuntimeDefaults( default_episodes=max(1, _parse_int(runtime.get("default_episodes"), 20)), leaderboard_path=str(runtime.get("leaderboard_path", "artifacts/leaderboard.json")), dashboard_path=str(runtime.get("dashboard_path", "artifacts/osint_dashboard.html")), sweep_dashboard_dir=str(runtime.get("sweep_dashboard_dir", "artifacts/sweep_dashboards")), ) def load_shared_config(path: str | Path | None) -> SharedConfig: if not path: return SharedConfig() file_path = Path(path) if not file_path.exists(): return SharedConfig() payload = json.loads(file_path.read_text(encoding="utf-8")) if not isinstance(payload, dict): raise ValueError("Shared config file must contain a JSON object.") return SharedConfig(environment=_parse_environment(payload), runtime=_parse_runtime(payload))