dreamlessx commited on
Commit
0c568c7
·
verified ·
1 Parent(s): 28dc803

Update landmarkdiff/hyperparam.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/hyperparam.py +18 -35
landmarkdiff/hyperparam.py CHANGED
@@ -99,45 +99,27 @@ class SearchSpace:
99
  self.params: dict[str, ParamSpec] = {}
100
 
101
  def add_float(
102
- self,
103
- name: str,
104
- low: float,
105
- high: float,
106
- log_scale: bool = False,
107
  ) -> SearchSpace:
108
  """Add a continuous float parameter."""
109
  self.params[name] = ParamSpec(
110
- name=name,
111
- param_type="float",
112
- low=low,
113
- high=high,
114
- log_scale=log_scale,
115
  )
116
  return self
117
 
118
  def add_int(
119
- self,
120
- name: str,
121
- low: int,
122
- high: int,
123
- step: int = 1,
124
  ) -> SearchSpace:
125
  """Add an integer parameter."""
126
  self.params[name] = ParamSpec(
127
- name=name,
128
- param_type="int",
129
- low=low,
130
- high=high,
131
- step=step,
132
  )
133
  return self
134
 
135
  def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
136
  """Add a categorical parameter."""
137
  self.params[name] = ParamSpec(
138
- name=name,
139
- param_type="choice",
140
- choices=choices,
141
  )
142
  return self
143
 
@@ -222,7 +204,10 @@ class HyperparamSearch:
222
  attempts = 0
223
  while len(trials) < n_trials and attempts < max_attempts:
224
  attempts += 1
225
- config = {name: spec.sample(rng) for name, spec in self.space.params.items()}
 
 
 
226
  trial = Trial(
227
  trial_id=f"trial_{len(trials):04d}",
228
  config=config,
@@ -238,11 +223,14 @@ class HyperparamSearch:
238
  import itertools
239
 
240
  param_names = list(self.space.params.keys())
241
- param_values = [self.space.params[name].grid_values(grid_points) for name in param_names]
 
 
 
242
 
243
  trials = []
244
  for combo in itertools.product(*param_values):
245
- config = dict(zip(param_names, combo, strict=False))
246
  trial = Trial(
247
  trial_id=f"trial_{len(trials):04d}",
248
  config=config,
@@ -252,9 +240,7 @@ class HyperparamSearch:
252
  return trials
253
 
254
  def record_result(
255
- self,
256
- trial_id: str,
257
- metrics: dict[str, float],
258
  ) -> None:
259
  """Record results for a trial."""
260
  for trial in self.trials:
@@ -265,9 +251,7 @@ class HyperparamSearch:
265
  raise KeyError(f"Trial {trial_id} not found")
266
 
267
  def best_trial(
268
- self,
269
- metric: str = "loss",
270
- lower_is_better: bool = True,
271
  ) -> Trial | None:
272
  """Get the best completed trial by a metric."""
273
  completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
@@ -291,8 +275,7 @@ class HyperparamSearch:
291
  with open(cfg_path, "w") as f:
292
  yaml.safe_dump(
293
  {"trial_id": trial.trial_id, **native_config},
294
- f,
295
- default_flow_style=False,
296
  )
297
 
298
  # Save summary index
@@ -338,7 +321,7 @@ class HyperparamSearch:
338
  if isinstance(val, float):
339
  parts.append(f"{val:>12.6f}")
340
  else:
341
- parts.append(f"{val!s:>12s}")
342
  for m in metric_names:
343
  val = trial.result.get(m, float("nan"))
344
  parts.append(f"{val:>12.4f}")
 
99
  self.params: dict[str, ParamSpec] = {}
100
 
101
  def add_float(
102
+ self, name: str, low: float, high: float, log_scale: bool = False,
 
 
 
 
103
  ) -> SearchSpace:
104
  """Add a continuous float parameter."""
105
  self.params[name] = ParamSpec(
106
+ name=name, param_type="float", low=low, high=high, log_scale=log_scale,
 
 
 
 
107
  )
108
  return self
109
 
110
  def add_int(
111
+ self, name: str, low: int, high: int, step: int = 1,
 
 
 
 
112
  ) -> SearchSpace:
113
  """Add an integer parameter."""
114
  self.params[name] = ParamSpec(
115
+ name=name, param_type="int", low=low, high=high, step=step,
 
 
 
 
116
  )
117
  return self
118
 
119
  def add_choice(self, name: str, choices: list[Any]) -> SearchSpace:
120
  """Add a categorical parameter."""
121
  self.params[name] = ParamSpec(
122
+ name=name, param_type="choice", choices=choices,
 
 
123
  )
124
  return self
125
 
 
204
  attempts = 0
205
  while len(trials) < n_trials and attempts < max_attempts:
206
  attempts += 1
207
+ config = {
208
+ name: spec.sample(rng)
209
+ for name, spec in self.space.params.items()
210
+ }
211
  trial = Trial(
212
  trial_id=f"trial_{len(trials):04d}",
213
  config=config,
 
223
  import itertools
224
 
225
  param_names = list(self.space.params.keys())
226
+ param_values = [
227
+ self.space.params[name].grid_values(grid_points)
228
+ for name in param_names
229
+ ]
230
 
231
  trials = []
232
  for combo in itertools.product(*param_values):
233
+ config = dict(zip(param_names, combo))
234
  trial = Trial(
235
  trial_id=f"trial_{len(trials):04d}",
236
  config=config,
 
240
  return trials
241
 
242
  def record_result(
243
+ self, trial_id: str, metrics: dict[str, float],
 
 
244
  ) -> None:
245
  """Record results for a trial."""
246
  for trial in self.trials:
 
251
  raise KeyError(f"Trial {trial_id} not found")
252
 
253
  def best_trial(
254
+ self, metric: str = "loss", lower_is_better: bool = True,
 
 
255
  ) -> Trial | None:
256
  """Get the best completed trial by a metric."""
257
  completed = [t for t in self.trials if t.status == "completed" and metric in t.result]
 
275
  with open(cfg_path, "w") as f:
276
  yaml.safe_dump(
277
  {"trial_id": trial.trial_id, **native_config},
278
+ f, default_flow_style=False,
 
279
  )
280
 
281
  # Save summary index
 
321
  if isinstance(val, float):
322
  parts.append(f"{val:>12.6f}")
323
  else:
324
+ parts.append(f"{str(val):>12s}")
325
  for m in metric_names:
326
  val = trial.result.get(m, float("nan"))
327
  parts.append(f"{val:>12.4f}")