kelseye's picture
Upload folder using huggingface_hub
0c4cb0e verified
import torch
class LoRALayer(torch.nn.Module):
def __init__(self, dim_in, dim_out, rank, initialize=False):
super().__init__()
if initialize:
scale = (1 / dim_in) ** 0.5
self.lora_A = torch.nn.Parameter(torch.rand((rank, dim_in)) * (scale * 2) - scale)
self.lora_B = torch.nn.Parameter(torch.zeros((dim_out, rank)))
else:
self.lora_A = torch.nn.Parameter(torch.empty((rank, dim_in)))
self.lora_B = torch.nn.Parameter(torch.empty((dim_out, rank)))
class LoRA(torch.nn.Module):
def __init__(self, rank):
super().__init__()
self.lora_patterns = [
{
"name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj",
"num_blocks": 20,
"dim_in": 3072,
"dim_out": 27648,
"rank": rank,
},
{
"name": "single_transformer_blocks.{block_id}.attn.to_out",
"num_blocks": 20,
"dim_in": 12288,
"dim_out": 3072,
"rank": rank,
},
]
self.parse_lora_layers(self.lora_patterns)
def parse_lora_layers(self, lora_patterns):
names = []
layers = []
for lora_pattern in lora_patterns:
for block_id in range(lora_pattern["num_blocks"]):
name = lora_pattern["name"].format(block_id=block_id)
layer = LoRALayer(lora_pattern["dim_in"], lora_pattern["dim_out"], lora_pattern["rank"])
names.append(name)
layers.append(layer)
self.names = names
self.layers = torch.nn.ModuleList(layers)
def forward(self):
lora = {}
for name, layer in zip(self.names, self.layers):
lora[f"{name}.lora_A.default.weight"] = layer.lora_A
lora[f"{name}.lora_B.default.weight"] = layer.lora_B
return lora
class DualLoRA(torch.nn.Module):
def __init__(self, num_loras=180):
super().__init__()
self.loras = torch.nn.ModuleList([LoRA(rank=4) for _ in range(num_loras)])
@torch.no_grad()
def process_inputs(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs):
return {"lora_ids": lora_ids, "lora_scales": lora_scales, "require_grads": require_grads, "merge_type": merge_type}
def forward(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs):
if isinstance(lora_scales, float):
lora_scales = [lora_scales] * len(lora_ids)
if require_grads is None:
require_grads = [True] * len(lora_scales)
loras = []
for lora_id, lora_scale, require_grad in zip(lora_ids, lora_scales, require_grads):
if not require_grad:
with torch.no_grad():
lora_ = self.loras[lora_id]()
else:
lora_ = self.loras[lora_id]()
lora_ = {key: lora_[key] * (lora_scale if "lora_A" in key else 1) for key in lora_}
loras.append(lora_)
lora = {}
if merge_type == "concat":
for key in loras[0]:
if "lora_A" in key:
lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=0)
else:
lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=1)
elif merge_type == "sum":
for key in loras[0]:
lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0)
elif merge_type == "mean":
for key in loras[0]:
if "lora_A" in key:
lora[key] = torch.stack([lora_[key] for lora_ in loras]).mean(dim=0)
else:
lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0)
else:
raise ValueError(f"Unsupported merge_type: {merge_type}")
return {"lora": lora}
class DataAnnotator:
def __call__(self, **kwargs):
return kwargs
TEMPLATE_MODEL = DualLoRA
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataAnnotator