Daankular commited on
Commit
a713ba2
·
verified ·
1 Parent(s): aba2408

Update aoti.py - r3gm base with LoRA gallery

Browse files
Files changed (1) hide show
  1. aoti.py +35 -0
aoti.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import cast
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
9
+ from spaces.zero.torch.aoti import ZeroGPUWeights
10
+ from torch._functorch._aot_autograd.subclass_parametrization import unwrap_tensor_subclass_parameters
11
+
12
+
13
+ def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
14
+ clone = object.__new__(module.__class__)
15
+ clone.__dict__ = module.__dict__.copy()
16
+ clone._parameters = module._parameters.copy()
17
+ clone._buffers = module._buffers.copy()
18
+ clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
19
+ return clone
20
+
21
+
22
+ def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
23
+ repeated_blocks = cast(list[str], module._repeated_blocks)
24
+ aoti_files = {name: hf_hub_download(
25
+ repo_id=repo_id,
26
+ filename='package.pt2',
27
+ subfolder=name if variant is None else f'{name}.{variant}',
28
+ ) for name in repeated_blocks}
29
+ for block_name, aoti_file in aoti_files.items():
30
+ for block in module.modules():
31
+ if block.__class__.__name__ == block_name:
32
+ block_ = _shallow_clone_module(block)
33
+ unwrap_tensor_subclass_parameters(block_)
34
+ weights = ZeroGPUWeights(block_.state_dict())
35
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)