File size: 2,018 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | from pathlib import Path
from typing import Any
import torch
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
def check_config_value(config: dict, key: str, expected: Any) -> None: # noqa: ANN401
actual = config.get(key)
if actual != expected:
raise ValueError(f"Config value {key} is {actual}, expected {expected}")
def to_velocity(
sample: torch.Tensor,
sigma: float | torch.Tensor,
denoised_sample: torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoised version to velocity.
Returns:
Velocity
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype).item()
if sigma == 0:
raise ValueError("Sigma can't be 0.0")
return ((sample.to(calc_dtype) - denoised_sample.to(calc_dtype)) / sigma).to(sample.dtype)
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
def find_matching_file(root_path: str, pattern: str) -> Path:
"""
Recursively search for files matching a glob pattern and return the first match.
"""
matches = list(Path(root_path).rglob(pattern))
if not matches:
raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}")
return matches[0]
|