"""Configuration loader for ReMDM-MiniHack. Loads YAML configs with deep-merge and CLI override support, following the Craftax config pattern. """ from __future__ import annotations import logging import os import secrets from datetime import datetime, timezone from pathlib import Path from types import SimpleNamespace import yaml logger = logging.getLogger(__name__) _PROJECT_ROOT = Path(__file__).resolve().parent.parent def _deep_merge(base: dict, override: dict) -> dict: """Recursively merge *override* into *base* (mutates *base*). Args: base: Base dictionary to merge into. override: Dictionary whose values take precedence. Returns: The merged dictionary (same object as *base*). """ for key, value in override.items(): if ( key in base and isinstance(base[key], dict) and isinstance(value, dict) ): _deep_merge(base[key], value) else: base[key] = value return base def _cast_value(value: str) -> int | float | bool | str | None: """Best-effort cast of a CLI string to a Python scalar. Args: value: Raw string from the command line. Returns: Parsed Python value (int, float, bool, str, or None). """ if value.lower() in ("true", "yes"): return True if value.lower() in ("false", "no"): return False if value.lower() == "null": return None try: return int(value) except ValueError: pass try: return float(value) except ValueError: pass return value def load_config( config_path: str | None = None, cli_overrides: dict | None = None, ) -> SimpleNamespace: """Load configuration from YAML with optional overrides. 1. Load ``configs/defaults.yaml``. 2. Deep-merge *config_path* on top (if provided and different from defaults). 3. Apply *cli_overrides* key=value pairs. 4. Auto-select device (``cuda`` if available, else ``cpu``; honour ``DEVICE`` env-var). 5. Validate invariants. Args: config_path: Path to a YAML file merged on top of defaults. ``None`` uses defaults only. cli_overrides: ``{key: value}`` pairs applied last. Returns: A ``SimpleNamespace`` containing all hyperparameters. Raises: AssertionError: If ``mask_token != action_dim`` or ``pad_token != action_dim + 1``. """ if cli_overrides is None: cli_overrides = {} defaults_path = _PROJECT_ROOT / "configs" / "defaults.yaml" with open(defaults_path, "r") as fh: cfg = yaml.safe_load(fh) if config_path is not None: config_path_resolved = Path(config_path) if not config_path_resolved.is_absolute(): config_path_resolved = _PROJECT_ROOT / config_path_resolved if config_path_resolved.resolve() != defaults_path.resolve(): with open(config_path_resolved, "r") as fh: overrides = yaml.safe_load(fh) or {} _deep_merge(cfg, overrides) for key, value in cli_overrides.items(): if isinstance(value, str): value = _cast_value(value) cfg[key] = value # Device selection env_device = os.environ.get("DEVICE") if env_device: cfg["device"] = env_device elif "device" not in cfg: try: import torch cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: cfg["device"] = "cpu" ns = SimpleNamespace(**cfg) # Validation assert ns.mask_token == ns.action_dim, ( f"mask_token ({ns.mask_token}) must equal action_dim ({ns.action_dim})" ) assert ns.pad_token == ns.action_dim + 1, ( f"pad_token ({ns.pad_token}) must equal action_dim + 1 " f"({ns.action_dim + 1})" ) return ns def make_run_dir(cfg: SimpleNamespace, tag: str = "run") -> Path: """Create a unique run subdirectory under ``cfg.checkpoint_dir``. Generates a directory named ``{tag}_{YYYYMMDD}_{HHMMSS}_{hex4}`` to prevent concurrent runs from overwriting each other's checkpoints. Updates ``cfg.checkpoint_dir`` in place. Args: cfg: Config namespace (``checkpoint_dir`` is mutated). tag: Prefix for the directory name (e.g. ``"dagger"``, ``"offline"``). Returns: The created directory path. """ ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") suffix = secrets.token_hex(2) run_dir = Path(cfg.checkpoint_dir).resolve() / f"{tag}_{ts}_{suffix}" run_dir.mkdir(parents=True, exist_ok=True) cfg.checkpoint_dir = str(run_dir) logger.info("Checkpoint directory: %s", run_dir) return run_dir