| import sys |
| from typing import Any, Callable, Union |
|
|
| from torch import nn |
| from torch.utils.hooks import RemovableHandle |
|
|
| from ldm.modules.diffusionmodules.openaimodel import ( |
| TimestepEmbedSequential, |
| ) |
| from ldm.modules.attention import ( |
| SpatialTransformer, |
| BasicTransformerBlock, |
| CrossAttention, |
| MemoryEfficientCrossAttention, |
| ) |
| from ldm.modules.diffusionmodules.openaimodel import ( |
| ResBlock, |
| ) |
| from modules.processing import StableDiffusionProcessing |
| from modules import shared |
|
|
| class ForwardHook: |
| |
| def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any], Any], Any]): |
| self.o = module.forward |
| self.fn = fn |
| self.module = module |
| self.module.forward = self.forward |
| |
| def remove(self): |
| if self.module is not None and self.o is not None: |
| self.module.forward = self.o |
| self.module = None |
| self.o = None |
| self.fn = None |
| |
| def forward(self, *args, **kwargs): |
| if self.module is not None and self.o is not None: |
| if self.fn is not None: |
| return self.fn(self.module, self.o, *args, **kwargs) |
| return None |
| |
|
|
| class SDHook: |
| |
| def __init__(self, enabled: bool): |
| self._enabled = enabled |
| self._handles: list[Union[RemovableHandle,ForwardHook]] = [] |
| |
| @property |
| def enabled(self): |
| return self._enabled |
| |
| @property |
| def batch_num(self): |
| return shared.state.job_no |
| |
| @property |
| def step_num(self): |
| return shared.state.current_image_sampling_step |
| |
| def __enter__(self): |
| if self.enabled: |
| pass |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| if self.enabled: |
| for handle in self._handles: |
| handle.remove() |
| self._handles.clear() |
| self.dispose() |
| |
| def dispose(self): |
| pass |
| |
| def setup( |
| self, |
| p: StableDiffusionProcessing |
| ): |
| if not self.enabled: |
| return |
| |
| wrapper = getattr(p.sd_model, "model", None) |
| |
| unet: Union[nn.Module,None] = getattr(wrapper, "diffusion_model", None) if wrapper is not None else None |
| vae: Union[nn.Module,None] = getattr(p.sd_model, "first_stage_model", None) |
| clip: Union[nn.Module,None] = getattr(p.sd_model, "cond_stage_model", None) |
| |
| assert unet is not None, "p.sd_model.diffusion_model is not found. broken model???" |
| self._do_hook(p, p.sd_model, unet=unet, vae=vae, clip=clip) |
| self.on_setup() |
| |
| def on_setup(self): |
| pass |
| |
| def _do_hook( |
| self, |
| p: StableDiffusionProcessing, |
| model: Any, |
| unet: Union[nn.Module,None], |
| vae: Union[nn.Module,None], |
| clip: Union[nn.Module,None] |
| ): |
| assert model is not None, "empty model???" |
| |
| if clip is not None: |
| self.hook_clip(p, clip) |
| |
| if unet is not None: |
| self.hook_unet(p, unet) |
| |
| if vae is not None: |
| self.hook_vae(p, vae) |
| |
| def hook_vae( |
| self, |
| p: StableDiffusionProcessing, |
| vae: nn.Module |
| ): |
| pass |
|
|
| def hook_unet( |
| self, |
| p: StableDiffusionProcessing, |
| unet: nn.Module |
| ): |
| pass |
|
|
| def hook_clip( |
| self, |
| p: StableDiffusionProcessing, |
| clip: nn.Module |
| ): |
| pass |
|
|
| def hook_layer( |
| self, |
| module: Union[nn.Module,Any], |
| fn: Callable[[nn.Module, tuple, Any], Any] |
| ): |
| if not self.enabled: |
| return |
| |
| assert module is not None |
| assert isinstance(module, nn.Module) |
| self._handles.append(module.register_forward_hook(fn)) |
|
|
| def hook_layer_pre( |
| self, |
| module: Union[nn.Module,Any], |
| fn: Callable[[nn.Module, tuple], Any] |
| ): |
| if not self.enabled: |
| return |
| |
| assert module is not None |
| assert isinstance(module, nn.Module) |
| self._handles.append(module.register_forward_pre_hook(fn)) |
|
|
| def hook_forward( |
| self, |
| module: Union[nn.Module,Any], |
| fn: Callable[[nn.Module, Callable[..., Any], Any], Any] |
| ): |
| assert module is not None |
| assert isinstance(module, nn.Module) |
| self._handles.append(ForwardHook(module, fn)) |
| |
| def log(self, msg: str): |
| print(msg, file=sys.stderr) |
|
|
|
|
| |
| def each_transformer(unet_block: TimestepEmbedSequential): |
| for block in unet_block.children(): |
| if isinstance(block, SpatialTransformer): |
| yield block |
|
|
| |
| def each_basic_block(trans: SpatialTransformer): |
| for block in trans.transformer_blocks.children(): |
| if isinstance(block, BasicTransformerBlock): |
| yield block |
|
|
| |
| |
| def each_attns(unet_block: TimestepEmbedSequential): |
| for n, trans in enumerate(each_transformer(unet_block)): |
| for depth, basic_block in enumerate(each_basic_block(trans)): |
| |
| |
| |
| attn1, attn2 = basic_block.attn1, basic_block.attn2 |
| assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention) |
| assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention) |
| |
| yield n, depth, attn1, attn2 |
|
|
| def each_unet_attn_layers(unet: nn.Module): |
| def get_attns(layer_index: int, block: TimestepEmbedSequential, format: str): |
| for n, d, attn1, attn2 in each_attns(block): |
| kwargs = { |
| 'layer_index': layer_index, |
| 'trans_index': n, |
| 'block_index': d |
| } |
| yield format.format(attn_name='sattn', **kwargs), attn1 |
| yield format.format(attn_name='xattn', **kwargs), attn2 |
| |
| def enumerate_all(blocks: nn.ModuleList, format: str): |
| for idx, block in enumerate(blocks.children()): |
| if isinstance(block, TimestepEmbedSequential): |
| yield from get_attns(idx, block, format) |
| |
| inputs: nn.ModuleList = unet.input_blocks |
| middle: TimestepEmbedSequential = unet.middle_block |
| outputs: nn.ModuleList = unet.output_blocks |
| |
| yield from enumerate_all(inputs, 'IN{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}') |
| yield from get_attns(0, middle, 'M{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}') |
| yield from enumerate_all(outputs, 'OUT{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}') |
|
|
|
|
| def each_unet_transformers(unet: nn.Module): |
| def get_trans(layer_index: int, block: TimestepEmbedSequential, format: str): |
| for n, trans in enumerate(each_transformer(block)): |
| kwargs = { |
| 'layer_index': layer_index, |
| 'block_index': n, |
| 'block_name': 'trans', |
| } |
| yield format.format(**kwargs), trans |
| |
| def enumerate_all(blocks: nn.ModuleList, format: str): |
| for idx, block in enumerate(blocks.children()): |
| if isinstance(block, TimestepEmbedSequential): |
| yield from get_trans(idx, block, format) |
| |
| inputs: nn.ModuleList = unet.input_blocks |
| middle: TimestepEmbedSequential = unet.middle_block |
| outputs: nn.ModuleList = unet.output_blocks |
| |
| yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}') |
| yield from get_trans(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}') |
| yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}') |
|
|
|
|
| def each_resblock(unet_block: TimestepEmbedSequential): |
| for block in unet_block.children(): |
| if isinstance(block, ResBlock): |
| yield block |
|
|
| def each_unet_resblock(unet: nn.Module): |
| def get_resblock(layer_index: int, block: TimestepEmbedSequential, format: str): |
| for n, res in enumerate(each_resblock(block)): |
| kwargs = { |
| 'layer_index': layer_index, |
| 'block_index': n, |
| 'block_name': 'resblock', |
| } |
| yield format.format(**kwargs), res |
| |
| def enumerate_all(blocks: nn.ModuleList, format: str): |
| for idx, block in enumerate(blocks.children()): |
| if isinstance(block, TimestepEmbedSequential): |
| yield from get_resblock(idx, block, format) |
| |
| inputs: nn.ModuleList = unet.input_blocks |
| middle: TimestepEmbedSequential = unet.middle_block |
| outputs: nn.ModuleList = unet.output_blocks |
| |
| yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}') |
| yield from get_resblock(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}') |
| yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}') |
|
|
|
|