File size: 397 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import Callable, NamedTuple

import torch


class ModuleOps(NamedTuple):
    """
    Defines a named operation for matching and mutating PyTorch modules.
    Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
    """

    name: str
    matcher: Callable[[torch.nn.Module], bool]
    mutator: Callable[[torch.nn.Module], torch.nn.Module]