| import torch |
|
|
|
|
| class LoRALayer(torch.nn.Module): |
| def __init__(self, dim_in, dim_out, rank, initialize=False): |
| super().__init__() |
| if initialize: |
| scale = (1 / dim_in) ** 0.5 |
| self.lora_A = torch.nn.Parameter(torch.rand((rank, dim_in)) * (scale * 2) - scale) |
| self.lora_B = torch.nn.Parameter(torch.zeros((dim_out, rank))) |
| else: |
| self.lora_A = torch.nn.Parameter(torch.empty((rank, dim_in))) |
| self.lora_B = torch.nn.Parameter(torch.empty((dim_out, rank))) |
|
|
|
|
| class LoRA(torch.nn.Module): |
| def __init__(self, rank): |
| super().__init__() |
| self.lora_patterns = [ |
| { |
| "name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj", |
| "num_blocks": 20, |
| "dim_in": 3072, |
| "dim_out": 27648, |
| "rank": rank, |
| }, |
| { |
| "name": "single_transformer_blocks.{block_id}.attn.to_out", |
| "num_blocks": 20, |
| "dim_in": 12288, |
| "dim_out": 3072, |
| "rank": rank, |
| }, |
| ] |
| self.parse_lora_layers(self.lora_patterns) |
| |
| def parse_lora_layers(self, lora_patterns): |
| names = [] |
| layers = [] |
| for lora_pattern in lora_patterns: |
| for block_id in range(lora_pattern["num_blocks"]): |
| name = lora_pattern["name"].format(block_id=block_id) |
| layer = LoRALayer(lora_pattern["dim_in"], lora_pattern["dim_out"], lora_pattern["rank"]) |
| names.append(name) |
| layers.append(layer) |
| self.names = names |
| self.layers = torch.nn.ModuleList(layers) |
|
|
| def forward(self): |
| lora = {} |
| for name, layer in zip(self.names, self.layers): |
| lora[f"{name}.lora_A.default.weight"] = layer.lora_A |
| lora[f"{name}.lora_B.default.weight"] = layer.lora_B |
| return lora |
|
|
|
|
| class DualLoRA(torch.nn.Module): |
| def __init__(self, num_loras=180): |
| super().__init__() |
| self.loras = torch.nn.ModuleList([LoRA(rank=4) for _ in range(num_loras)]) |
|
|
| @torch.no_grad() |
| def process_inputs(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs): |
| return {"lora_ids": lora_ids, "lora_scales": lora_scales, "require_grads": require_grads, "merge_type": merge_type} |
| |
| def forward(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs): |
| if isinstance(lora_scales, float): |
| lora_scales = [lora_scales] * len(lora_ids) |
| if require_grads is None: |
| require_grads = [True] * len(lora_scales) |
| loras = [] |
| for lora_id, lora_scale, require_grad in zip(lora_ids, lora_scales, require_grads): |
| if not require_grad: |
| with torch.no_grad(): |
| lora_ = self.loras[lora_id]() |
| else: |
| lora_ = self.loras[lora_id]() |
| lora_ = {key: lora_[key] * (lora_scale if "lora_A" in key else 1) for key in lora_} |
| loras.append(lora_) |
| lora = {} |
| if merge_type == "concat": |
| for key in loras[0]: |
| if "lora_A" in key: |
| lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=0) |
| else: |
| lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=1) |
| elif merge_type == "sum": |
| for key in loras[0]: |
| lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0) |
| elif merge_type == "mean": |
| for key in loras[0]: |
| if "lora_A" in key: |
| lora[key] = torch.stack([lora_[key] for lora_ in loras]).mean(dim=0) |
| else: |
| lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0) |
| else: |
| raise ValueError(f"Unsupported merge_type: {merge_type}") |
| return {"lora": lora} |
|
|
|
|
| class DataAnnotator: |
| def __call__(self, **kwargs): |
| return kwargs |
|
|
|
|
| TEMPLATE_MODEL = DualLoRA |
| TEMPLATE_MODEL_PATH = "model.safetensors" |
| TEMPLATE_DATA_PROCESSOR = DataAnnotator |
|
|