Spaces:
Running
Running
Update landmarkdiff/hyperparam.py to v0.3.2
Browse files- landmarkdiff/hyperparam.py +18 -35
landmarkdiff/hyperparam.py
CHANGED
|
@@ -99,45 +99,27 @@ class SearchSpace:
|
|
| 99 |
self.params: dict[str, ParamSpec] = {}
|
| 100 |
|
| 101 |
def add_float(
|
| 102 |
-
self,
|
| 103 |
-
name: str,
|
| 104 |
-
low: float,
|
| 105 |
-
high: float,
|
| 106 |
-
log_scale: bool = False,
|
| 107 |
) -> SearchSpace:
|
| 108 |
"""Add a continuous float parameter."""
|
| 109 |
self.params[name] = ParamSpec(
|
| 110 |
-
name=name,
|
| 111 |
-
param_type="float",
|
| 112 |
-
low=low,
|
| 113 |
-
high=high,
|
| 114 |
-
log_scale=log_scale,
|
| 115 |
)
|
| 116 |
return self
|
| 117 |
|
| 118 |
def add_int(
|
| 119 |
-
self,
|
| 120 |
-
name: str,
|
| 121 |
-
low: int,
|
| 122 |
-
high: int,
|
| 123 |
-
step: int = 1,
|
| 124 |
) -> SearchSpace:
|
| 125 |
"""Add an integer parameter."""
|
| 126 |
self.params[name] = ParamSpec(
|
| 127 |
-
name=name,
|
| 128 |
-
param_type="int",
|
| 129 |
-
low=low,
|
| 130 |
-
high=high,
|
| 131 |
-
step=step,
|
| 132 |
)
|
| 133 |
return self
|
| 134 |
|
| 135 |
def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
|
| 136 |
"""Add a categorical parameter."""
|
| 137 |
self.params[name] = ParamSpec(
|
| 138 |
-
name=name,
|
| 139 |
-
param_type="choice",
|
| 140 |
-
choices=choices,
|
| 141 |
)
|
| 142 |
return self
|
| 143 |
|
|
@@ -222,7 +204,10 @@ class HyperparamSearch:
|
|
| 222 |
attempts = 0
|
| 223 |
while len(trials) < n_trials and attempts < max_attempts:
|
| 224 |
attempts += 1
|
| 225 |
-
config = {
|
|
|
|
|
|
|
|
|
|
| 226 |
trial = Trial(
|
| 227 |
trial_id=f"trial_{len(trials):04d}",
|
| 228 |
config=config,
|
|
@@ -238,11 +223,14 @@ class HyperparamSearch:
|
|
| 238 |
import itertools
|
| 239 |
|
| 240 |
param_names = list(self.space.params.keys())
|
| 241 |
-
param_values = [
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
trials = []
|
| 244 |
for combo in itertools.product(*param_values):
|
| 245 |
-
config = dict(zip(param_names, combo
|
| 246 |
trial = Trial(
|
| 247 |
trial_id=f"trial_{len(trials):04d}",
|
| 248 |
config=config,
|
|
@@ -252,9 +240,7 @@ class HyperparamSearch:
|
|
| 252 |
return trials
|
| 253 |
|
| 254 |
def record_result(
|
| 255 |
-
self,
|
| 256 |
-
trial_id: str,
|
| 257 |
-
metrics: dict[str, float],
|
| 258 |
) -> None:
|
| 259 |
"""Record results for a trial."""
|
| 260 |
for trial in self.trials:
|
|
@@ -265,9 +251,7 @@ class HyperparamSearch:
|
|
| 265 |
raise KeyError(f"Trial {trial_id} not found")
|
| 266 |
|
| 267 |
def best_trial(
|
| 268 |
-
self,
|
| 269 |
-
metric: str = "loss",
|
| 270 |
-
lower_is_better: bool = True,
|
| 271 |
) -> Trial | None:
|
| 272 |
"""Get the best completed trial by a metric."""
|
| 273 |
completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
|
|
@@ -291,8 +275,7 @@ class HyperparamSearch:
|
|
| 291 |
with open(cfg_path, "w") as f:
|
| 292 |
yaml.safe_dump(
|
| 293 |
{"trial_id": trial.trial_id, **native_config},
|
| 294 |
-
f,
|
| 295 |
-
default_flow_style=False,
|
| 296 |
)
|
| 297 |
|
| 298 |
# Save summary index
|
|
@@ -338,7 +321,7 @@ class HyperparamSearch:
|
|
| 338 |
if isinstance(val, float):
|
| 339 |
parts.append(f"{val:>12.6f}")
|
| 340 |
else:
|
| 341 |
-
parts.append(f"{val
|
| 342 |
for m in metric_names:
|
| 343 |
val = trial.result.get(m, float("nan"))
|
| 344 |
parts.append(f"{val:>12.4f}")
|
|
|
|
| 99 |
self.params: dict[str, ParamSpec] = {}
|
| 100 |
|
| 101 |
def add_float(
|
| 102 |
+
self, name: str, low: float, high: float, log_scale: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
) -> SearchSpace:
|
| 104 |
"""Add a continuous float parameter."""
|
| 105 |
self.params[name] = ParamSpec(
|
| 106 |
+
name=name, param_type="float", low=low, high=high, log_scale=log_scale,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
)
|
| 108 |
return self
|
| 109 |
|
| 110 |
def add_int(
|
| 111 |
+
self, name: str, low: int, high: int, step: int = 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
) -> SearchSpace:
|
| 113 |
"""Add an integer parameter."""
|
| 114 |
self.params[name] = ParamSpec(
|
| 115 |
+
name=name, param_type="int", low=low, high=high, step=step,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
)
|
| 117 |
return self
|
| 118 |
|
| 119 |
def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
|
| 120 |
"""Add a categorical parameter."""
|
| 121 |
self.params[name] = ParamSpec(
|
| 122 |
+
name=name, param_type="choice", choices=choices,
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
return self
|
| 125 |
|
|
|
|
| 204 |
attempts = 0
|
| 205 |
while len(trials) < n_trials and attempts < max_attempts:
|
| 206 |
attempts += 1
|
| 207 |
+
config = {
|
| 208 |
+
name: spec.sample(rng)
|
| 209 |
+
for name, spec in self.space.params.items()
|
| 210 |
+
}
|
| 211 |
trial = Trial(
|
| 212 |
trial_id=f"trial_{len(trials):04d}",
|
| 213 |
config=config,
|
|
|
|
| 223 |
import itertools
|
| 224 |
|
| 225 |
param_names = list(self.space.params.keys())
|
| 226 |
+
param_values = [
|
| 227 |
+
self.space.params[name].grid_values(grid_points)
|
| 228 |
+
for name in param_names
|
| 229 |
+
]
|
| 230 |
|
| 231 |
trials = []
|
| 232 |
for combo in itertools.product(*param_values):
|
| 233 |
+
config = dict(zip(param_names, combo))
|
| 234 |
trial = Trial(
|
| 235 |
trial_id=f"trial_{len(trials):04d}",
|
| 236 |
config=config,
|
|
|
|
| 240 |
return trials
|
| 241 |
|
| 242 |
def record_result(
|
| 243 |
+
self, trial_id: str, metrics: dict[str, float],
|
|
|
|
|
|
|
| 244 |
) -> None:
|
| 245 |
"""Record results for a trial."""
|
| 246 |
for trial in self.trials:
|
|
|
|
| 251 |
raise KeyError(f"Trial {trial_id} not found")
|
| 252 |
|
| 253 |
def best_trial(
|
| 254 |
+
self, metric: str = "loss", lower_is_better: bool = True,
|
|
|
|
|
|
|
| 255 |
) -> Trial | None:
|
| 256 |
"""Get the best completed trial by a metric."""
|
| 257 |
completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
|
|
|
|
| 275 |
with open(cfg_path, "w") as f:
|
| 276 |
yaml.safe_dump(
|
| 277 |
{"trial_id": trial.trial_id, **native_config},
|
| 278 |
+
f, default_flow_style=False,
|
|
|
|
| 279 |
)
|
| 280 |
|
| 281 |
# Save summary index
|
|
|
|
| 321 |
if isinstance(val, float):
|
| 322 |
parts.append(f"{val:>12.6f}")
|
| 323 |
else:
|
| 324 |
+
parts.append(f"{str(val):>12s}")
|
| 325 |
for m in metric_names:
|
| 326 |
val = trial.result.get(m, float("nan"))
|
| 327 |
parts.append(f"{val:>12.4f}")
|