Diffusers
Safetensors
EvalMDE / Edit2Perceive /lora /__init__.py
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
import torch
from typing import Dict, Optional, Tuple
class GeneralLoRALoader:
"""
支持动态加载和卸载 LoRA 的加载器(显存优化版)。
核心机制:
1. 只保存 LoRA 的原始 A/B 矩阵(存储在 CPU 上)
2. 加载/卸载时临时计算增量,用完即释放
3. 切换 LoRA 时,先卸载旧的,再加载新的
用法:
loader = GeneralLoRALoader(device="cuda", torch_dtype=torch.bfloat16)
# 加载 LoRA
loader.load(model, lora_state_dict, alpha=1.0)
# 卸载当前 LoRA (恢复基础模型)
loader.unload(model)
# 切换到另一个 LoRA
loader.switch(model, new_lora_state_dict, alpha=1.0)
"""
def __init__(self, device="cpu", torch_dtype=torch.float32):
self.device = device
self.torch_dtype = torch_dtype
# 存储当前加载的 LoRA 原始矩阵 (存在 CPU 上节省显存):
# {module_name: (weight_up, weight_down)}
self._current_lora_weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
# 当前加载的 LoRA 的 alpha 值
self._current_alpha: float = 0.0
# 标记是否有 LoRA 被加载
self._lora_loaded: bool = False
def get_name_dict(self, lora_state_dict):
lora_name_dict = {}
for key in lora_state_dict:
if ".lora_B." not in key:
continue
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
if keys[0] == "diffusion_model":
keys.pop(0)
keys.pop(-1)
target_name = ".".join(keys)
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
return lora_name_dict
def _compute_lora_delta(self, weight_up: torch.Tensor, weight_down: torch.Tensor, alpha: float) -> torch.Tensor:
"""计算 LoRA 权重增量"""
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_lora = alpha * torch.mm(weight_up, weight_down)
return weight_lora
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
"""
加载 LoRA 权重到模型。
如果已有 LoRA 被加载,会先自动卸载。
"""
# 如果已经有 LoRA 加载,先卸载
if self._lora_loaded:
print("Detected existing LoRA, unloading first...")
self.unload(model)
updated_num = 0
lora_name_dict = self.get_name_dict(state_dict_lora)
self._current_lora_weights.clear()
for name, (up_key, down_key) in lora_name_dict.items():
try:
module = model.get_submodule(name)
except AttributeError:
continue
except Exception:
continue
weight_up = state_dict_lora[up_key].to(dtype=self.torch_dtype)
weight_down = state_dict_lora[down_key].to(dtype=self.torch_dtype)
# 保存原始 A/B 矩阵到 CPU(节省显存)
self._current_lora_weights[name] = (
weight_up.cpu().clone(),
weight_down.cpu().clone()
)
# 临时移到 GPU 计算增量
weight_up_gpu = weight_up.to(device=self.device)
weight_down_gpu = weight_down.to(device=self.device)
weight_lora = self._compute_lora_delta(weight_up_gpu, weight_down_gpu, alpha)
# 应用到模型
module.weight.data += weight_lora
# 立即释放临时 GPU 张量
del weight_up_gpu, weight_down_gpu, weight_lora
updated_num += 1
self._current_alpha = alpha
self._lora_loaded = True
# 清理 GPU 缓存
if self.device != "cpu":
torch.cuda.empty_cache()
print(f"{updated_num} tensors are updated by LoRA.")
def unload(self, model: torch.nn.Module):
"""
卸载当前 LoRA,恢复基础模型权重。
"""
if not self._lora_loaded:
print("No LoRA is currently loaded.")
return
unloaded_num = 0
for name, (weight_up, weight_down) in self._current_lora_weights.items():
try:
module = model.get_submodule(name)
except AttributeError:
continue
except Exception:
continue
# 临时移到 GPU 计算增量
weight_up_gpu = weight_up.to(device=self.device, dtype=self.torch_dtype)
weight_down_gpu = weight_down.to(device=self.device, dtype=self.torch_dtype)
weight_delta = self._compute_lora_delta(weight_up_gpu, weight_down_gpu, self._current_alpha)
# 减去增量恢复原始权重
module.weight.data -= weight_delta
# 立即释放临时 GPU 张量
del weight_up_gpu, weight_down_gpu, weight_delta
unloaded_num += 1
self._current_lora_weights.clear()
self._current_alpha = 0.0
self._lora_loaded = False
# 清理 GPU 缓存
if self.device != "cpu":
torch.cuda.empty_cache()
print(f"{unloaded_num} tensors restored to base model.")
def switch(self, model: torch.nn.Module, new_state_dict_lora, alpha=1.0):
"""
快速切换到另一个 LoRA。
等价于 unload() + load(),但语义更清晰。
"""
if self._lora_loaded:
self.unload(model)
self.load(model, new_state_dict_lora, alpha)
def is_loaded(self) -> bool:
"""检查是否有 LoRA 被加载"""
return self._lora_loaded
def get_loaded_modules(self) -> list:
"""获取当前加载了 LoRA 的模块名称列表"""
return list(self._current_lora_weights.keys())