File size: 2,019 Bytes
b7ddfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import yaml
from pathlib import Path

@dataclass
class ModelConfig:
    n_layers: int = 2
    n_heads: int = 4
    d_model: int = 128
    max_length: int = 30
    state_dim: Optional[int] = None
    action_dim: int = 7

@dataclass
class DataConfig:
    env_id: str = "MiniGrid-Empty-8x8-v0"
    num_episodes: int = 1000
    collection_method: str = "PPO-Teacher"

@dataclass
class TrainConfig:
    lr: float = 1e-4
    epochs: int = 10
    seed: int = 42

@dataclass
class SAEConfig:
    expansion_factor: int = 8
    k: int = 32
    l1_coeff: float = 0.0005
    lr: float = 3e-4
    epochs: int = 5
    batch_size: int = 1024
    num_episodes: int = 100

@dataclass
class Config:
    model: ModelConfig = field(default_factory=ModelConfig)
    data: DataConfig = field(default_factory=DataConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    sae: SAEConfig = field(default_factory=SAEConfig)

    @classmethod
    def load_from_yaml(cls, yaml_path: str = "config.yaml") -> "Config":
        """Loads configuration from a YAML file, overriding defaults."""
        path = Path(yaml_path)
        if not path.exists():
            return cls()

        with open(path, "r") as f:
            data = yaml.safe_load(f)

        # Helper to safely update dataclass from dict
        def update_dataclass(dc_obj, dc_dict):
            for key, value in dc_dict.items():
                if hasattr(dc_obj, key):
                    setattr(dc_obj, key, value)

        config = cls()
        if "model" in data:
            update_dataclass(config.model, data["model"])
        if "data" in data:
            update_dataclass(config.data, data["data"])
        if "train" in data:
            update_dataclass(config.train, data["train"])
        if "sae" in data:
            update_dataclass(config.sae, data["sae"])
            
        return config

# Global config instance for easy access
cfg = Config.load_from_yaml()