model_tools / embed_12B.py
Naphula's picture
Upload 5 files
6a2122d verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
## Version 2D by Naphula
import logging
from typing import Dict, Optional
import torch
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo
from mergekit.tokenizer.config import (
ModelTokenEmbedding,
TokenEmbeddingConfig,
ZeroEmbedding,
)
class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]):
gather_tensors: GatherTensors
tokenizer_task: BuildTokenizer
tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]]
pad_to_multiple_of: Optional[int]
base_model: Optional[ModelReference]
def arguments(self) -> Dict[str, Task]:
return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}
def execute(
self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor]
) -> Dict[ModelReference, torch.Tensor]:
tokenizer = tokenizer_info.tokenizer
permutations = tokenizer_info.permutations
models = set(tensors.keys())
if self.base_model:
models.add(self.base_model)
models = list(models)
vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of:
vocab_size = (
vocab_size // self.pad_to_multiple_of + 1
) * self.pad_to_multiple_of
embed_size = tensors[models[0]].shape[1]
assert all(
t.shape[1] == embed_size for t in tensors.values()
), "Embedding sizes must match"
dtype = tensors[models[0]].dtype
device = tensors[models[0]].device
token_configs = dict(**(self.tokens or {}))
tokens_to_average = self.assign_embedding_sources(
permutations, models, vocab, token_configs
)
default_embeds = {}
for token, token_id in vocab.items():
embed = torch.zeros(embed_size, dtype=dtype, device=device)
if token in tokens_to_average:
count = 0
for model in models:
p = permutations[model]
if p[token_id] < 0:
continue
# --- AUDIT & BOUNDS CHECK ---
if p[token_id] >= tensors[model].shape[0]:
logging.warning(f"[VOCAB AUDIT] Model '{model}' is missing token '{token}' (ID: {token_id}). "
f"Donor size: {tensors[model].shape[0]}, Requested Index: {p[token_id]}. Skipping.")
continue
# ----------------------------
embed += tensors[model][p[token_id]]
count += 1
embed /= count
elif cfg := token_configs.get(token, None):
cfg: TokenEmbeddingConfig
embed = self.compute_default_embedding(
tokenizer_info, tensors, permutations, token, token_id, cfg
)
else:
continue
default_embeds[token] = embed
result = {}
for model in models:
p = permutations[model]
old_embed = tensors[model]
new_embed = torch.zeros(
(vocab_size, embed_size), dtype=dtype, device=device
)
for token, token_id in vocab.items():
force = False
if token in token_configs:
force = token_configs[token].force
if p[token_id] >= 0 and not force:
# --- BOUNDS CHECK FOR RESULT MAPPING ---
if p[token_id] < old_embed.shape[0]:
new_embed[token_id, :] = old_embed[p[token_id]]
else:
# Fallback to the averaged/default version if the donor is too small
new_embed[token_id, :] = default_embeds.get(token, torch.zeros_like(new_embed[0]))
# ---------------------------------------
elif token in default_embeds:
new_embed[token_id, :] = default_embeds[token]
else:
logging.error(
f"No embedding for token {repr(token)} in model {model}!"
)
if vocab_size > len(vocab):
# as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html
avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0)
new_embed[len(vocab) :, :] = avg_embed
result[model] = new_embed
return result
def assign_embedding_sources(
self,
permutations: Dict[ModelReference, Dict[int, int]],
models: list[ModelReference],
vocab: Dict[str, int],
token_configs: Dict[str, TokenEmbeddingConfig],
):
permutation_list = [permutations[model] for model in models]
tokens_to_average = set()
# find tokens that are only present in one model
for token, token_id in vocab.items():
if token in token_configs:
continue
has_token = [p[token_id] >= 0 for p in permutation_list]
num_present = sum(int(x) for x in has_token)
if num_present == 1:
donor_model = models[has_token.index(True)]
token_configs[token] = TokenEmbeddingConfig(source=donor_model)
continue
if num_present == 0:
token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding())
logging.warning(f"Token {repr(token)} not found in any model")
continue
if num_present > 0 and self.base_model is not None:
if permutations[self.base_model][token_id] >= 0:
token_configs[token] = TokenEmbeddingConfig(source=self.base_model)
continue
tokens_to_average.add(token)
return tokens_to_average
def compute_default_embedding(
self,
tokenizer_info: TokenizerInfo,
tensors: Dict[ModelReference, torch.Tensor],
permutations: Dict[ModelReference, Dict[int, int]],
token: str,
token_id: int,
cfg: TokenEmbeddingConfig,
) -> torch.Tensor:
if isinstance(cfg.source, ZeroEmbedding):
pass
elif isinstance(cfg.source, ModelTokenEmbedding):
model = cfg.source.model
assert (
model in permutations
), f"Model {model} referenced but not part of merge"
p = permutations[model]
src_token_id = cfg.source.token_id
if src_token_id is None:
src_token = cfg.source.token
assert (
src_token in tokenizer_info.original_vocabs[model]
), f"Token {repr(src_token)} not found in model {model}"
src_token_id = tokenizer_info.original_vocabs[model][src_token]
assert (
src_token_id >= 0 and src_token_id < tensors[model].shape[0]
), f"Token ID {src_token_id} out of range for model {model}"
embed = tensors[model][src_token_id]
elif isinstance(cfg.source, ModelReference):
model = cfg.source
p = permutations[model]
assert p[token_id] >= 0, f"Token {repr(token)} not found in model {model}"
# --- BOUNDS CHECK FOR DEFAULT EMBED ---
if p[token_id] >= tensors[model].shape[0]:
logging.warning(f"[VOCAB AUDIT] Default source model '{model}' is missing token '{token}' from its physical tensor. Falling back to zero.")
return torch.zeros(tensors[model].shape[1], dtype=tensors[model].dtype, device=tensors[model].device)
# --------------------------------------
embed = tensors[model][p[token_id]]
else:
raise NotImplementedError(cfg)
return embed