Spaces:
Paused
Paused
| from __future__ import annotations | |
| import copy | |
| import json | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| from osint_env.domain.models import ( | |
| EnvironmentConfig, | |
| LLMConfig, | |
| NodeType, | |
| SeedingConfig, | |
| SeedEdgeSpec, | |
| SeedNodeSpec, | |
| SeedQuestionSpec, | |
| SpawnRewardConfig, | |
| SwarmConfig, | |
| ) | |
| class RuntimeDefaults: | |
| default_episodes: int = 20 | |
| leaderboard_path: str = "artifacts/leaderboard.json" | |
| dashboard_path: str = "artifacts/osint_dashboard.html" | |
| sweep_dashboard_dir: str = "artifacts/sweep_dashboards" | |
| class SharedConfig: | |
| environment: EnvironmentConfig = field(default_factory=EnvironmentConfig) | |
| runtime: RuntimeDefaults = field(default_factory=RuntimeDefaults) | |
| def clone_environment_config(config: EnvironmentConfig) -> EnvironmentConfig: | |
| return copy.deepcopy(config) | |
| def _as_dict(value: Any) -> dict[str, Any]: | |
| return value if isinstance(value, dict) else {} | |
| def _parse_int(value: Any, default: int) -> int: | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def _parse_float(value: Any, default: float) -> float: | |
| try: | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def _parse_bool(value: Any, default: bool) -> bool: | |
| if isinstance(value, bool): | |
| return value | |
| if isinstance(value, str): | |
| lowered = value.strip().lower() | |
| if lowered in {"1", "true", "yes", "y", "on"}: | |
| return True | |
| if lowered in {"0", "false", "no", "n", "off"}: | |
| return False | |
| return default | |
| def _parse_str_list(value: Any, default: list[str]) -> list[str]: | |
| if isinstance(value, list): | |
| items = [str(item).strip() for item in value if str(item).strip()] | |
| return items or list(default) | |
| if isinstance(value, str): | |
| items = [part.strip() for part in value.split(",") if part.strip()] | |
| return items or list(default) | |
| return list(default) | |
| def _infer_node_type(node_id: str) -> NodeType: | |
| prefix = str(node_id).split("_", 1)[0].lower() | |
| mapping = { | |
| "user": NodeType.USER, | |
| "alias": NodeType.ALIAS, | |
| "org": NodeType.ORG, | |
| "loc": NodeType.LOCATION, | |
| "location": NodeType.LOCATION, | |
| "post": NodeType.POST, | |
| "thr": NodeType.THREAD, | |
| "thread": NodeType.THREAD, | |
| "event": NodeType.EVENT, | |
| } | |
| return mapping.get(prefix, NodeType.USER) | |
| def _parse_node_type(value: Any, node_id: str) -> NodeType: | |
| if isinstance(value, NodeType): | |
| return value | |
| if isinstance(value, str): | |
| raw = value.strip().lower() | |
| try: | |
| return NodeType(raw) | |
| except ValueError: | |
| return _infer_node_type(node_id) | |
| return _infer_node_type(node_id) | |
| def _parse_seed_edge(item: dict[str, Any]) -> SeedEdgeSpec | None: | |
| src = str(item.get("src", "")).strip() | |
| rel = str(item.get("rel", "")).strip() | |
| dst = str(item.get("dst", "")).strip() | |
| if not src or not rel or not dst: | |
| return None | |
| confidence = _parse_float(item.get("confidence", 1.0), 1.0) | |
| return SeedEdgeSpec(src=src, rel=rel, dst=dst, confidence=confidence) | |
| def _parse_seeding(data: dict[str, Any]) -> SeedingConfig: | |
| seeded_nodes: list[SeedNodeSpec] = [] | |
| for item in data.get("seeded_nodes", []): | |
| row = _as_dict(item) | |
| node_id = str(row.get("node_id", "")).strip() | |
| if not node_id: | |
| continue | |
| node_type = _parse_node_type(row.get("node_type"), node_id) | |
| attrs = _as_dict(row.get("attrs")) | |
| seeded_nodes.append(SeedNodeSpec(node_id=node_id, node_type=node_type, attrs=attrs)) | |
| seeded_edges: list[SeedEdgeSpec] = [] | |
| for item in data.get("seeded_edges", []): | |
| edge = _parse_seed_edge(_as_dict(item)) | |
| if edge is not None: | |
| seeded_edges.append(edge) | |
| seeded_questions: list[SeedQuestionSpec] = [] | |
| for item in data.get("seeded_questions", []): | |
| row = _as_dict(item) | |
| question = str(row.get("question", "")).strip() | |
| if not question: | |
| continue | |
| answer_val = row.get("answer") | |
| answer = str(answer_val).strip() if answer_val is not None and str(answer_val).strip() else None | |
| task_type = str(row.get("task_type", "seeded")).strip() or "seeded" | |
| support_edges: list[SeedEdgeSpec] = [] | |
| for edge_item in row.get("supporting_edges", []): | |
| edge = _parse_seed_edge(_as_dict(edge_item)) | |
| if edge is not None: | |
| support_edges.append(edge) | |
| metadata = _as_dict(row.get("metadata")) | |
| seeded_questions.append( | |
| SeedQuestionSpec( | |
| question=question, | |
| answer=answer, | |
| task_type=task_type, | |
| supporting_edges=support_edges, | |
| metadata=metadata, | |
| ) | |
| ) | |
| return SeedingConfig( | |
| seeded_nodes=seeded_nodes, | |
| seeded_edges=seeded_edges, | |
| seeded_questions=seeded_questions, | |
| llm_generate_remaining_graph=_parse_bool(data.get("llm_generate_remaining_graph"), True), | |
| llm_generate_remaining_tasks=_parse_bool(data.get("llm_generate_remaining_tasks"), True), | |
| llm_generated_edge_budget=max(0, _parse_int(data.get("llm_generated_edge_budget"), 6)), | |
| llm_generated_task_budget=max(0, _parse_int(data.get("llm_generated_task_budget"), 8)), | |
| llm_generation_parallel=_parse_bool(data.get("llm_generation_parallel"), True), | |
| llm_generation_workers=max(1, _parse_int(data.get("llm_generation_workers"), 3)), | |
| llm_generation_retries=max(1, _parse_int(data.get("llm_generation_retries"), 2)), | |
| allow_template_fallback_on_llm_failure=_parse_bool( | |
| data.get("allow_template_fallback_on_llm_failure"), | |
| False, | |
| ), | |
| ) | |
| def load_seeding_config(path: str | Path) -> SeedingConfig: | |
| payload = json.loads(Path(path).read_text(encoding="utf-8")) | |
| if not isinstance(payload, dict): | |
| raise ValueError("Seed file must contain a JSON object.") | |
| source = _as_dict(payload.get("seeding", payload)) | |
| return _parse_seeding(source) | |
| def _parse_environment(payload: dict[str, Any]) -> EnvironmentConfig: | |
| env_data = _as_dict(payload.get("environment", payload)) | |
| dataset_data = _as_dict(payload.get("dataset", env_data.get("dataset", {}))) | |
| swarm_data = _as_dict(payload.get("swarm", env_data.get("swarm", {}))) | |
| spawn_data = _as_dict(payload.get("spawn_reward", env_data.get("spawn_reward", {}))) | |
| seeding_data = _as_dict(payload.get("seeding", env_data.get("seeding", {}))) | |
| llm_data = _as_dict(payload.get("llm", env_data.get("llm", {}))) | |
| dataset_mode = str(dataset_data.get("mode", env_data.get("dataset_mode", "canonical"))).strip().lower() | |
| if dataset_mode not in {"canonical", "metaqa"}: | |
| dataset_mode = "canonical" | |
| metaqa_variant = str(dataset_data.get("metaqa_variant", env_data.get("metaqa_variant", "vanilla"))).strip().lower() | |
| if metaqa_variant not in {"vanilla", "ntm"}: | |
| metaqa_variant = "vanilla" | |
| env = EnvironmentConfig( | |
| n_users=max(4, _parse_int(env_data.get("n_users"), 40)), | |
| alias_density=max(0.0, min(1.0, _parse_float(env_data.get("alias_density"), 0.35))), | |
| noise_level=max(0.0, min(1.0, _parse_float(env_data.get("noise_level"), 0.15))), | |
| red_herring_rate=max(0.0, min(1.0, _parse_float(env_data.get("red_herring_rate"), 0.1))), | |
| max_steps=max(2, _parse_int(env_data.get("max_steps"), 18)), | |
| seed=_parse_int(env_data.get("seed"), 7), | |
| dataset_mode=dataset_mode, | |
| metaqa_root=str(dataset_data.get("metaqa_root", env_data.get("metaqa_root", "metaQA"))).strip() or "metaQA", | |
| metaqa_kb_path=str(dataset_data.get("metaqa_kb_path", env_data.get("metaqa_kb_path", ""))).strip(), | |
| metaqa_variant=metaqa_variant, | |
| metaqa_hops=_parse_str_list( | |
| dataset_data.get("metaqa_hops", env_data.get("metaqa_hops", ["1-hop", "2-hop", "3-hop"])), | |
| ["1-hop", "2-hop", "3-hop"], | |
| ), | |
| metaqa_splits=_parse_str_list( | |
| dataset_data.get("metaqa_splits", env_data.get("metaqa_splits", ["train", "dev", "test"])), | |
| ["train", "dev", "test"], | |
| ), | |
| ) | |
| env.swarm = SwarmConfig( | |
| enabled=_parse_bool(swarm_data.get("enabled"), False), | |
| max_agents=max(1, _parse_int(swarm_data.get("max_agents"), 3)), | |
| max_breadth=max(1, _parse_int(swarm_data.get("max_breadth"), 2)), | |
| max_width=max(1, _parse_int(swarm_data.get("max_width"), 2)), | |
| max_depth=max(1, _parse_int(swarm_data.get("max_depth"), 2)), | |
| planner_rounds=max(1, _parse_int(swarm_data.get("planner_rounds"), 2)), | |
| tools_per_agent=max(1, _parse_int(swarm_data.get("tools_per_agent"), 1)), | |
| ) | |
| env.spawn_reward = SpawnRewardConfig( | |
| lambda_parallel=max(0.0, _parse_float(spawn_data.get("lambda_parallel"), 0.15)), | |
| lambda_finish=max(0.0, _parse_float(spawn_data.get("lambda_finish"), 0.2)), | |
| anneal=max(0.0, min(1.0, _parse_float(spawn_data.get("anneal"), 1.0))), | |
| max_parallel_hint=max(1, _parse_int(spawn_data.get("max_parallel_hint"), 3)), | |
| ) | |
| env.seeding = _parse_seeding(seeding_data) | |
| env.llm = LLMConfig( | |
| provider=str(llm_data.get("provider", "mock")).strip() or "mock", | |
| model=str(llm_data.get("model", "qwen3:2b")).strip() or "qwen3:2b", | |
| temperature=_parse_float(llm_data.get("temperature"), 0.1), | |
| max_tokens=max(1, _parse_int(llm_data.get("max_tokens"), 256)), | |
| timeout_seconds=max(1, _parse_int(llm_data.get("timeout_seconds"), 240)), | |
| ollama_base_url=str(llm_data.get("ollama_base_url", "http://127.0.0.1:11434")).strip() | |
| or "http://127.0.0.1:11434", | |
| openai_base_url=str(llm_data.get("openai_base_url", "https://api.openai.com/v1")).strip() | |
| or "https://api.openai.com/v1", | |
| openai_api_key_env=str(llm_data.get("openai_api_key_env", "OPENAI_API_KEY")).strip() or "OPENAI_API_KEY", | |
| openai_api_key=str(llm_data.get("openai_api_key", "")).strip(), | |
| ) | |
| return env | |
| def _parse_runtime(payload: dict[str, Any]) -> RuntimeDefaults: | |
| runtime = _as_dict(payload.get("runtime", {})) | |
| return RuntimeDefaults( | |
| default_episodes=max(1, _parse_int(runtime.get("default_episodes"), 20)), | |
| leaderboard_path=str(runtime.get("leaderboard_path", "artifacts/leaderboard.json")), | |
| dashboard_path=str(runtime.get("dashboard_path", "artifacts/osint_dashboard.html")), | |
| sweep_dashboard_dir=str(runtime.get("sweep_dashboard_dir", "artifacts/sweep_dashboards")), | |
| ) | |
| def load_shared_config(path: str | Path | None) -> SharedConfig: | |
| if not path: | |
| return SharedConfig() | |
| file_path = Path(path) | |
| if not file_path.exists(): | |
| return SharedConfig() | |
| payload = json.loads(file_path.read_text(encoding="utf-8")) | |
| if not isinstance(payload, dict): | |
| raise ValueError("Shared config file must contain a JSON object.") | |
| return SharedConfig(environment=_parse_environment(payload), runtime=_parse_runtime(payload)) | |