Remove vision from from TranslateGemma 4B

from google.colab import drive
drive.mount('/content/drive')

!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q accelerate
!pip install --upgrade huggingface_hub transformers timm

import torch
import gc
import copy
import os
from google.colab import drive
from huggingface_hub import login

from transformers import (
    AutoModel, 
    AutoConfig, 
    AutoModelForCausalLM,
    AutoTokenizer
)

# Configuration
SOURCE_MODEL_ID = "keisuke-miyako/translategemma-4b-it"
OUTPUT_DIR = "/content/drive/MyDrive/translategemma-4b-it-novision"

def convert_gemma3_multimodal_to_text():
    print(f"Loading config from {SOURCE_MODEL_ID}...")
    
    # 1. Prepare the Configuration
    full_config = AutoConfig.from_pretrained(SOURCE_MODEL_ID, trust_remote_code=True)
    
    if not hasattr(full_config, 'text_config'):
        raise ValueError("Config does not contain 'text_config'. Is this a Gemma 3 Multimodal model?")
    
    # Deepcopy the text_config
    new_config = copy.deepcopy(full_config.text_config)
    new_config.architectures = ["Gemma3ForCausalLM"]
    
    print("Configuration prepared.")

    # 2. Load the Original Model
    print(f"Loading original model weights (CPU)...")
    multimodal_model = AutoModel.from_pretrained(
        SOURCE_MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
        trust_remote_code=True
    )

    # 3. Create the State Dictionary Mapping
    print("Extracting Language Model weights...")
    full_sd = multimodal_model.state_dict()
    text_sd = {}
    
    prefix_to_remove = "language_model."
    embed_weight = None # Placeholder to store embeddings for tying later
    
    keys_dropped = 0
    keys_kept = 0

    for key, value in full_sd.items():
        # Drop vision tower weights
        if "vision_tower" in key:
            keys_dropped += 1
            continue
            
        # Map language model weights
        if key.startswith(prefix_to_remove):
            # Strip "language_model."
            stripped_key = key[len(prefix_to_remove):]
            
            # Identify embeddings for weight tying
            if "embed_tokens.weight" in stripped_key:
                embed_weight = value
            
            # Rename logic:
            # Everything in the text backbone needs to be prefixed with "model."
            # e.g. "layers.0..." -> "model.layers.0..."
            # e.g. "norm.weight" -> "model.norm.weight"
            
            # Exception: if the source specifically had an "lm_head", we keep it at root.
            # (But it likely won't exist due to weight tying)
            if stripped_key.startswith("lm_head"):
                new_key = stripped_key
            else:
                new_key = f"model.{stripped_key}"
            
            text_sd[new_key] = value
            keys_kept += 1
        else:
            keys_dropped += 1

    # 4. Handle Weight Tying (The Fix for your error)
    if "lm_head.weight" not in text_sd:
        print("Notice: 'lm_head.weight' not found. Creating it from embeddings (Weight Tying).")
        if embed_weight is not None:
            text_sd["lm_head.weight"] = embed_weight
        else:
            raise ValueError("Could not find embedding weights to tie to lm_head!")

    print(f"Extraction complete. Kept {keys_kept} keys. Dropped {keys_dropped} keys.")

    # 5. Clean up Memory
    del multimodal_model
    del full_sd
    gc.collect()

    # 6. Create the Text-Only Model
    print("Instantiating new Text-Only model...")
    text_model = AutoModelForCausalLM.from_config(new_config)
    
    # Load the filtered weights
    print("Loading extracted weights into new architecture...")
    text_model.load_state_dict(text_sd, strict=True)
    
    text_model.to(dtype=torch.bfloat16)
    text_model.eval()

    # 7. Save Model and Tokenizer
    print(f"Saving model to {OUTPUT_DIR}...")
    text_model.save_pretrained(OUTPUT_DIR)
    
    print("Saving tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(SOURCE_MODEL_ID)
    tokenizer.save_pretrained(OUTPUT_DIR)

    print(f"SUCCESS! Text-only model saved to: {OUTPUT_DIR}")

# Run the conversion
convert_gemma3_multimodal_to_text()
Downloads last month
4
Safetensors
Model size
4B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for keisuke-miyako/translategemma-4b-it-novision

Finetuned
(17)
this model
Quantizations
1 model