Spaces:
Sleeping
Sleeping
| """ | |
| CLIP-conditioned VAE. | |
| The encoder concatenates a projected CLIP text embedding with the CNN | |
| image features before the latent bottleneck. The decoder is identical | |
| to the baseline VAE. | |
| CLIP is loaded once and kept frozen; only the linear projector and the | |
| rest of the VAE are trained. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.vae import ConvDecoder | |
| CLIP_DIM = 512 # openai/clip-vit-base-patch32 text embedding size | |
| PROJ_DIM = 64 # projected text feature size | |
| class TextProjector(nn.Module): | |
| def __init__(self, in_dim: int = CLIP_DIM, out_dim: int = PROJ_DIM): | |
| super().__init__() | |
| self.fc = nn.Linear(in_dim, out_dim) | |
| def forward(self, text_emb: torch.Tensor) -> torch.Tensor: | |
| return F.relu(self.fc(text_emb), inplace=True) | |
| class ClipCondEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| z_dim: int = 4, | |
| img_channels: int = 3, | |
| proj_dim: int = PROJ_DIM, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(img_channels, 32, 4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, 64, 4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 128, 4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 256, 4, stride=2, padding=1), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.text_proj = TextProjector(CLIP_DIM, proj_dim) | |
| feat_dim = 256 * 4 * 4 + proj_dim | |
| self.fc_mu = nn.Linear(feat_dim, z_dim) | |
| self.fc_logvar = nn.Linear(feat_dim, z_dim) | |
| def forward(self, x: torch.Tensor, text_emb: torch.Tensor): | |
| img_feat = self.conv(x).flatten(1) | |
| txt_feat = self.text_proj(text_emb) | |
| h = torch.cat([img_feat, txt_feat], dim=1) | |
| return self.fc_mu(h), self.fc_logvar(h) | |
| class ClipVAE(nn.Module): | |
| def __init__(self, z_dim: int = 4, img_channels: int = 3): | |
| super().__init__() | |
| self.encoder = ClipCondEncoder(z_dim, img_channels) | |
| self.decoder = ConvDecoder(z_dim, img_channels) | |
| self.z_dim = z_dim | |
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: | |
| if self.training: | |
| std = (0.5 * logvar).exp() | |
| return mu + std * torch.randn_like(std) | |
| return mu | |
| def encode(self, x: torch.Tensor, text_emb: torch.Tensor): | |
| mu, logvar = self.encoder(x, text_emb) | |
| z = self.reparameterize(mu, logvar) | |
| return z, mu, logvar | |
| def decode(self, z: torch.Tensor) -> torch.Tensor: | |
| return self.decoder(z) | |
| def forward(self, x: torch.Tensor, text_emb: torch.Tensor): | |
| z, mu, logvar = self.encode(x, text_emb) | |
| recon = self.decode(z) | |
| return recon, mu, logvar | |
| # ------------------------------------------------------------------ | |
| # CLIP helper — load once, freeze, cache text embeddings | |
| # ------------------------------------------------------------------ | |
| _clip_model = None | |
| _clip_tokenizer = None | |
| def get_clip(device: torch.device): | |
| global _clip_model, _clip_tokenizer | |
| if _clip_model is None: | |
| from transformers import CLIPModel, CLIPTokenizer | |
| _clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| _clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| _clip_model.eval() | |
| for p in _clip_model.parameters(): | |
| p.requires_grad_(False) | |
| return _clip_model.to(device), _clip_tokenizer | |
| def encode_text(texts: list, device: torch.device) -> torch.Tensor: | |
| model, tokenizer = get_clip(device) | |
| tokens = tokenizer(texts, padding=True, return_tensors="pt").to(device) | |
| out = model.get_text_features(**tokens) | |
| # get_text_features returns a plain tensor in newer transformers | |
| emb = out if isinstance(out, torch.Tensor) else out.pooler_output | |
| emb = emb / emb.norm(dim=-1, keepdim=True) # L2-normalise | |
| return emb.float() | |