dreamlessx commited on
Commit
d2c39e0
·
verified ·
1 Parent(s): a681eea

Upload landmarkdiff/hyperparam.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/hyperparam.py +330 -0
landmarkdiff/hyperparam.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hyperparameter search utilities for systematic ControlNet tuning.
2
+
3
+ Supports grid search, random search, and Bayesian-inspired adaptive search
4
+ over training hyperparameters. Generates YAML configs for each trial and
5
+ tracks results for comparison.
6
+
7
+ Usage:
8
+ from landmarkdiff.hyperparam import HyperparamSearch, SearchSpace
9
+
10
+ space = SearchSpace()
11
+ space.add_float("learning_rate", 1e-6, 1e-4, log_scale=True)
12
+ space.add_choice("optimizer", ["adamw", "adam8bit"])
13
+ space.add_int("batch_size", 2, 8, step=2)
14
+
15
+ search = HyperparamSearch(space, output_dir="hp_search")
16
+ for trial in search.generate_trials(strategy="random", n_trials=20):
17
+ print(trial.config)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import hashlib
23
+ import json
24
+ import math
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Iterator
28
+
29
+
30
+ def _to_native(val: Any) -> Any:
31
+ """Convert numpy/non-standard types to native Python for YAML serialization."""
32
+ if hasattr(val, "item"): # numpy scalar
33
+ return val.item()
34
+ return val
35
+
36
+
37
+ @dataclass
38
+ class ParamSpec:
39
+ """Specification for a single hyperparameter."""
40
+
41
+ name: str
42
+ param_type: str # "float", "int", "choice"
43
+ low: float | None = None
44
+ high: float | None = None
45
+ step: float | None = None
46
+ log_scale: bool = False
47
+ choices: list[Any] | None = None
48
+
49
+ def sample(self, rng) -> Any:
50
+ """Sample a value from this parameter spec."""
51
+ if self.param_type == "choice":
52
+ return rng.choice(self.choices)
53
+ elif self.param_type == "float":
54
+ if self.log_scale:
55
+ log_low = math.log(self.low)
56
+ log_high = math.log(self.high)
57
+ return float(math.exp(rng.uniform(log_low, log_high)))
58
+ return float(rng.uniform(self.low, self.high))
59
+ elif self.param_type == "int":
60
+ if self.step and self.step > 1:
61
+ n_steps = int((self.high - self.low) / self.step) + 1
62
+ idx = rng.integers(0, n_steps)
63
+ return int(self.low + idx * self.step)
64
+ return int(rng.integers(int(self.low), int(self.high) + 1))
65
+ raise ValueError(f"Unknown param type: {self.param_type}")
66
+
67
+ def grid_values(self, n_points: int = 5) -> list[Any]:
68
+ """Generate grid values for this parameter."""
69
+ if self.param_type == "choice":
70
+ return list(self.choices)
71
+ elif self.param_type == "int":
72
+ if self.step and self.step > 1:
73
+ vals = []
74
+ v = self.low
75
+ while v <= self.high:
76
+ vals.append(int(v))
77
+ v += self.step
78
+ return vals
79
+ return list(range(int(self.low), int(self.high) + 1))
80
+ elif self.param_type == "float":
81
+ if self.log_scale:
82
+ log_low = math.log(self.low)
83
+ log_high = math.log(self.high)
84
+ return [
85
+ float(math.exp(log_low + i * (log_high - log_low) / (n_points - 1)))
86
+ for i in range(n_points)
87
+ ]
88
+ return [
89
+ float(self.low + i * (self.high - self.low) / (n_points - 1))
90
+ for i in range(n_points)
91
+ ]
92
+ return []
93
+
94
+
95
+ class SearchSpace:
96
+ """Define the hyperparameter search space."""
97
+
98
+ def __init__(self) -> None:
99
+ self.params: dict[str, ParamSpec] = {}
100
+
101
+ def add_float(
102
+ self, name: str, low: float, high: float, log_scale: bool = False,
103
+ ) -> SearchSpace:
104
+ """Add a continuous float parameter."""
105
+ self.params[name] = ParamSpec(
106
+ name=name, param_type="float", low=low, high=high, log_scale=log_scale,
107
+ )
108
+ return self
109
+
110
+ def add_int(
111
+ self, name: str, low: int, high: int, step: int = 1,
112
+ ) -> SearchSpace:
113
+ """Add an integer parameter."""
114
+ self.params[name] = ParamSpec(
115
+ name=name, param_type="int", low=low, high=high, step=step,
116
+ )
117
+ return self
118
+
119
+ def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
120
+ """Add a categorical parameter."""
121
+ self.params[name] = ParamSpec(
122
+ name=name, param_type="choice", choices=choices,
123
+ )
124
+ return self
125
+
126
+ def __len__(self) -> int:
127
+ return len(self.params)
128
+
129
+ def __contains__(self, name: str) -> bool:
130
+ return name in self.params
131
+
132
+
133
+ @dataclass
134
+ class Trial:
135
+ """A single hyperparameter trial."""
136
+
137
+ trial_id: str
138
+ config: dict[str, Any]
139
+ result: dict[str, float] = field(default_factory=dict)
140
+ status: str = "pending" # pending, running, completed, failed
141
+
142
+ @property
143
+ def config_hash(self) -> str:
144
+ """Short hash of the config for deduplication."""
145
+ s = json.dumps(self.config, sort_keys=True, default=str)
146
+ return hashlib.md5(s.encode()).hexdigest()[:8]
147
+
148
+
149
+ class HyperparamSearch:
150
+ """Hyperparameter search engine.
151
+
152
+ Args:
153
+ space: Search space definition.
154
+ output_dir: Directory to save trial configs and results.
155
+ seed: Random seed for reproducibility.
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ space: SearchSpace,
161
+ output_dir: str | Path = "hp_search",
162
+ seed: int = 42,
163
+ ) -> None:
164
+ self.space = space
165
+ self.output_dir = Path(output_dir)
166
+ self.seed = seed
167
+ self.trials: list[Trial] = []
168
+
169
+ def generate_trials(
170
+ self,
171
+ strategy: str = "random",
172
+ n_trials: int = 20,
173
+ grid_points: int = 5,
174
+ ) -> list[Trial]:
175
+ """Generate trial configurations.
176
+
177
+ Args:
178
+ strategy: "random" or "grid".
179
+ n_trials: Number of trials for random search.
180
+ grid_points: Points per continuous dimension for grid search.
181
+
182
+ Returns:
183
+ List of Trial objects with configs.
184
+ """
185
+ if strategy == "grid":
186
+ trials = self._grid_search(grid_points)
187
+ elif strategy == "random":
188
+ trials = self._random_search(n_trials)
189
+ else:
190
+ raise ValueError(f"Unknown strategy: {strategy}. Use 'random' or 'grid'.")
191
+
192
+ self.trials.extend(trials)
193
+ return trials
194
+
195
+ def _random_search(self, n_trials: int) -> list[Trial]:
196
+ """Generate random trial configs."""
197
+ import numpy as np
198
+
199
+ rng = np.random.default_rng(self.seed)
200
+ seen_hashes: set[str] = set()
201
+ trials: list[Trial] = []
202
+
203
+ max_attempts = n_trials * 10
204
+ attempts = 0
205
+ while len(trials) < n_trials and attempts < max_attempts:
206
+ attempts += 1
207
+ config = {
208
+ name: spec.sample(rng)
209
+ for name, spec in self.space.params.items()
210
+ }
211
+ trial = Trial(
212
+ trial_id=f"trial_{len(trials):04d}",
213
+ config=config,
214
+ )
215
+ if trial.config_hash not in seen_hashes:
216
+ seen_hashes.add(trial.config_hash)
217
+ trials.append(trial)
218
+
219
+ return trials
220
+
221
+ def _grid_search(self, grid_points: int) -> list[Trial]:
222
+ """Generate grid search configs."""
223
+ import itertools
224
+
225
+ param_names = list(self.space.params.keys())
226
+ param_values = [
227
+ self.space.params[name].grid_values(grid_points)
228
+ for name in param_names
229
+ ]
230
+
231
+ trials = []
232
+ for combo in itertools.product(*param_values):
233
+ config = dict(zip(param_names, combo))
234
+ trial = Trial(
235
+ trial_id=f"trial_{len(trials):04d}",
236
+ config=config,
237
+ )
238
+ trials.append(trial)
239
+
240
+ return trials
241
+
242
+ def record_result(
243
+ self, trial_id: str, metrics: dict[str, float],
244
+ ) -> None:
245
+ """Record results for a trial."""
246
+ for trial in self.trials:
247
+ if trial.trial_id == trial_id:
248
+ trial.result = metrics
249
+ trial.status = "completed"
250
+ return
251
+ raise KeyError(f"Trial {trial_id} not found")
252
+
253
+ def best_trial(
254
+ self, metric: str = "loss", lower_is_better: bool = True,
255
+ ) -> Trial | None:
256
+ """Get the best completed trial by a metric."""
257
+ completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
258
+ if not completed:
259
+ return None
260
+ return (min if lower_is_better else max)(completed, key=lambda t: t.result[metric])
261
+
262
+ def save_configs(self) -> Path:
263
+ """Save all trial configs as YAML files.
264
+
265
+ Returns:
266
+ Output directory path.
267
+ """
268
+ import yaml
269
+
270
+ self.output_dir.mkdir(parents=True, exist_ok=True)
271
+ for trial in self.trials:
272
+ cfg_path = self.output_dir / f"{trial.trial_id}.yaml"
273
+ # Convert numpy types to native Python for YAML serialization
274
+ native_config = {k: _to_native(v) for k, v in trial.config.items()}
275
+ with open(cfg_path, "w") as f:
276
+ yaml.safe_dump(
277
+ {"trial_id": trial.trial_id, **native_config},
278
+ f, default_flow_style=False,
279
+ )
280
+
281
+ # Save summary index
282
+ index = {
283
+ "seed": self.seed,
284
+ "n_trials": len(self.trials),
285
+ "params": {
286
+ name: {
287
+ "type": spec.param_type,
288
+ "low": spec.low,
289
+ "high": spec.high,
290
+ "choices": spec.choices,
291
+ "log_scale": spec.log_scale,
292
+ }
293
+ for name, spec in self.space.params.items()
294
+ },
295
+ }
296
+ with open(self.output_dir / "search_index.json", "w") as f:
297
+ json.dump(index, f, indent=2, default=str)
298
+
299
+ return self.output_dir
300
+
301
+ def results_table(self) -> str:
302
+ """Format results as a text table."""
303
+ completed = [t for t in self.trials if t.status == "completed"]
304
+ if not completed:
305
+ return "No completed trials."
306
+
307
+ # Collect all metric names
308
+ metric_names = sorted(set().union(*(t.result.keys() for t in completed)))
309
+ param_names = sorted(self.space.params.keys())
310
+
311
+ # Header
312
+ cols = ["Trial"] + param_names + metric_names
313
+ lines = [" | ".join(f"{c:>12s}" for c in cols)]
314
+ lines.append("-" * len(lines[0]))
315
+
316
+ # Rows
317
+ for trial in completed:
318
+ parts = [f"{trial.trial_id:>12s}"]
319
+ for p in param_names:
320
+ val = trial.config.get(p, "")
321
+ if isinstance(val, float):
322
+ parts.append(f"{val:>12.6f}")
323
+ else:
324
+ parts.append(f"{str(val):>12s}")
325
+ for m in metric_names:
326
+ val = trial.result.get(m, float("nan"))
327
+ parts.append(f"{val:>12.4f}")
328
+ lines.append(" | ".join(parts))
329
+
330
+ return "\n".join(lines)