| import torch |
| from abc import ABC |
| from torch import nn |
| from torch.nn import functional as F |
| from .diffusion import DiffusionLoss, DDIMSampler, DDPMSampler |
| from .transformer import TransformerModel |
| from .mamba import MambaModel |
| from .lstm import LstmModel |
| from .gatemlp import GMLPModel |
|
|
|
|
|
|
|
|
| class ModelDiffusion(nn.Module, ABC): |
| config = {} |
|
|
| def __init__(self, sequence_length): |
| super().__init__() |
| DiffusionLoss.config = self.config |
| self.criteria = DiffusionLoss() |
| if self.config.get("post_d_model") is None: |
| assert self.config["d_model"] == self.config["condition_dim"] |
| self.sequence_length = sequence_length |
| |
| self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
| self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
| self.to_permutation_state.weight = \ |
| nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"]) |
|
|
| def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs): |
| |
| if condition is not None: |
| assert len(condition.shape) == 2 |
| assert condition.shape[-1] == self.config["d_condition"] |
| condition = self.to_condition(condition.to(self.device)[:, None, :]) |
| else: |
| condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device)) |
| |
| if kwargs.get("sample"): |
| if permutation_state is not False: |
| permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device) |
| permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
| else: |
| permutation_state = 0. |
| return self.sample(x=None, condition=condition+permutation_state) |
| else: |
| if permutation_state is not None: |
| permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
| else: |
| permutation_state = 0. |
| |
| c = self.model(output_shape, condition+permutation_state) |
| loss = self.criteria(x=x_0, c=c, **kwargs) |
| return loss |
|
|
| @torch.no_grad() |
| def sample(self, x=None, condition=None): |
| z = self.model([1, self.sequence_length, self.config["d_model"]], condition) |
| if x is None: |
| x = torch.randn((1, self.sequence_length, self.config["model_dim"]), device=z.device) |
| x = self.criteria.sample(x, z) |
| return x |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
|
|
| class ModelMSELoss(nn.Module, ABC): |
| config = {} |
|
|
| def __init__(self, sequence_length): |
| super().__init__() |
| if self.config.get("post_d_model") is None: |
| assert self.config["d_model"] == self.config["condition_dim"] |
| self.sequence_length = sequence_length |
| |
| self.to_condition = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
| self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
| self.to_permutation_state.weight = \ |
| nn.Parameter(torch.ones_like(self.to_permutation_state.weight) / self.config["d_model"]) |
|
|
| def forward(self, output_shape=None, x_0=None, condition=None, permutation_state=None, **kwargs): |
| |
| if condition is not None: |
| assert len(condition.shape) == 2 |
| assert condition.shape[-1] == self.config["d_condition"] |
| condition = self.to_condition(condition.to(self.device)[:, None, :]) |
| else: |
| condition = self.to_condition(torch.zeros(size=(1, 1, 1), device=self.device)) |
| |
| if kwargs.get("sample"): |
| if permutation_state is not False: |
| permutation_state = torch.randint(0, self.to_permutation_state.num_embeddings, (1,), device=self.device) |
| permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
| else: |
| permutation_state = 0. |
| return self.sample(x=None, condition=condition+permutation_state) |
| else: |
| if permutation_state is not None: |
| permutation_state = self.to_permutation_state(permutation_state)[:, None, :] |
| else: |
| permutation_state = 0. |
| |
| c = self.model(output_shape, condition+permutation_state) |
| assert c.shape[-1] == x_0.shape[-1], "d_model should be equal to dim_per_token" |
| |
| mask = torch.isnan(x_0) |
| x_0 = torch.nan_to_num(x_0, 0.) |
| |
| loss = F.mse_loss(c, x_0, reduction="none") |
| loss[mask] = torch.nan |
| return loss.nanmean() |
|
|
| @torch.no_grad() |
| def sample(self, x=None, condition=None): |
| z = self.model([1, self.sequence_length, self.config["d_model"]], condition) |
| return z |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
|
|
|
|
|
|
| class MambaDiffusion(ModelDiffusion): |
| def __init__(self, sequence_length, positional_embedding): |
| super().__init__(sequence_length=sequence_length) |
| MambaModel.config = self.config |
| self.model = MambaModel(positional_embedding=positional_embedding) |
|
|
|
|
| class TransformerDiffusion(ModelDiffusion): |
| def __init__(self, sequence_length, positional_embedding): |
| super().__init__(sequence_length=sequence_length) |
| TransformerModel.config = self.config |
| self.model = TransformerModel(positional_embedding=positional_embedding) |
|
|
|
|
| class LstmDiffusion(ModelDiffusion): |
| def __init__(self, sequence_length, positional_embedding): |
| super().__init__(sequence_length=sequence_length) |
| LstmModel.config = self.config |
| self.model = LstmModel(positional_embedding=positional_embedding) |
|
|
|
|
| class GMLPDiffusion(ModelDiffusion): |
| def __init__(self, sequence_length, positional_embedding): |
| super().__init__(sequence_length=sequence_length) |
| GMLPModel.config = self.config |
| self.model = GMLPModel(positional_embedding=positional_embedding) |
|
|
|
|
|
|
|
|
| class MambaMSELoss(ModelMSELoss): |
| def __init__(self, sequence_length, positional_embedding): |
| super().__init__(sequence_length=sequence_length) |
| MambaModel.config = self.config |
| self.model = MambaModel(positional_embedding=positional_embedding) |
|
|
|
|
|
|
|
|
| class ClassConditionMambaDiffusion(MambaDiffusion): |
| def __init__(self, sequence_length, positional_embedding, input_class=10): |
| super().__init__(sequence_length, positional_embedding) |
| self.get_condition = nn.Sequential( |
| nn.Linear(input_class, self.config["d_condition"]), |
| nn.SiLU(), |
| ) |
| self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
| |
| self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
| to_condition_gate = torch.zeros(size=(1, sequence_length, 1)) |
| to_condition_gate[:, -8:, :] = 1. |
| self.register_buffer("to_condition_gate", to_condition_gate) |
| |
| del self.to_condition |
| self.to_condition = self._to_condition |
|
|
| def forward(self, output_shape=None, x_0=None, condition=None, **kwargs): |
| condition = self.get_condition(condition.to(self.device)) |
| return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs) |
|
|
| def _to_condition(self, x): |
| assert len(x.shape) == 3 |
| x = self.to_condition_linear(x) |
| x = x * self.to_condition_gate |
| return x |
|
|
|
|
| class ClassConditionMambaDiffusionFull(MambaDiffusion): |
| def __init__(self, sequence_length, positional_embedding, input_class=10, init_noise_intensity=1e-4): |
| super().__init__(sequence_length, positional_embedding) |
| self.get_condition = nn.Sequential( |
| nn.Linear(input_class, self.config["d_condition"]), |
| nn.LayerNorm(self.config["d_condition"]), |
| ) |
| self.to_permutation_state = nn.Embedding(self.config["num_permutation"], self.config["d_model"]) |
| |
| self.to_condition_linear = nn.Linear(self.config["d_condition"], self.config["d_model"]) |
| self.to_condition_conv = nn.Sequential( |
| nn.Conv1d(1, sequence_length, 9, 1, 4), |
| nn.GroupNorm(num_groups=1, num_channels=sequence_length), |
| nn.Conv1d(sequence_length, sequence_length, 9, 1, 4), |
| ) |
| |
| del self.to_condition |
|
|
| def forward(self, output_shape=None, x_0=None, condition=None, **kwargs): |
| if kwargs.get("pre_training"): |
| self.to_condition = self._zero_condition |
| condition = None |
| else: |
| self.to_condition = self._to_condition |
| condition = self.get_condition(condition.to(self.device)) |
| return super().forward(output_shape=output_shape, x_0=x_0, condition=condition, **kwargs) |
|
|
| def _to_condition(self, x): |
| assert len(x.shape) == 3 |
| x = self.to_condition_linear(x) |
| x = self.to_condition_conv(x) |
| return x |
|
|
| def _zero_condition(self, x): |
| return torch.zeros(size=(x.shape[0], self.sequence_length, self.config["d_model"]), device=x.device) |