|
|
| import torch |
| from transformers import AutoModelForCausalLM, Mistral3ForConditionalGeneration, AutoTokenizer |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
| from tqdm import tqdm |
|
|
| def copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path): |
| """ |
| Copy Devstral language model weights to Mistral-Small model, |
| preserving Mistral's vision components. |
| """ |
| |
| print(f"Loading Devstral model from {devstral_id}...") |
| devstral_model = AutoModelForCausalLM.from_pretrained( |
| devstral_id, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ) |
| |
| print(f"Loading Mistral-Small model from {mistral_id}...") |
| mistral_model = Mistral3ForConditionalGeneration.from_pretrained( |
| mistral_id, |
| torch_dtype=torch.bfloat16, |
| device_map="cpu" |
| ) |
|
|
| print("Fixing generation configuration...") |
| if hasattr(mistral_model, 'generation_config') and mistral_model.generation_config is not None: |
| gen_config = mistral_model.generation_config |
| |
| |
| if hasattr(gen_config, 'do_sample') and hasattr(gen_config, 'temperature'): |
| if not gen_config.do_sample and gen_config.temperature is not None: |
| |
| gen_config.temperature = None |
| print(" - Removed temperature setting (keeping do_sample=False)") |
| |
| |
| |
| |
| |
| |
| try: |
| gen_config.validate() |
| print(" - Generation config is now valid") |
| except Exception as e: |
| print(f" - Warning: Generation config validation failed: {e}") |
|
|
| devstral_state = devstral_model.state_dict() |
| mistral_state = mistral_model.state_dict() |
|
|
| print("Copying weights from Devstral to Mistral-Small...") |
|
|
| weight_mappings = [ |
| ("model.embed_tokens.weight", "model.language_model.embed_tokens.weight"), |
| ("model.norm.weight", "model.language_model.norm.weight") |
| ] |
| for devstral_key, mistral_key in weight_mappings: |
| print(f"Copying {devstral_key} to {mistral_key}") |
| if devstral_key not in devstral_state or mistral_key not in mistral_state: |
| |
| raise KeyError(f"Missing key: {devstral_key} or {mistral_key}") |
| mistral_state[mistral_key] = devstral_state[devstral_key].clone() |
|
|
| |
| for i in tqdm(range(40), desc="Copying layer weights"): |
| layer_mappings = [ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| (f"model.layers.{i}.input_layernorm.weight", f"model.language_model.layers.{i}.input_layernorm.weight"), |
| (f"model.layers.{i}.mlp.down_proj.weight", f"model.language_model.layers.{i}.mlp.down_proj.weight"), |
| (f"model.layers.{i}.mlp.gate_proj.weight", f"model.language_model.layers.{i}.mlp.gate_proj.weight"), |
| (f"model.layers.{i}.mlp.up_proj.weight", f"model.language_model.layers.{i}.mlp.up_proj.weight"), |
| (f"model.layers.{i}.post_attention_layernorm.weight", f"model.language_model.layers.{i}.post_attention_layernorm.weight"), |
| (f"model.layers.{i}.self_attn.k_proj.weight", f"model.language_model.layers.{i}.self_attn.k_proj.weight"), |
| (f"model.layers.{i}.self_attn.o_proj.weight", f"model.language_model.layers.{i}.self_attn.o_proj.weight"), |
| (f"model.layers.{i}.self_attn.q_proj.weight", f"model.language_model.layers.{i}.self_attn.q_proj.weight"), |
| (f"model.layers.{i}.self_attn.v_proj.weight", f"model.language_model.layers.{i}.self_attn.v_proj.weight"), |
| ] |
|
|
| for devstral_key, mistral_key in layer_mappings: |
| if devstral_key not in devstral_state or mistral_key not in mistral_state: |
| raise KeyError(f"Missing key: {devstral_key} or {mistral_key}") |
| mistral_state[mistral_key] = devstral_state[devstral_key].clone() |
|
|
| print("Saving updated Mistral-Small model...") |
|
|
| mistral_model.load_state_dict(mistral_state) |
| mistral_model.save_pretrained(output_path, safe_serialization=True) |
|
|
| if __name__ == "__main__": |
| devstral_id = "mistralai/Devstral-Small-2507" |
| mistral_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" |
| output_path = "./Devstral-Vision-Small-2507" |
| |
| model = copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path) |