Remove vision from from TranslateGemma 4B
from google.colab import drive
drive.mount('/content/drive')
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q accelerate
!pip install --upgrade huggingface_hub transformers timm
import torch
import gc
import copy
import os
from google.colab import drive
from huggingface_hub import login
from transformers import (
AutoModel,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer
)
# Configuration
SOURCE_MODEL_ID = "keisuke-miyako/translategemma-4b-it"
OUTPUT_DIR = "/content/drive/MyDrive/translategemma-4b-it-novision"
def convert_gemma3_multimodal_to_text():
print(f"Loading config from {SOURCE_MODEL_ID}...")
# 1. Prepare the Configuration
full_config = AutoConfig.from_pretrained(SOURCE_MODEL_ID, trust_remote_code=True)
if not hasattr(full_config, 'text_config'):
raise ValueError("Config does not contain 'text_config'. Is this a Gemma 3 Multimodal model?")
# Deepcopy the text_config
new_config = copy.deepcopy(full_config.text_config)
new_config.architectures = ["Gemma3ForCausalLM"]
print("Configuration prepared.")
# 2. Load the Original Model
print(f"Loading original model weights (CPU)...")
multimodal_model = AutoModel.from_pretrained(
SOURCE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu",
trust_remote_code=True
)
# 3. Create the State Dictionary Mapping
print("Extracting Language Model weights...")
full_sd = multimodal_model.state_dict()
text_sd = {}
prefix_to_remove = "language_model."
embed_weight = None # Placeholder to store embeddings for tying later
keys_dropped = 0
keys_kept = 0
for key, value in full_sd.items():
# Drop vision tower weights
if "vision_tower" in key:
keys_dropped += 1
continue
# Map language model weights
if key.startswith(prefix_to_remove):
# Strip "language_model."
stripped_key = key[len(prefix_to_remove):]
# Identify embeddings for weight tying
if "embed_tokens.weight" in stripped_key:
embed_weight = value
# Rename logic:
# Everything in the text backbone needs to be prefixed with "model."
# e.g. "layers.0..." -> "model.layers.0..."
# e.g. "norm.weight" -> "model.norm.weight"
# Exception: if the source specifically had an "lm_head", we keep it at root.
# (But it likely won't exist due to weight tying)
if stripped_key.startswith("lm_head"):
new_key = stripped_key
else:
new_key = f"model.{stripped_key}"
text_sd[new_key] = value
keys_kept += 1
else:
keys_dropped += 1
# 4. Handle Weight Tying (The Fix for your error)
if "lm_head.weight" not in text_sd:
print("Notice: 'lm_head.weight' not found. Creating it from embeddings (Weight Tying).")
if embed_weight is not None:
text_sd["lm_head.weight"] = embed_weight
else:
raise ValueError("Could not find embedding weights to tie to lm_head!")
print(f"Extraction complete. Kept {keys_kept} keys. Dropped {keys_dropped} keys.")
# 5. Clean up Memory
del multimodal_model
del full_sd
gc.collect()
# 6. Create the Text-Only Model
print("Instantiating new Text-Only model...")
text_model = AutoModelForCausalLM.from_config(new_config)
# Load the filtered weights
print("Loading extracted weights into new architecture...")
text_model.load_state_dict(text_sd, strict=True)
text_model.to(dtype=torch.bfloat16)
text_model.eval()
# 7. Save Model and Tokenizer
print(f"Saving model to {OUTPUT_DIR}...")
text_model.save_pretrained(OUTPUT_DIR)
print("Saving tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(SOURCE_MODEL_ID)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"SUCCESS! Text-only model saved to: {OUTPUT_DIR}")
# Run the conversion
convert_gemma3_multimodal_to_text()
- Downloads last month
- 4
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support