""" Extract the language model (text-only) weights from Gemma 4 multimodal safetensors. - Filters keys containing 'language_model' - Renames: model.language_model.X -> model.X - Saves as sharded safetensors (10GB per shard) - Generates model.safetensors.index.json """ import glob import json import os import torch from safetensors import safe_open from safetensors.torch import save_file SRC_DIR = "/workspace/llm/gemma-4-31B-it" DST_DIR = "/workspace/llm/gemma-4-31B-Text" MAX_SHARD_SIZE = 10 * 1024 * 1024 * 1024 # 10GB def main(): os.makedirs(DST_DIR, exist_ok=True) src_files = sorted(glob.glob(os.path.join(SRC_DIR, "*.safetensors"))) print(f"Source files: {len(src_files)}") # Step 1: Collect all language_model tensors with renamed keys all_tensors = {} for path in src_files: print(f"Reading {os.path.basename(path)}...") with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): if "language_model" in key: new_key = key.replace("model.language_model.", "model.") all_tensors[new_key] = f.get_tensor(key) print(f"Extracted {len(all_tensors)} tensors") # Step 2: Split into shards by size shards = [] current_shard = {} current_size = 0 for key in sorted(all_tensors.keys()): tensor = all_tensors[key] tensor_size = tensor.nelement() * tensor.element_size() if current_shard and current_size + tensor_size > MAX_SHARD_SIZE: shards.append(current_shard) current_shard = {} current_size = 0 current_shard[key] = tensor current_size += tensor_size if current_shard: shards.append(current_shard) print(f"Splitting into {len(shards)} shards") # Step 3: Save each shard and build weight_map total_shards = len(shards) weight_map = {} for i, shard in enumerate(shards): filename = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors" filepath = os.path.join(DST_DIR, filename) shard_size = sum(t.nelement() * t.element_size() for t in shard.values()) print(f"Saving {filename} ({shard_size / 1e9:.2f} GB, {len(shard)} tensors)...") save_file(shard, filepath) for key in shard: weight_map[key] = filename # Step 4: Write index file index = { "metadata": {"total_size": sum(t.nelement() * t.element_size() for t in all_tensors.values())}, "weight_map": weight_map, } index_path = os.path.join(DST_DIR, "model.safetensors.index.json") with open(index_path, "w") as f: json.dump(index, f, indent=2) print(f"Done! Index written to {index_path}") if __name__ == "__main__": main()