| import streamlit as st |
| import torch |
| import matplotlib.pyplot as plt |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import BertTokenizer, BertModel |
| import io |
|
|
|
|
| class ConditionalAugmentation(nn.Module): |
| def __init__(self, text_dim, projected_dim): |
| super(ConditionalAugmentation, self).__init__() |
| self.proj = nn.Linear(text_dim, projected_dim * 2) |
|
|
| def forward(self, text_embedding): |
| mu_logvar = self.proj(text_embedding) |
| mu, logvar = mu_logvar.chunk(2, dim=1) |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
|
|
|
|
| class Stage1Generator(nn.Module): |
| def __init__(self, text_embedding_dim, noise_dim, img_size): |
| super(Stage1Generator, self).__init__() |
| self.fc1 = nn.Linear(768 + noise_dim, 128 * 8 * 8) |
| self.reduced_embeddings = nn.Linear(text_embedding_dim, 128) |
| self.bn1 = nn.BatchNorm1d(128 * 8 * 8) |
| self.relu = nn.ReLU(inplace=True) |
| self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1) |
| self.bn2 = nn.BatchNorm2d(64) |
| self.upsample2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1) |
| self.bn3 = nn.BatchNorm2d(32) |
| self.upsample3 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1) |
| self.tanh = nn.Tanh() |
| self.augment = ConditionalAugmentation(768,768) |
| self.img_size = img_size |
| |
|
|
| def forward(self, text_embedding, noise): |
| |
| x = self.augment(text_embedding) |
| x = torch.cat((x, noise), dim=1) |
| x = self.relu(self.bn1(self.fc1(x))) |
| x = x.view(-1, 128, 8, 8) |
| x = self.relu(self.bn2(self.upsample1(x))) |
| x = self.relu(self.bn3(self.upsample2(x))) |
| x = self.tanh(self.upsample3(x)) |
| return x |
|
|
|
|
| stage1_generator = Stage1Generator(text_embedding_dim=768, noise_dim=100, img_size=64) |
|
|
|
|
|
|
| class Stage2Generator(nn.Module): |
| def __init__(self, text_embedding_dim, img_size): |
| super(Stage2Generator, self).__init__() |
| self.fc1 = nn.Linear(text_embedding_dim + 3 * img_size * img_size, 128 * 16 * 16) |
| self.bn1 = nn.BatchNorm1d(128 * 16 * 16) |
| self.relu = nn.ReLU(inplace=True) |
| self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1) |
| self.bn2 = nn.BatchNorm2d(64) |
| self.upsample2 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1) |
| self.bn3 = nn.BatchNorm2d(32) |
| self.upsample3 = nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1) |
| self.tanh = nn.Tanh() |
| self.augment = ConditionalAugmentation(768,768) |
| self.img_size = img_size |
|
|
| def forward(self, text_embedding, stage1_img): |
| stage1_img_flat = stage1_img.view(stage1_img.size(0), -1) |
| text_embedding = self.augment(text_embedding) |
| x = torch.cat((text_embedding, stage1_img_flat), dim=1) |
| x = self.relu(self.bn1(self.fc1(x))) |
| x = x.view(-1, 128, 16, 16) |
| x = self.relu(self.bn2(self.upsample1(x))) |
| x = self.relu(self.bn3(self.upsample2(x))) |
| x = self.tanh(self.upsample3(x)) |
| return x |
|
|
|
|
| stage2_generator = Stage2Generator(text_embedding_dim=768, img_size=64) |
| |
| stage1_generator.eval() |
| stage2_generator.eval() |
| device = 'cpu' |
| stage1_generator.load_state_dict(torch.load('stage1Generator_weights.pth',map_location=device)) |
| stage2_generator.load_state_dict(torch.load('stage2Generator_weights_UPDATED.pth',map_location=device)) |
| print("Models loaded successfully") |
| |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| bert_model = BertModel.from_pretrained('bert-base-uncased').eval() |
| |
| print("bert loaded") |
|
|
|
|
| def Tokenize(sentence): |
| encoded_input = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=64) |
| with torch.no_grad(): |
| model_output = bert_model(**encoded_input) |
| text_embedding = model_output.last_hidden_state.mean(dim=1).squeeze() |
|
|
| return text_embedding.unsqueeze(0) |
| |
|
|
|
|
| def generate_images(text_embeddings): |
| noise = torch.randn(1, 100) |
| with torch.no_grad(): |
| Image_stage1 = stage1_generator(text_embeddings,noise) |
| Image_stage2 = stage2_generator(text_embeddings,Image_stage1) |
| print(Image_stage2.shape) |
| return Image_stage2.squeeze() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
| |