import torch class LoRALayer(torch.nn.Module): def __init__(self, dim_in, dim_out, rank, initialize=False): super().__init__() if initialize: scale = (1 / dim_in) ** 0.5 self.lora_A = torch.nn.Parameter(torch.rand((rank, dim_in)) * (scale * 2) - scale) self.lora_B = torch.nn.Parameter(torch.zeros((dim_out, rank))) else: self.lora_A = torch.nn.Parameter(torch.empty((rank, dim_in))) self.lora_B = torch.nn.Parameter(torch.empty((dim_out, rank))) class LoRA(torch.nn.Module): def __init__(self, rank): super().__init__() self.lora_patterns = [ { "name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj", "num_blocks": 20, "dim_in": 3072, "dim_out": 27648, "rank": rank, }, { "name": "single_transformer_blocks.{block_id}.attn.to_out", "num_blocks": 20, "dim_in": 12288, "dim_out": 3072, "rank": rank, }, ] self.parse_lora_layers(self.lora_patterns) def parse_lora_layers(self, lora_patterns): names = [] layers = [] for lora_pattern in lora_patterns: for block_id in range(lora_pattern["num_blocks"]): name = lora_pattern["name"].format(block_id=block_id) layer = LoRALayer(lora_pattern["dim_in"], lora_pattern["dim_out"], lora_pattern["rank"]) names.append(name) layers.append(layer) self.names = names self.layers = torch.nn.ModuleList(layers) def forward(self): lora = {} for name, layer in zip(self.names, self.layers): lora[f"{name}.lora_A.default.weight"] = layer.lora_A lora[f"{name}.lora_B.default.weight"] = layer.lora_B return lora class DualLoRA(torch.nn.Module): def __init__(self, num_loras=180): super().__init__() self.loras = torch.nn.ModuleList([LoRA(rank=4) for _ in range(num_loras)]) @torch.no_grad() def process_inputs(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs): return {"lora_ids": lora_ids, "lora_scales": lora_scales, "require_grads": require_grads, "merge_type": merge_type} def forward(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs): if isinstance(lora_scales, float): lora_scales = [lora_scales] * len(lora_ids) if require_grads is None: require_grads = [True] * len(lora_scales) loras = [] for lora_id, lora_scale, require_grad in zip(lora_ids, lora_scales, require_grads): if not require_grad: with torch.no_grad(): lora_ = self.loras[lora_id]() else: lora_ = self.loras[lora_id]() lora_ = {key: lora_[key] * (lora_scale if "lora_A" in key else 1) for key in lora_} loras.append(lora_) lora = {} if merge_type == "concat": for key in loras[0]: if "lora_A" in key: lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=0) else: lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=1) elif merge_type == "sum": for key in loras[0]: lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0) elif merge_type == "mean": for key in loras[0]: if "lora_A" in key: lora[key] = torch.stack([lora_[key] for lora_ in loras]).mean(dim=0) else: lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0) else: raise ValueError(f"Unsupported merge_type: {merge_type}") return {"lora": lora} class DataAnnotator: def __call__(self, **kwargs): return kwargs TEMPLATE_MODEL = DualLoRA TEMPLATE_MODEL_PATH = "model.safetensors" TEMPLATE_DATA_PROCESSOR = DataAnnotator