Spaces:
Running
Running
| # Copyright (C) 2025 Arcee AI | |
| # SPDX-License-Identifier: LGPL-3.0-only | |
| import logging | |
| from typing import Dict, Optional | |
| import torch | |
| from mergekit.common import ImmutableMap, ModelReference | |
| from mergekit.graph import Task | |
| from mergekit.io.tasks import GatherTensors | |
| from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo | |
| from mergekit.tokenizer.config import ( | |
| ModelTokenEmbedding, | |
| TokenEmbeddingConfig, | |
| ZeroEmbedding, | |
| ) | |
| class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]): | |
| gather_tensors: GatherTensors | |
| tokenizer_task: BuildTokenizer | |
| tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]] | |
| pad_to_multiple_of: Optional[int] | |
| base_model: Optional[ModelReference] | |
| def arguments(self) -> Dict[str, Task]: | |
| return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors} | |
| def execute( | |
| self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor] | |
| ) -> Dict[ModelReference, torch.Tensor]: | |
| tokenizer = tokenizer_info.tokenizer | |
| permutations = tokenizer_info.permutations | |
| models = set(tensors.keys()) | |
| if self.base_model: | |
| models.add(self.base_model) | |
| models = list(models) | |
| vocab = tokenizer.get_vocab() | |
| vocab_size = len(vocab) | |
| if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of: | |
| vocab_size = ( | |
| vocab_size // self.pad_to_multiple_of + 1 | |
| ) * self.pad_to_multiple_of | |
| embed_size = tensors[models[0]].shape[1] | |
| assert all( | |
| t.shape[1] == embed_size for t in tensors.values() | |
| ), "Embedding sizes must match" | |
| dtype = tensors[models[0]].dtype | |
| device = tensors[models[0]].device | |
| token_configs = dict(**(self.tokens or {})) | |
| tokens_to_average = self.assign_embedding_sources( | |
| permutations, models, vocab, token_configs | |
| ) | |
| default_embeds = {} | |
| for token, token_id in vocab.items(): | |
| embed = torch.zeros(embed_size, dtype=dtype, device=device) | |
| if token in tokens_to_average: | |
| count = 0 | |
| for model in models: | |
| p = permutations[model] | |
| if p[token_id] < 0: | |
| continue | |
| embed += tensors[model][p[token_id]] | |
| count += 1 | |
| embed /= count | |
| elif cfg := token_configs.get(token, None): | |
| cfg: TokenEmbeddingConfig | |
| embed = self.compute_default_embedding( | |
| tokenizer_info, tensors, permutations, token, token_id, cfg, | |
| embed_size, dtype, device | |
| ) | |
| else: | |
| continue | |
| default_embeds[token] = embed | |
| result = {} | |
| for model in models: | |
| p = permutations[model] | |
| old_embed = tensors[model] | |
| new_embed = torch.zeros( | |
| (vocab_size, embed_size), dtype=dtype, device=device | |
| ) | |
| for token, token_id in vocab.items(): | |
| force = False | |
| if token in token_configs: | |
| force = token_configs[token].force | |
| if p[token_id] >= 0 and not force: | |
| new_embed[token_id, :] = old_embed[p[token_id]] | |
| elif token in default_embeds: | |
| new_embed[token_id, :] = default_embeds[token] | |
| else: | |
| logging.error( | |
| f"No embedding for token {repr(token)} in model {model}!" | |
| ) | |
| if vocab_size > len(vocab): | |
| # as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html | |
| avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0) | |
| new_embed[len(vocab) :, :] = avg_embed | |
| result[model] = new_embed | |
| return result | |
| def assign_embedding_sources( | |
| self, | |
| permutations: Dict[ModelReference, Dict[int, int]], | |
| models: list[ModelReference], | |
| vocab: Dict[str, int], | |
| token_configs: Dict[str, TokenEmbeddingConfig], | |
| ): | |
| permutation_list = [permutations[model] for model in models] | |
| tokens_to_average = set() | |
| # find tokens that are only present in one model | |
| for token, token_id in vocab.items(): | |
| if token in token_configs: | |
| continue | |
| has_token = [p[token_id] >= 0 for p in permutation_list] | |
| num_present = sum(int(x) for x in has_token) | |
| if num_present == 1: | |
| donor_model = models[has_token.index(True)] | |
| token_configs[token] = TokenEmbeddingConfig(source=donor_model) | |
| continue | |
| if num_present == 0: | |
| token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding(kind="zero")) | |
| logging.warning(f"Token {repr(token)} not found in any model") | |
| continue | |
| if num_present > 0 and self.base_model is not None: | |
| if permutations[self.base_model][token_id] >= 0: | |
| token_configs[token] = TokenEmbeddingConfig(source=self.base_model) | |
| continue | |
| tokens_to_average.add(token) | |
| return tokens_to_average | |
| def compute_default_embedding( | |
| self, | |
| tokenizer_info: TokenizerInfo, | |
| tensors: Dict[ModelReference, torch.Tensor], | |
| permutations: Dict[ModelReference, Dict[int, int]], | |
| token: str, | |
| token_id: int, | |
| cfg: TokenEmbeddingConfig, | |
| embed_size: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| if isinstance(cfg.source, ZeroEmbedding): | |
| embed = torch.zeros(embed_size, dtype=dtype, device=device) | |
| pass | |
| elif isinstance(cfg.source, ModelTokenEmbedding): | |
| model = cfg.source.model | |
| assert ( | |
| model in permutations | |
| ), f"Model {model} referenced but not part of merge" | |
| p = permutations[model] | |
| src_token_id = cfg.source.token_id | |
| if src_token_id is None: | |
| src_token = cfg.source.token | |
| assert ( | |
| src_token in tokenizer_info.original_vocabs[model] | |
| ), f"Token {repr(src_token)} not found in model {model}" | |
| src_token_id = tokenizer_info.original_vocabs[model][src_token] | |
| assert ( | |
| src_token_id >= 0 and src_token_id < tensors[model].shape[0] | |
| ), f"Token ID {src_token_id} out of range for model {model}" | |
| embed = tensors[model][src_token_id] | |
| elif isinstance(cfg.source, ModelReference): | |
| model = cfg.source | |
| p = permutations[model] | |
| assert p[token_id] >= 0, f"Token {repr(token)} not found in model {model}" | |
| embed = tensors[model][p[token_id]] | |
| else: | |
| raise NotImplementedError(cfg) | |
| return embed | |