| """Modified from https://github.com/kijai/ComfyUI-MochiWrapper |
| """ |
| import importlib.util |
|
|
| import torch |
| import torch.nn as nn |
|
|
| def replace_parameters_by_name(module, name_keywords, device): |
| from torch import nn |
| for name, param in list(module.named_parameters(recurse=False)): |
| if any(keyword in name for keyword in name_keywords): |
| if isinstance(param, nn.Parameter): |
| tensor = param.data |
| delattr(module, name) |
| setattr(module, name, tensor.to(device=device)) |
| for child_name, child_module in module.named_children(): |
| replace_parameters_by_name(child_module, name_keywords, device) |
|
|
| def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens'], device=None): |
| for name, module in model.named_modules(): |
| flag = False |
| for _exclude_module_name in exclude_module_name: |
| if _exclude_module_name in name: |
| flag = True |
| if flag: |
| continue |
| for param_name, param in module.named_parameters(): |
| flag = False |
| for _exclude_module_name in exclude_module_name: |
| if _exclude_module_name in param_name: |
| flag = True |
| if flag: |
| continue |
| param.data = param.data.to(torch.float8_e4m3fn) |
|
|
| def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): |
| weight_dtype = cls.weight.dtype |
| cls.to(origin_dtype) |
|
|
| |
| inputs = [input.to(origin_dtype) for input in inputs] |
| out = cls.original_forward(*inputs, **kwargs) |
|
|
| cls.to(weight_dtype) |
| return out |
|
|
| def convert_weight_dtype_wrapper(module, origin_dtype): |
| for name, module in module.named_modules(): |
| if name == "" or "embed_tokens" in name: |
| continue |
| original_forward = module.forward |
| if hasattr(module, "weight") and module.weight is not None: |
| setattr(module, "original_forward", original_forward) |
| setattr( |
| module, |
| "forward", |
| lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) |
| ) |