File size: 397 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | from typing import Callable, NamedTuple
import torch
class ModuleOps(NamedTuple):
"""
Defines a named operation for matching and mutating PyTorch modules.
Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
"""
name: str
matcher: Callable[[torch.nn.Module], bool]
mutator: Callable[[torch.nn.Module], torch.nn.Module]
|