Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from models.transformer_model import GraphTransformer | |
| from diffusion.noise_schedule import DiscreteUniformTransition, PredefinedNoiseScheduleDiscrete | |
| from diffusion import diffusion_utils | |
| import utils | |
| import networkx as nx | |
| from sentence_transformers import SentenceTransformer | |
| import pytorch_lightning as pl | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| class LGGMText2Graph_Demo(pl.LightningModule): | |
| def __init__(self, cfg, input_dims, output_dims, cond_dims, cond_emb, \ | |
| nodes_dist, node_types, edge_types, extra_features, data_loaders): | |
| super().__init__() | |
| nodes_dist = nodes_dist | |
| self.cfg = cfg | |
| self.T = cfg.model.diffusion_steps | |
| self.Xdim = input_dims['X'] | |
| self.Edim = input_dims['E'] | |
| self.ydim = input_dims['y'] | |
| self.Xdim_output = output_dims['X'] | |
| self.Edim_output = output_dims['E'] | |
| self.ydim_output = output_dims['y'] | |
| self.node_dist = nodes_dist | |
| self.extra_features = extra_features | |
| self.model = GraphTransformer(n_layers=cfg.model.n_layers, | |
| input_dims=input_dims, | |
| hidden_mlp_dims=cfg.model.hidden_mlp_dims, | |
| hidden_dims=cfg.model.hidden_dims, | |
| output_dims=output_dims, | |
| cond_dims = cond_dims, | |
| act_fn_in=nn.ReLU(), | |
| act_fn_out=nn.ReLU()).to(self.device) | |
| self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule, | |
| timesteps=cfg.model.diffusion_steps).to(self.device) | |
| self.transition_model = DiscreteUniformTransition(x_classes=self.Xdim_output, e_classes=self.Edim_output, | |
| y_classes=self.ydim_output) | |
| x_limit = torch.ones(self.Xdim_output) / self.Xdim_output | |
| e_limit = torch.ones(self.Edim_output) / self.Edim_output | |
| y_limit = torch.ones(self.ydim_output) / self.ydim_output | |
| self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit) | |
| def generate_basic(self, text, num_nodes) -> None: | |
| print(num_nodes) | |
| prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device) | |
| samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes) | |
| nx_graphs = [] | |
| for graph in samples: | |
| node_types, edge_types = graph | |
| A = edge_types.bool().cpu().numpy() | |
| nx_graph = nx.from_numpy_array(A) | |
| nx_graphs.append(nx_graph) | |
| return nx_graphs | |
| def generate_pretrained(self, text, num_nodes) -> None: | |
| encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512) | |
| encoded_input = {key: val.to(self.text_encoder.device) for key, val in encoded_input.items()} | |
| # Get the model output | |
| with torch.no_grad(): | |
| prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0] | |
| samples = self.sample_batch(3, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes) | |
| nx_graphs = [] | |
| for graph in samples: | |
| node_types, edge_types = graph | |
| A = edge_types.bool().cpu().numpy() | |
| nx_graph = nx.from_numpy_array(A) | |
| nx_graphs.append(nx_graph) | |
| return nx_graphs | |
| def init_prompt_encoder_basic(self): | |
| self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2") | |
| def init_prompt_encoder_pretrained(self): | |
| model_name = f"./checkpoint-900" # or "bert-base-uncased" if starting from the base model | |
| self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| self.text_encoder = BertForSequenceClassification.from_pretrained(model_name, num_labels=8, output_hidden_states=True, device_map = 'cpu') | |
| def sample_batch(self, batch_size: int, cond_emb = None, num_nodes = None): | |
| """ | |
| :param batch_id: int | |
| :param batch_size: int | |
| :param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes | |
| :param save_final: int: number of predictions to save to file | |
| :param keep_chain: int: number of chains to save to file | |
| :param keep_chain_steps: number of timesteps to save for each chain | |
| :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions) | |
| """ | |
| if num_nodes is None: | |
| n_nodes = self.node_dist.sample_n(batch_size, self.device) | |
| elif type(num_nodes) == int: | |
| n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int) | |
| n_max = torch.max(n_nodes).item() | |
| # Build the masks | |
| arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1) | |
| node_mask = arange < n_nodes.unsqueeze(1) | |
| # Sample noise -- z has size (n_samples, n_nodes, n_features) | |
| z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask, transition=self.cfg.model.transition) | |
| X, E, y = z_T.X, z_T.E, z_T.y | |
| # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | |
| for s_int in tqdm(reversed(range(0, self.T))): | |
| s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | |
| t_array = s_array + 1 | |
| s_norm = s_array / self.T | |
| t_norm = t_array / self.T | |
| # Sample z_s | |
| sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask, cond_emb) | |
| X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
| # Sample | |
| sampled_s = sampled_s.mask(node_mask, collapse=True) | |
| X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
| graph_list = [] | |
| for i in range(batch_size): | |
| n = n_nodes[i] | |
| node_types = X[i, :n].cpu() | |
| edge_types = E[i, :n, :n].cpu() | |
| graph_list.append([node_types, edge_types]) | |
| return graph_list | |
| def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask, cond_emb): | |
| """Samples from zs ~ p(zs | zt). Only used during sampling. | |
| if last_step, return the graph prediction as well""" | |
| bs, n, dxs = X_t.shape | |
| beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) | |
| alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) | |
| alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) | |
| # Retrieve transitions matrix | |
| Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) | |
| Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device) | |
| Qt = self.transition_model.get_Qt(beta_t, self.device) | |
| noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask, 'cond_emb': cond_emb.repeat(X_t.shape[0], 1)} | |
| extra_data = self.compute_extra_data(noisy_data) | |
| pred = self.forward(noisy_data, extra_data, node_mask) | |
| # Normalize predictions | |
| pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 | |
| pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 | |
| p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t, | |
| Qt=Qt.X, | |
| Qsb=Qsb.X, | |
| Qtb=Qtb.X) | |
| p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t, | |
| Qt=Qt.E, | |
| Qsb=Qsb.E, | |
| Qtb=Qtb.E) | |
| # Dim of these two tensors: bs, N, d0, d_t-1 | |
| weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X # bs, n, d0, d_t-1 | |
| unnormalized_prob_X = weighted_X.sum(dim=2) # bs, n, d_t-1 | |
| unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 | |
| prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1 | |
| pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1])) | |
| weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E # bs, N, d0, d_t-1 | |
| unnormalized_prob_E = weighted_E.sum(dim=-2) | |
| unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 | |
| prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True) | |
| prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | |
| assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() | |
| assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() | |
| sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask) | |
| X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() | |
| E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() | |
| assert (E_s == torch.transpose(E_s, 1, 2)).all() | |
| assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) | |
| out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0)) | |
| return out_one_hot.mask(node_mask).type_as(y_t) | |
| def compute_extra_data(self, noisy_data): | |
| """ At every training step (after adding noise) and step in sampling, compute extra information and append to | |
| the network input. """ | |
| extra_features = self.extra_features(noisy_data) | |
| # print(extra_features.X.shape, extra_features.E.shape, extra_features.y.shape) | |
| extra_X = extra_features.X | |
| extra_E = extra_features.E | |
| extra_y = extra_features.y | |
| t = noisy_data['t'] | |
| extra_y = torch.cat((extra_y, t), dim=1) | |
| return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y) | |
| def forward(self, noisy_data, extra_data, node_mask): | |
| # print(noisy_data['cond_emb'].sum()) | |
| B = noisy_data['cond_emb'].unsqueeze(1).unsqueeze(2).expand(-1, noisy_data['X_t'].shape[1], noisy_data['X_t'].shape[1], -1).to(self.device) | |
| A = noisy_data['cond_emb'].unsqueeze(1).expand(-1, noisy_data['X_t'].shape[1], -1).to(self.device) | |
| X = torch.cat((noisy_data['X_t'], extra_data.X, A), dim=2).float() | |
| E = torch.cat((noisy_data['E_t'], extra_data.E, B), dim=3).float() | |
| y = torch.hstack((noisy_data['y_t'], extra_data.y)).float() | |
| return self.model(X, E, y, node_mask) | |