The model holds the additional LoRA (Low-Rank Adaptation) layers used for fine-tuning the meta-llama/Llama-3.2-1B model. The LoRA layers holds 13.5 million parameters which were trainable during the fine-tuning process.

The base Llama model has 1.2 billion parameters.

Integrating the LoRA layers' weights with the base model

The base model weights were frozen duuring the fine-tuning process, hence we can download it directly from the hub and integrate it's wieghts with this model weights. Loading the Llama base model.

from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = 'meta-llama/Llama-3.2-1B'
base_model = AutoModelForCausalLM.from_pretrained(model_id)

Get the base model parameters to be integrated with the LoRA Layer weights.

import torch.nn as nn

def get_base_model_parameter(model, rank, alpha):
  base_model_parameters = []
  for name, module in model.named_children():
    if isinstance(module, torch.nn.Linear):
      base_model_parameters.append(
        [module.in_features, module.out_features, rank, alpha]
      )
    else:
      base_model_parameters.extend(get_base_model_parameter(module, rank, alpha))

  return base_model_parameters

base_model_parameters = get_base_model_parameter(base_model, rank=16, alpha=16)

Loading the model weights to the LoRA Layers

import torch
import math
from huggingface_hub import PyTorchModelHubMixin

class LoRALayer(torch.nn.Module):
  def __init__(self, in_dim, out_dim, rank, alpha):
    super().__init__()
    self.A = torch.nn.Parameter(torch.empty(in_dim, rank, dtype=torch.bfloat16))
    torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
    self.B = torch.nn.Parameter(torch.zeros(rank, out_dim, dtype=torch.bfloat16))
    self.alpha = alpha
    self.rank = rank

  def forward(self, x):
    x = (self.alpha / self.rank) * (x @ self.A @ self.B)
    return x


class LlamaSummarizationLoRALayers(nn.Module, PyTorchModelHubMixin):
  def __init__(self):
    super().__init__()
    self.lora_layers = nn.ModuleList(
      [
        LoRALayer(
          base_model_parameter[0],
          base_model_parameter[1],
          base_model_parameter[2],
          base_model_parameter[3]
        ) 
        for base_model_parameter in base_model_parameters
      ]
    )

model_id = 'SauravP97/llama-3.2-1B-summarization-loraa-layers'
llama_summarizer_lora_layers = LlamaSummarizationLoRALayers.from_pretrained(model_id)

Integrate the LoRA layers with the Llama base model.

class LinearWithLoRA(torch.nn.Module):
  def __init__(self, linear, lora_layer):
    super().__init__()
    self.linear = linear
    self.lora = lora_layer

  def forward(self, x):
    return self.linear(x) + self.lora(x)

def replace_linear_with_preloaded_lora(model, lora_layers, cur_index):
  for name, module in model.named_children():
    if isinstance(module, torch.nn.Linear):
      print(f'Initialized with Pre-loaded LoRA Layer: {cur_index}')
      setattr(model, name, LinearWithLoRA(module, lora_layers[cur_index]))
      cur_index = cur_index + 1
    else:
      cur_index = replace_linear_with_preloaded_lora(module, lora_layers, cur_index)
  
  return cur_index

replace_linear_with_preloaded_lora(base_model, llama_summarizer_lora_layers.lora_layers, 0)

Now we have our integrated model (with fine-tuned LoRA weights) we can use it for summarization tasks.

import torch

def generate_content(prompt, model):
  encoded_ids = tokenizer.encode(prompt)
  encoded_ids = torch.tensor(encoded_ids, dtype=torch.int64, device=device)

  encoded_ids = encoded_ids.view(1, encoded_ids.shape[-1])

  generated_ids = model.generate(
    encoded_ids, max_length=500, pad_token_id=tokenizer.eos_token_id
  )

  return tokenizer.decode(generated_ids)[0]

model_id = 'meta-llama/Llama-3.2-1B'
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.pad_token_id = 1131
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

prompt2 = '''<|begin_of_text|> Once upon a time, there was a little bird.
The bird was white and had a pretty song. One day, the bird saw a big pit.
The pit was deep and dark.
The bird wanted to see what was inside, so it flew down to take a look.
As the bird got closer, the pit began to rise up and swallow the bird.
The bird tried to fly away, but it was too late. The pit had taken the bird down.
The moral of the story is that we should be careful when we see something dangerous.
We should not go near it, even if we are curious. It is better to be safe than sorry.
Summary:
'''

print(generate_content(prompt2, base_model))

Output of the above summarization task:

<|begin_of_text|><|begin_of_text|> Once upon a time, there was a little bird.
The bird was white and had a pretty song. One day, the bird saw a big pit.
The pit was deep and dark.
The bird wanted to see what was inside, so it flew down to take a look.
As the bird got closer, the pit began to rise up and swallow the bird.
The bird tried to fly away, but it was too late. The pit had taken the bird down.
The moral of the story is that we should be careful when we see something dangerous.
We should not go near it, even if we are curious. It is better to be safe than sorry.
Summary:
A curious little bird flies down to a deep pit and gets swallowed by it,
teaching the moral to be careful when exploring dangerous things. <|end_of_text|>

Let's look at the model architecture before and after integration with the LoRA layers.

Initial Llama base model arcitechute:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
           (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)

Updated Llama model arcitechture after integration with LoRA layers.

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): LinearWithLoRA(
            (linear): Linear(in_features=2048, out_features=2048, bias=False)
            (lora): LoRALayer()
          )
          (k_proj): LinearWithLoRA(
            (linear): Linear(in_features=2048, out_features=512, bias=False)
            (lora): LoRALayer()
          )
          (v_proj): LinearWithLoRA(
            (linear): Linear(in_features=2048, out_features=512, bias=False)
            (lora): LoRALayer()
          )
          (o_proj): LinearWithLoRA(
            (linear): Linear(in_features=2048, out_features=2048, bias=False)
            (lora): LoRALayer()
          )
        )
        (mlp): LlamaMLP(
          (gate_proj): LinearWithLoRA(
            (linear): Linear(in_features=2048, out_features=8192, bias=False)
            (lora): LoRALayer()
          )
          (up_proj): LinearWithLoRA(
            (linear): Linear(in_features=2048, out_features=8192, bias=False)
            (lora): LoRALayer()
          )
          (down_proj): LinearWithLoRA(
            (linear): Linear(in_features=8192, out_features=2048, bias=False)
            (lora): LoRALayer()
          )
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): LinearWithLoRA(
    (linear): Linear(in_features=2048, out_features=128256, bias=False)
    (lora): LoRALayer()
  )
)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for SauravP97/llama-3.2-1B-summarization-loraa-layers

Finetuned
(899)
this model

Dataset used to train SauravP97/llama-3.2-1B-summarization-loraa-layers