| """ |
| Text encoder for conditioning the diffusion model |
| Uses a simple transformer architecture |
| """ |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """Sinusoidal positional encoding""" |
| def __init__(self, d_model: int, max_len: int = 5000): |
| super().__init__() |
| |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| |
| self.register_buffer('pe', pe) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.pe[:, :x.size(1)] |
|
|
|
|
| class TransformerEncoderLayer(nn.Module): |
| """Single transformer encoder layer""" |
| def __init__( |
| self, |
| d_model: int, |
| num_heads: int, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| |
| self.self_attn = nn.MultiheadAttention( |
| d_model, num_heads, dropout=dropout, batch_first=True |
| ) |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
| |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| |
| x2, _ = self.self_attn(x, x, x, key_padding_mask=mask) |
| x = x + self.dropout1(x2) |
| x = self.norm1(x) |
| |
| |
| x2 = self.linear2(self.dropout(F.gelu(self.linear1(x)))) |
| x = x + self.dropout2(x2) |
| x = self.norm2(x) |
| |
| return x |
|
|
|
|
| class TextEncoder(nn.Module): |
| """ |
| Transformer-based text encoder for conditioning |
| Similar to CLIP text encoder but simplified |
| """ |
| def __init__( |
| self, |
| vocab_size: int = 49408, |
| max_length: int = 77, |
| embed_dim: int = 512, |
| num_layers: int = 6, |
| num_heads: int = 8, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| |
| self.vocab_size = vocab_size |
| self.max_length = max_length |
| self.embed_dim = embed_dim |
| |
| |
| self.token_embedding = nn.Embedding(vocab_size, embed_dim) |
| |
| |
| self.pos_encoding = PositionalEncoding(embed_dim, max_length) |
| |
| |
| self.layers = nn.ModuleList([ |
| TransformerEncoderLayer( |
| d_model=embed_dim, |
| num_heads=num_heads, |
| dim_feedforward=embed_dim * 4, |
| dropout=dropout, |
| ) |
| for _ in range(num_layers) |
| ]) |
| |
| |
| self.final_norm = nn.LayerNorm(embed_dim) |
| |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """Initialize weights""" |
| nn.init.normal_(self.token_embedding.weight, std=0.02) |
| |
| def forward( |
| self, |
| tokens: torch.Tensor, |
| return_pooled: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Forward pass |
| Args: |
| tokens: Token IDs (B, seq_len) |
| return_pooled: Whether to return pooled output (first token) |
| Returns: |
| Text embeddings (B, seq_len, embed_dim) or (B, embed_dim) if pooled |
| """ |
| |
| x = self.token_embedding(tokens) |
| |
| |
| x = self.pos_encoding(x) |
| |
| |
| padding_mask = (tokens == 2) |
| |
| |
| for layer in self.layers: |
| x = layer(x, mask=padding_mask) |
| |
| |
| x = self.final_norm(x) |
| |
| if return_pooled: |
| |
| return x[:, 0] |
| |
| return x |
|
|
|
|
| class FrozenCLIPTextEncoder(nn.Module): |
| """ |
| Wrapper for using pretrained CLIP text encoder (if available) |
| Falls back to custom TextEncoder if CLIP is not available |
| """ |
| def __init__( |
| self, |
| embed_dim: int = 512, |
| max_length: int = 77, |
| ): |
| super().__init__() |
| |
| self.embed_dim = embed_dim |
| self.max_length = max_length |
| |
| try: |
| from transformers import CLIPTextModel, CLIPTokenizer |
| |
| self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
| self.model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
| |
| |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| |
| |
| clip_dim = self.model.config.hidden_size |
| if clip_dim != embed_dim: |
| self.proj = nn.Linear(clip_dim, embed_dim) |
| else: |
| self.proj = nn.Identity() |
| |
| self.use_clip = True |
| print("Using pretrained CLIP text encoder") |
| |
| except Exception as e: |
| print(f"CLIP not available ({e}), using custom text encoder") |
| self.model = TextEncoder( |
| embed_dim=embed_dim, |
| max_length=max_length, |
| ) |
| self.proj = nn.Identity() |
| self.use_clip = False |
| |
| def forward( |
| self, |
| tokens: torch.Tensor, |
| text: Optional[list] = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass |
| Args: |
| tokens: Pre-tokenized token IDs (B, seq_len) - used if not using CLIP |
| text: List of text strings - used if using CLIP |
| Returns: |
| Text embeddings (B, seq_len, embed_dim) |
| """ |
| if self.use_clip and text is not None: |
| |
| inputs = self.tokenizer( |
| text, |
| padding="max_length", |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| hidden_states = outputs.last_hidden_state |
| |
| return self.proj(hidden_states) |
| else: |
| return self.proj(self.model(tokens)) |
|
|
|
|
| def create_text_encoder(config, use_clip: bool = True): |
| """Create text encoder from config (default: pretrained CLIP)""" |
| if use_clip: |
| return FrozenCLIPTextEncoder( |
| embed_dim=config.text_embed_dim, |
| max_length=config.max_text_length, |
| ) |
| else: |
| return TextEncoder( |
| vocab_size=config.vocab_size, |
| max_length=config.max_text_length, |
| embed_dim=config.text_embed_dim, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| encoder = TextEncoder( |
| vocab_size=49408, |
| max_length=77, |
| embed_dim=512, |
| num_layers=6, |
| num_heads=8, |
| ) |
| |
| |
| tokens = torch.randint(0, 49408, (2, 77)) |
| |
| |
| output = encoder(tokens) |
| print(f"Input shape: {tokens.shape}") |
| print(f"Output shape: {output.shape}") |
| print(f"Parameters: {sum(p.numel() for p in encoder.parameters()):,}") |
|
|