File size: 4,185 Bytes
0c4cb0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch


class LoRALayer(torch.nn.Module):
    def __init__(self, dim_in, dim_out, rank, initialize=False):
        super().__init__()
        if initialize:
            scale = (1 / dim_in) ** 0.5
            self.lora_A = torch.nn.Parameter(torch.rand((rank, dim_in)) * (scale * 2) - scale)
            self.lora_B = torch.nn.Parameter(torch.zeros((dim_out, rank)))
        else:
            self.lora_A = torch.nn.Parameter(torch.empty((rank, dim_in)))
            self.lora_B = torch.nn.Parameter(torch.empty((dim_out, rank)))


class LoRA(torch.nn.Module):
    def __init__(self, rank):
        super().__init__()
        self.lora_patterns = [
            {
                "name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj",
                "num_blocks": 20,
                "dim_in": 3072,
                "dim_out": 27648,
                "rank": rank,
            },
            {
                "name": "single_transformer_blocks.{block_id}.attn.to_out",
                "num_blocks": 20,
                "dim_in": 12288,
                "dim_out": 3072,
                "rank": rank,
            },
        ]
        self.parse_lora_layers(self.lora_patterns)
    
    def parse_lora_layers(self, lora_patterns):
        names = []
        layers = []
        for lora_pattern in lora_patterns:
            for block_id in range(lora_pattern["num_blocks"]):
                name = lora_pattern["name"].format(block_id=block_id)
                layer = LoRALayer(lora_pattern["dim_in"], lora_pattern["dim_out"], lora_pattern["rank"])
                names.append(name)
                layers.append(layer)
        self.names = names
        self.layers = torch.nn.ModuleList(layers)

    def forward(self):
        lora = {}
        for name, layer in zip(self.names, self.layers):
            lora[f"{name}.lora_A.default.weight"] = layer.lora_A
            lora[f"{name}.lora_B.default.weight"] = layer.lora_B
        return lora


class DualLoRA(torch.nn.Module):
    def __init__(self, num_loras=180):
        super().__init__()
        self.loras = torch.nn.ModuleList([LoRA(rank=4) for _ in range(num_loras)])

    @torch.no_grad()
    def process_inputs(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs):
        return {"lora_ids": lora_ids, "lora_scales": lora_scales, "require_grads": require_grads, "merge_type": merge_type}
    
    def forward(self, lora_ids, lora_scales, require_grads=None, merge_type="concat", **kwargs):
        if isinstance(lora_scales, float):
            lora_scales = [lora_scales] * len(lora_ids)
        if require_grads is None:
            require_grads = [True] * len(lora_scales)
        loras = []
        for lora_id, lora_scale, require_grad in zip(lora_ids, lora_scales, require_grads):
            if not require_grad:
                with torch.no_grad():
                    lora_ = self.loras[lora_id]()
            else:
                lora_ = self.loras[lora_id]()
            lora_ = {key: lora_[key] * (lora_scale if "lora_A" in key else 1) for key in lora_}
            loras.append(lora_)
        lora = {}
        if merge_type == "concat":
            for key in loras[0]:
                if "lora_A" in key:
                    lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=0)
                else:
                    lora[key] = torch.concat([lora_[key] for lora_ in loras], dim=1)
        elif merge_type == "sum":
            for key in loras[0]:
                lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0)
        elif merge_type == "mean":
            for key in loras[0]:
                if "lora_A" in key:
                    lora[key] = torch.stack([lora_[key] for lora_ in loras]).mean(dim=0)
                else:
                    lora[key] = torch.stack([lora_[key] for lora_ in loras]).sum(dim=0)
        else:
            raise ValueError(f"Unsupported merge_type: {merge_type}")
        return {"lora": lora}


class DataAnnotator:
    def __call__(self, **kwargs):
        return kwargs


TEMPLATE_MODEL = DualLoRA
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataAnnotator