singhanshuman commited on
Commit
619b31e
·
verified ·
1 Parent(s): b3a58bf

Upload models/clip_vae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/clip_vae.py +116 -0
models/clip_vae.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLIP-conditioned VAE.
3
+
4
+ The encoder concatenates a projected CLIP text embedding with the CNN
5
+ image features before the latent bottleneck. The decoder is identical
6
+ to the baseline VAE.
7
+
8
+ CLIP is loaded once and kept frozen; only the linear projector and the
9
+ rest of the VAE are trained.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from models.vae import ConvDecoder
17
+
18
+ CLIP_DIM = 512 # openai/clip-vit-base-patch32 text embedding size
19
+ PROJ_DIM = 64 # projected text feature size
20
+
21
+
22
+ class TextProjector(nn.Module):
23
+ def __init__(self, in_dim: int = CLIP_DIM, out_dim: int = PROJ_DIM):
24
+ super().__init__()
25
+ self.fc = nn.Linear(in_dim, out_dim)
26
+
27
+ def forward(self, text_emb: torch.Tensor) -> torch.Tensor:
28
+ return F.relu(self.fc(text_emb), inplace=True)
29
+
30
+
31
+ class ClipCondEncoder(nn.Module):
32
+ def __init__(
33
+ self,
34
+ z_dim: int = 4,
35
+ img_channels: int = 3,
36
+ proj_dim: int = PROJ_DIM,
37
+ ):
38
+ super().__init__()
39
+ self.conv = nn.Sequential(
40
+ nn.Conv2d(img_channels, 32, 4, stride=2, padding=1),
41
+ nn.ReLU(inplace=True),
42
+ nn.Conv2d(32, 64, 4, stride=2, padding=1),
43
+ nn.ReLU(inplace=True),
44
+ nn.Conv2d(64, 128, 4, stride=2, padding=1),
45
+ nn.ReLU(inplace=True),
46
+ nn.Conv2d(128, 256, 4, stride=2, padding=1),
47
+ nn.ReLU(inplace=True),
48
+ )
49
+ self.text_proj = TextProjector(CLIP_DIM, proj_dim)
50
+ feat_dim = 256 * 4 * 4 + proj_dim
51
+ self.fc_mu = nn.Linear(feat_dim, z_dim)
52
+ self.fc_logvar = nn.Linear(feat_dim, z_dim)
53
+
54
+ def forward(self, x: torch.Tensor, text_emb: torch.Tensor):
55
+ img_feat = self.conv(x).flatten(1)
56
+ txt_feat = self.text_proj(text_emb)
57
+ h = torch.cat([img_feat, txt_feat], dim=1)
58
+ return self.fc_mu(h), self.fc_logvar(h)
59
+
60
+
61
+ class ClipVAE(nn.Module):
62
+ def __init__(self, z_dim: int = 4, img_channels: int = 3):
63
+ super().__init__()
64
+ self.encoder = ClipCondEncoder(z_dim, img_channels)
65
+ self.decoder = ConvDecoder(z_dim, img_channels)
66
+ self.z_dim = z_dim
67
+
68
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
69
+ if self.training:
70
+ std = (0.5 * logvar).exp()
71
+ return mu + std * torch.randn_like(std)
72
+ return mu
73
+
74
+ def encode(self, x: torch.Tensor, text_emb: torch.Tensor):
75
+ mu, logvar = self.encoder(x, text_emb)
76
+ z = self.reparameterize(mu, logvar)
77
+ return z, mu, logvar
78
+
79
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
80
+ return self.decoder(z)
81
+
82
+ def forward(self, x: torch.Tensor, text_emb: torch.Tensor):
83
+ z, mu, logvar = self.encode(x, text_emb)
84
+ recon = self.decode(z)
85
+ return recon, mu, logvar
86
+
87
+
88
+ # ------------------------------------------------------------------
89
+ # CLIP helper — load once, freeze, cache text embeddings
90
+ # ------------------------------------------------------------------
91
+
92
+ _clip_model = None
93
+ _clip_tokenizer = None
94
+
95
+
96
+ def get_clip(device: torch.device):
97
+ global _clip_model, _clip_tokenizer
98
+ if _clip_model is None:
99
+ from transformers import CLIPModel, CLIPTokenizer
100
+ _clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
101
+ _clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
102
+ _clip_model.eval()
103
+ for p in _clip_model.parameters():
104
+ p.requires_grad_(False)
105
+ return _clip_model.to(device), _clip_tokenizer
106
+
107
+
108
+ @torch.no_grad()
109
+ def encode_text(texts: list, device: torch.device) -> torch.Tensor:
110
+ model, tokenizer = get_clip(device)
111
+ tokens = tokenizer(texts, padding=True, return_tensors="pt").to(device)
112
+ out = model.get_text_features(**tokens)
113
+ # get_text_features returns a plain tensor in newer transformers
114
+ emb = out if isinstance(out, torch.Tensor) else out.pooler_output
115
+ emb = emb / emb.norm(dim=-1, keepdim=True) # L2-normalise
116
+ return emb.float()