| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from contextlib import contextmanager |
| from typing import Sequence |
|
|
| import click |
| import numpy as np |
| from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
| from nemo.collections.common.data.lhotse.cutset import get_parser_fn |
|
|
|
|
| @click.command() |
| @click.argument("input_cfgs", type=click.Path(exists=True, dir_okay=False), nargs=-1) |
| @click.argument("output_cfg", type=click.Path()) |
| @click.option( |
| "-t", |
| "--temperature", |
| type=float, |
| default=None, |
| multiple=True, |
| help="Temperature for re-weighting datasets. 1 is a neutral value. " |
| "Lower temperature over-samples smaller datasets, and vice versa. " |
| "Can be specified multiple times to apply a different temperature to each group level in the YAML config.", |
| ) |
| @click.option( |
| "-s", |
| "--strategy", |
| type=click.Choice(["num_hours", "num_examples"]), |
| default="num_hours", |
| help="Strategy for choosing weights for each dataset.", |
| ) |
| def estimate_data_weights(input_cfgs: str, output_cfg: str, temperature: list[float], strategy: str): |
| """ |
| Read a YAML specification of datasets from INPUT_CFGS, compute their weights, and save the result in OUTPUT_CFG. |
| The weight for each entry is determined by the number of hours in a given dataset. |
| |
| If more than one config is provided as input, we will concatenate them and output a single merged config. |
| |
| Optionally, apply temperature re-weighting to balance the datasets (specify TEMPERATURE lesser than 1). |
| """ |
| data = ListConfig([]) |
| for icfg in input_cfgs: |
| data.extend(OmegaConf.load(icfg)) |
| temperature = parse_temperature(temperature) |
| validate(data) |
| count(data, weight_key=strategy) |
| aggregate_group_weights(data) |
| reweight(data, temperature=temperature) |
| OmegaConf.save(data, output_cfg) |
|
|
|
|
| def validate(entry: DictConfig | ListConfig, _level: int = 0): |
| if isinstance(entry, ListConfig): |
| for subentry in entry: |
| validate(subentry, _level + 1) |
| return |
|
|
| assert "type" in entry, f"Invalid YAML data config at nesting level {_level}: missing key 'type' in entry={entry}" |
|
|
| if entry.type == "group": |
| for subentry in entry["input_cfg"]: |
| validate(subentry, _level + 1) |
|
|
|
|
| def count(entry: DictConfig | ListConfig, weight_key: str) -> None: |
| if isinstance(entry, ListConfig): |
| for subentry in entry: |
| count(subentry, weight_key=weight_key) |
| return |
| if entry.type == "group": |
| for subentry in entry["input_cfg"]: |
| count(subentry, weight_key=weight_key) |
| return |
|
|
| with quick_iter_options(entry): |
| iterable, is_tarred = get_parser_fn(entry.type)(entry) |
| stats = {"num_hours": 0.0, "num_examples": 0} |
| for example in iterable: |
| if hasattr(example, "duration"): |
| stats["num_hours"] += example.duration |
| stats["num_examples"] += 1 |
| stats["num_hours"] /= 3600.0 |
|
|
| if weight_key == "num_hours" and stats[weight_key] == 0.0: |
| raise RuntimeError( |
| f"Cannot set weights based on 'num_hours': at least one dataset has examples without 'duration' property. " |
| f"Details: {entry=}" |
| ) |
|
|
| entry["weight"] = stats[weight_key] |
|
|
|
|
| def aggregate_group_weights(entry: DictConfig | ListConfig) -> None: |
| if isinstance(entry, ListConfig): |
| for subentry in entry: |
| aggregate_group_weights(subentry) |
| return |
|
|
| if entry.type != "group": |
| return |
|
|
| for subentry in entry["input_cfg"]: |
| if "weight" not in subentry: |
| aggregate_group_weights(subentry) |
|
|
| entry.weight = sum(subentry["weight"] for subentry in entry["input_cfg"]) |
|
|
|
|
| def reweight(entry: DictConfig | ListConfig, temperature: None | float | list[float]) -> None: |
| if not temperature or (isinstance(entry, DictConfig) and entry.type != "group"): |
| return |
|
|
| if isinstance(temperature, Sequence): |
| temperature, *next_temperatures = temperature |
| else: |
| next_temperatures = temperature |
|
|
| if isinstance(entry, ListConfig): |
| for subentry in entry: |
| reweight(subentry, temperature=next_temperatures) |
| new_weights = temperature_reweighting([se.weight for se in entry], temperature=temperature) |
| for se, nw in zip(entry, new_weights): |
| se.weight = nw |
| return |
|
|
| for subentry in entry["input_cfg"]: |
| reweight(subentry, temperature=next_temperatures) |
|
|
| new_weights = temperature_reweighting([se.weight for se in entry["input_cfg"]], temperature=temperature) |
| for se, nw in zip(entry["input_cfg"], new_weights): |
| se.weight = nw |
|
|
|
|
| def temperature_reweighting(weights: list[float], temperature: float = 1.0): |
| """(w_i ^ alpha / sum(w_i ^ alpha))""" |
| weights = np.asarray(weights) ** temperature |
| return (weights / weights.sum()).tolist() |
|
|
|
|
| @contextmanager |
| def quick_iter_options(entry: DictConfig): |
| entry.metadata_only = True |
| entry.force_finite = True |
| yield entry |
| del entry["metadata_only"] |
| del entry["force_finite"] |
|
|
|
|
| def parse_temperature(value: list[float]) -> float | list[float] | None: |
| match value: |
| case 0: |
| return None |
| case 1: |
| return value[0] |
| case _: |
| return value |
|
|
|
|
| if __name__ == '__main__': |
| estimate_data_weights() |
|
|