File size: 786 Bytes
abd08dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

def decomposite(tensor_A, tensor_B, rank):
    dtype, device = tensor_A.dtype, tensor_A.device
    weight = tensor_B @ tensor_A
    U, S, V = torch.pca_lowrank(weight.float(), q=rank)
    tensor_A = (V.T).to(dtype=dtype, device=device).contiguous()
    tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous()
    return tensor_A, tensor_B

def reset_lora_rank(lora, rank):
    lora_merged = {}
    keys = [i for i in lora.keys() if ".lora_A." in i]
    for key in keys:
        tensor_A = lora[key]
        tensor_B = lora[key.replace(".lora_A.", ".lora_B.")]
        tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank)
        lora_merged[key] = tensor_A
        lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B
    return lora_merged