Spaces:
Running on Zero
Running on Zero
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: | |
| """Root-mean-square (RMS) normalize `x` over its last dimension. | |
| Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized | |
| shape and forwards `weight` and `eps`. | |
| """ | |
| return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) | |
| def check_config_value(config: dict, key: str, expected: Any) -> None: # noqa: ANN401 | |
| actual = config.get(key) | |
| if actual != expected: | |
| raise ValueError(f"Config value {key} is {actual}, expected {expected}") | |
| def to_velocity( | |
| sample: torch.Tensor, | |
| sigma: float | torch.Tensor, | |
| denoised_sample: torch.Tensor, | |
| calc_dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| """ | |
| Convert the sample and its denoised version to velocity. | |
| Returns: | |
| Velocity | |
| """ | |
| if isinstance(sigma, torch.Tensor): | |
| sigma = sigma.to(calc_dtype).item() | |
| if sigma == 0: | |
| raise ValueError("Sigma can't be 0.0") | |
| return ((sample.to(calc_dtype) - denoised_sample.to(calc_dtype)) / sigma).to(sample.dtype) | |
| def to_denoised( | |
| sample: torch.Tensor, | |
| velocity: torch.Tensor, | |
| sigma: float | torch.Tensor, | |
| calc_dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| """ | |
| Convert the sample and its denoising velocity to denoised sample. | |
| Returns: | |
| Denoised sample | |
| """ | |
| if isinstance(sigma, torch.Tensor): | |
| sigma = sigma.to(calc_dtype) | |
| return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) | |
| def find_matching_file(root_path: str, pattern: str) -> Path: | |
| """ | |
| Recursively search for files matching a glob pattern and return the first match. | |
| """ | |
| matches = list(Path(root_path).rglob(pattern)) | |
| if not matches: | |
| raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}") | |
| return matches[0] | |