| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
|
|
|
|
| EMBEDDING_FILENAME = "embeddings.pt" |
| BACKBONE_REF_FILENAME = "backbone.json" |
| LLOPA_SPECIALS_FILENAME = "llopa_specials.pt" |
|
|
|
|
| def write_backbone_ref(best_dir: str | Path, backbone: str | None) -> None: |
| if not backbone: |
| return |
| out = Path(best_dir) |
| out.mkdir(parents=True, exist_ok=True) |
| payload = {"backbone": str(backbone)} |
| (out / BACKBONE_REF_FILENAME).write_text( |
| json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8" |
| ) |
|
|
|
|
| def read_backbone_ref(best_dir: str | Path) -> Optional[str]: |
| path = Path(best_dir) / BACKBONE_REF_FILENAME |
| if not path.is_file(): |
| return None |
| try: |
| data = json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| return None |
| val = data.get("backbone") |
| return str(val) if val else None |
|
|
|
|
| def save_embedding_layer(model, best_dir: str | Path, *, include_lm_head: bool = True) -> bool: |
| out = Path(best_dir) |
| out.mkdir(parents=True, exist_ok=True) |
| try: |
| emb = model.get_input_embeddings() |
| except Exception: |
| emb = None |
| if emb is None or not hasattr(emb, "weight"): |
| return False |
| payload: dict[str, torch.Tensor] = { |
| "input_embeddings": emb.weight.detach().cpu() |
| } |
| if include_lm_head: |
| head = getattr(model, "lm_head", None) |
| if head is None and hasattr(model, "get_output_embeddings"): |
| try: |
| head = model.get_output_embeddings() |
| except Exception: |
| head = None |
| if head is not None and hasattr(head, "weight"): |
| try: |
| same_storage = head.weight.data_ptr() == emb.weight.data_ptr() |
| except Exception: |
| same_storage = False |
| if not same_storage and head.weight.shape == emb.weight.shape: |
| payload["lm_head"] = head.weight.detach().cpu() |
| torch.save(payload, out / EMBEDDING_FILENAME) |
| return True |
|
|
|
|
| def load_embedding_layer(model, best_dir: str | Path) -> bool: |
| path = Path(best_dir) / EMBEDDING_FILENAME |
| if not path.is_file(): |
| return False |
| try: |
| payload = torch.load(path, map_location="cpu") |
| except Exception: |
| return False |
| weight = payload.get("input_embeddings") |
| if weight is None: |
| return False |
| try: |
| cur_emb = model.get_input_embeddings() |
| except Exception: |
| cur_emb = None |
| if cur_emb is None or not hasattr(cur_emb, "weight"): |
| return False |
| if cur_emb.weight.shape[0] != weight.shape[0]: |
| try: |
| model.resize_token_embeddings(weight.shape[0]) |
| except Exception: |
| pass |
| cur_emb = model.get_input_embeddings() |
| cur_emb.weight.data.copy_(weight.to(cur_emb.weight.dtype)) |
| head_weight = payload.get("lm_head") |
| if head_weight is not None: |
| head = getattr(model, "lm_head", None) |
| if head is None and hasattr(model, "get_output_embeddings"): |
| try: |
| head = model.get_output_embeddings() |
| except Exception: |
| head = None |
| if head is not None and hasattr(head, "weight") and head.weight.shape == head_weight.shape: |
| head.weight.data.copy_(head_weight.to(head.weight.dtype)) |
| return True |
|
|
|
|
| def _get_llopa_specials_module(model): |
| for cand in (model, getattr(model, "base_model", None), getattr(model, "model", None), getattr(model, "transformer", None)): |
| if cand is not None and hasattr(cand, "llopa_specials"): |
| return cand |
| return None |
|
|
|
|
| def save_llopa_specials(model, best_dir: str | Path) -> bool: |
| out = Path(best_dir) |
| out.mkdir(parents=True, exist_ok=True) |
| core = _get_llopa_specials_module(model) |
| if core is None: |
| return False |
| specials = getattr(core, "llopa_specials", None) |
| if specials is None: |
| return False |
| try: |
| tensors = [p.detach().cpu() for p in specials] |
| except Exception: |
| return False |
| if not tensors: |
| return False |
| payload = { |
| "llopa_num_specials": int(getattr(core, "llopa_num_specials", 0) or 0), |
| "llopa_num_layers": len(tensors), |
| "tensors": tensors, |
| } |
| torch.save(payload, out / LLOPA_SPECIALS_FILENAME) |
| return True |
|
|
|
|
| def load_llopa_specials(model, best_dir: str | Path) -> bool: |
| path = Path(best_dir) / LLOPA_SPECIALS_FILENAME |
| if not path.is_file(): |
| return False |
| try: |
| payload = torch.load(path, map_location="cpu") |
| except Exception: |
| return False |
| tensors = payload.get("tensors") |
| if not tensors: |
| return False |
| core = _get_llopa_specials_module(model) |
| if core is None: |
| return False |
| specials = getattr(core, "llopa_specials", None) |
| if specials is None: |
| return False |
| if len(specials) != len(tensors): |
| return False |
| try: |
| for dst, src in zip(specials, tensors): |
| if dst.shape != src.shape: |
| return False |
| dst.data.copy_(src.to(dst.dtype)) |
| except Exception: |
| return False |
| return True |
|
|
|
|
| def init_llopa_specials_with_mean(model, *, chunk_size: int = 8192) -> bool: |
| core = _get_llopa_specials_module(model) |
| if core is None: |
| return False |
| specials = getattr(core, "llopa_specials", None) |
| if specials is None or len(specials) == 0: |
| return False |
|
|
| emb = None |
| try: |
| emb = model.get_input_embeddings() |
| except Exception: |
| emb = None |
| if emb is None and hasattr(core, "get_input_embeddings"): |
| try: |
| emb = core.get_input_embeddings() |
| except Exception: |
| emb = None |
| if emb is None: |
| emb = getattr(core, "embed_tokens", None) |
| if emb is None or not hasattr(emb, "weight"): |
| return False |
|
|
| weight = emb.weight.detach() |
| if weight.ndim != 2: |
| return False |
| old_num_tokens, hidden_size = weight.shape |
| if old_num_tokens <= 0: |
| return False |
|
|
| emb_cpu = weight.float().cpu() |
| mean = emb_cpu.mean(dim=0) |
| cov = torch.zeros((hidden_size, hidden_size), dtype=torch.float32) |
| step = max(1, int(chunk_size) if chunk_size else int(old_num_tokens)) |
| for start in range(0, int(old_num_tokens), step): |
| chunk = emb_cpu[start : start + step] |
| centered = chunk - mean |
| cov += centered.t().matmul(centered) |
| cov /= float(old_num_tokens) |
| epsilon = 1e-9 |
| cov_eps = cov * epsilon |
| try: |
| is_psd = torch.distributions.constraints.positive_definite.check(cov_eps).all() |
| except Exception: |
| is_psd = False |
|
|
| with torch.no_grad(): |
| if bool(is_psd): |
| dist = torch.distributions.multivariate_normal.MultivariateNormal( |
| mean, covariance_matrix=cov_eps |
| ) |
| for p in specials: |
| if p.numel() == 0: |
| continue |
| samples = dist.sample(sample_shape=(p.shape[0],)) |
| p.copy_(samples.to(dtype=p.dtype, device=p.device)) |
| else: |
| mean_row = mean.unsqueeze(0) |
| for p in specials: |
| if p.numel() == 0: |
| continue |
| p.copy_(mean_row.expand_as(p).to(dtype=p.dtype, device=p.device)) |
| return True |
|
|