| from __future__ import annotations |
|
|
| """sedd_wrapper.py |
| ========================================= |
| This module provides a minimal HuggingFace-compatible wrapper around the |
| `SEDD` architecture that is implemented in :pyfile:`model/transformer.py`. |
| |
| The wrapper closely follows the design used in the Aero implementation that |
| lives in this code-base (see :pyfile:`configuration_aero.py` and |
| :pyfile:`modeling_aero.py`). Concretely we expose three public objects: |
| |
| * ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that |
| stores the hyper-parameters needed to instantiate a ``SEDD`` model. |
| * ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that |
| internally contains an instance of the original ``SEDD`` network and maps |
| from ``input_ids`` + ``sigma`` to the vocabulary logits. |
| * ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput` |
| dataclass that mirrors the usual "logits / loss" structure. |
| |
| With this wrapper a trained model checkpoint can be pushed to / loaded from |
| 🤗 Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same |
| way as any other ``transformers`` model. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple, List, Dict, Any, Union |
|
|
| import torch |
| from torch import nn |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
|
|
| |
| from model.transformer import SEDD as _OrigSEDD |
|
|
| try: |
| from omegaconf import OmegaConf |
| except ImportError: |
| OmegaConf = None |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| |
| |
|
|
|
|
| class SEDDConfig(PretrainedConfig): |
| """Configuration class for the SEDD architecture. |
| |
| The defaults reproduce *roughly* the "small" configuration shipped in |
| ``configs/model/small.yaml``. Additional keys that are present in the |
| original Hydra config but not required for instantiation (e.g. *training* |
| hyper-parameters) are deliberately omitted here – they can still be stored |
| as *extra* fields in the underlying JSON if a user wishes to preserve them. |
| """ |
|
|
| model_type: str = "sedd" |
|
|
| def __init__( |
| self, |
| *, |
| tokens: int = 50257, |
| |
| graph_type: str = "absorb", |
| |
| model_hidden_size: int = 768, |
| model_cond_dim: int = 128, |
| model_length: int = 1024, |
| model_n_blocks: int = 12, |
| model_n_heads: int = 12, |
| model_scale_by_sigma: bool = True, |
| model_dropout: float = 0.10, |
| |
| tie_word_embeddings: bool = False, |
| **kwargs, |
| ) -> None: |
| super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
|
|
| |
| self.tokens = tokens |
| self.graph_type = graph_type |
|
|
| |
| self.model_hidden_size = model_hidden_size |
| self.model_cond_dim = model_cond_dim |
| self.model_length = model_length |
| self.model_n_blocks = model_n_blocks |
| self.model_n_heads = model_n_heads |
| self.model_scale_by_sigma = model_scale_by_sigma |
| self.model_dropout = model_dropout |
|
|
| |
| |
| |
| |
|
|
| def to_hydra(self): |
| """Convert this *flat* config to the nested OmegaConf structure that |
| the reference ``SEDD`` implementation expects. |
| """ |
|
|
| if OmegaConf is None: |
| raise RuntimeError("`omegaconf` is required to build a Hydra config") |
|
|
| nested: Dict[str, Any] = { |
| "tokens": self.tokens, |
| "graph": { |
| "type": self.graph_type, |
| }, |
| "model": { |
| "hidden_size": self.model_hidden_size, |
| "cond_dim": self.model_cond_dim, |
| "length": self.model_length, |
| "n_blocks": self.model_n_blocks, |
| "n_heads": self.model_n_heads, |
| "scale_by_sigma": self.model_scale_by_sigma, |
| "dropout": self.model_dropout, |
| }, |
| } |
| return OmegaConf.create(nested) |
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class SEDDOutput(ModelOutput): |
| """Standard output for :class:`SEDDModel`. |
| |
| Attributes |
| ---------- |
| loss: |
| *Optional* scalar returned when ``labels`` are provided. |
| logits: |
| The raw vocabulary logits computed by the model of shape |
| ``(batch_size, sequence_length, vocab_size)``. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor | None = None |
|
|
| |
| |
| |
|
|
|
|
| class SEDDModel(PreTrainedModel): |
| """HuggingFace *Transformers* wrapper around the original ``SEDD`` model.""" |
|
|
| config_class = SEDDConfig |
| base_model_prefix = "score_model" |
| _no_split_modules: List[str] = [ |
| "DDiTBlock", |
| ] |
|
|
| def __init__(self, config: SEDDConfig): |
| super().__init__(config) |
|
|
| |
| |
| |
| |
| if OmegaConf is None: |
| raise RuntimeError("`omegaconf` is required to instantiate SEDD") |
|
|
| hydra_cfg = config.to_hydra() |
| self.score_model = _OrigSEDD(hydra_cfg) |
|
|
| |
| self.post_init() |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| sigma: torch.FloatTensor, |
| labels: Optional[torch.LongTensor] = None, |
| **kwargs: Any, |
| ) -> Union[SEDDOutput, Tuple]: |
| """Run a forward pass. |
| |
| Parameters |
| ---------- |
| input_ids: |
| Token indices of shape ``(batch_size, seq_len)``. |
| sigma: |
| Noise level ("time-step") of shape ``(batch_size,)``. |
| labels: |
| *Optional* label tensor used to compute a cross-entropy training |
| loss. If provided the returned :class:`SEDDOutput` will contain a |
| ``loss`` field. |
| """ |
|
|
| logits = self.score_model(indices=input_ids, sigma=sigma) |
|
|
| loss: Optional[torch.Tensor] = None |
| if labels is not None: |
| |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
|
|
| if not self.config.return_dict: |
| output: Tuple[Any, ...] = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SEDDOutput(loss=loss, logits=logits) |
|
|
| |
| |
| |
| |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| *model_args: Any, |
| **kwargs: Any, |
| ) -> "SEDDModel": |
| """Overrides the default method to allow loading legacy SEDD checkpoints |
| whose weights are saved via ``torch.save({'model': state_dict, ...})``. |
| """ |
|
|
| try: |
| |
| |
| |
| return super().from_pretrained( |
| pretrained_model_name_or_path, *model_args, **kwargs |
| ) |
| except (EnvironmentError, RuntimeError) as e: |
| logger.info( |
| "Falling back to legacy SEDD checkpoint format because standard " |
| "loading raised: %s", e, |
| ) |
|
|
| |
| |
| |
| config = kwargs.pop("config", None) or SEDDConfig.from_pretrained( |
| pretrained_model_name_or_path |
| ) |
| model = cls(config, *model_args, **kwargs) |
|
|
| |
| |
| |
| import os |
| import torch as _torch |
|
|
| checkpoint_path = os.path.join( |
| pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth" |
| ) |
| if not os.path.isfile(checkpoint_path): |
| raise FileNotFoundError( |
| "Could not find legacy SEDD checkpoint at " f"{checkpoint_path}" |
| ) |
|
|
| ckpt = _torch.load(checkpoint_path, map_location="cpu") |
| state_dict = ckpt.get("model", ckpt) |
| |
| state_dict = { |
| k.replace("module.", ""): v for k, v in state_dict.items() |
| } |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| logger.warning("Missing keys when loading SEDD weights: %s", missing) |
| if unexpected: |
| logger.warning( |
| "Unexpected keys when loading SEDD weights: %s", unexpected |
| ) |
| return model |
|
|
| |
| |
| |
|
|
| __all__ = [ |
| "SEDDConfig", |
| "SEDDModel", |
| "SEDDOutput", |
| ] |
|
|