Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from utils.helpers import create_multimodal_model | |
| from models import SingleTransformer | |
| from scipy.sparse import csr_matrix | |
| def filter_idx(dataset, idx): | |
| """ | |
| Filter the idx to only return the samples that none of its modalities are all zeros | |
| Args: | |
| dataset: Dataset object containing the data. | |
| idx: List of indices to filter. | |
| Returns: | |
| filtered_idx: List of filtered indices. | |
| """ | |
| rna = dataset.rna_data | |
| atac = dataset.atac_data | |
| flux = dataset.flux_data | |
| mask = (rna != 0).any(axis=1) & (atac != 0).any(axis=1) & (flux != 0).any(axis=1) | |
| # filter the idx if the id is in the mask | |
| filtered_idx = [i for i in idx if mask[i]] | |
| return filtered_idx | |
| def analyze_cls_attention(id, fold_results, dataset, model_config, device, indices, | |
| average_heads=True, return_flow_attention=False): | |
| """ | |
| Extracts the attention weights of the validation set of each fold | |
| Args: | |
| id: The type of data to use. Must be one of 'RNA', 'ATAC', 'Flux', 'Multi'. | |
| fold_results: List of dictionaries containing the results of each fold. | |
| dataset: Dataset object containing the data. | |
| model_config: Dictionary containing the model configuration. | |
| device: Device to run the model on. | |
| sample_type: The type of samples to analyze. Must be one of 'all', 'dead-end', or 'reprogramming'. Defaults to 'all'. | |
| average_heads: Whether to average the attention weights across heads. Defaults to True. | |
| Returns: | |
| all_attention_weights: Numpy array containing the attention weights of the validation set | |
| """ | |
| if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: | |
| raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") | |
| all_attention_weights = [] | |
| for fold in fold_results: | |
| val_idx = fold['val_idx'] | |
| # filter val_idx if is in indices | |
| val_idx = [i for i in val_idx if i in indices] | |
| if id == 'Multi': | |
| val_idx = filter_idx(dataset, val_idx) | |
| if len(val_idx) == 0: | |
| print('No samples of the specified type in the validation set. Skipping...') | |
| continue | |
| val_ds = Subset(dataset, val_idx) | |
| val_loader = DataLoader(val_ds, batch_size=32, shuffle=False) | |
| if id=='Multi': | |
| model = create_multimodal_model(model_config, device, use_mlm=False) | |
| else: | |
| model = SingleTransformer(id=id, **model_config).to(device) | |
| model_path = fold['best_model_path'] | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| x, b, _ = batch | |
| if isinstance(x, list): | |
| rna = x[0].to(device) | |
| atac = x[1].to(device) | |
| flux = x[2].to(device) | |
| x = (rna, atac, flux) | |
| else: | |
| x = x.to(device) | |
| b = b.to(device) | |
| _, _, attention_weights = model(x, b, return_attention=True, return_flow_attention=return_flow_attention) | |
| if not return_flow_attention: | |
| if average_heads: | |
| attention_weights = attention_weights.squeeze(-2).mean(dim=1) # Average across heads (batch, 1, seq_len) -> (batch, seq_len) | |
| else: | |
| attention_weights = attention_weights.squeeze(-2) # (batch, num_heads, 1, seq_len) -> (batch, num_heads, seq_len) | |
| # if hasattr(attention_weights, 'numpy'): | |
| # attention_weights = attention_weights.cpu().numpy() | |
| all_attention_weights.append(attention_weights) | |
| if not return_flow_attention: | |
| return np.concatenate(all_attention_weights, axis=0) # (n_samples, seq_len) or (n_samples, num_heads, seq_len) | |
| else: | |
| att_w = {'rna': [], 'atac': [], 'flux': [], 'cls': []} | |
| # noew we have a dict. So concatenating all values for each key | |
| num_layers_mlm = len(all_attention_weights[0]['rna']) | |
| num_layers_cls = len(all_attention_weights[0]['cls']) if isinstance(all_attention_weights[0]['cls'], list) else 1 | |
| for key in all_attention_weights[0].keys(): | |
| key_all_attentions = [] | |
| for batch_row in all_attention_weights: | |
| modality_batch_attention_layers = batch_row[key] | |
| if isinstance(modality_batch_attention_layers, list): | |
| for i, modality_attention_layers in enumerate(modality_batch_attention_layers): | |
| modality_batch_attention_layers[i] = modality_attention_layers.cpu() | |
| key_all_attentions.append(modality_batch_attention_layers) | |
| else: | |
| key_all_attentions.append([modality_batch_attention_layers.cpu()]) | |
| # now I have a list of attention weights for each batch in each layer [[layer0_att_weights_batch1, layer1_att_weights_batch1, ...], [layer0_att_weights_batch2, layer1_att_weights_batch2, ...], ...] | |
| # I want to concatenate all the attention weights for each layer | |
| num_layers = num_layers_cls if key == 'cls' else num_layers_mlm | |
| att_w[key] = [torch.cat([layer[i] for layer in key_all_attentions], axis=0) for i in range(num_layers)] | |
| return att_w | |
| # def compute_attention_rollout(attention_weights): | |
| # num_layers = len(attention_weights) | |
| # combined_attention = torch.eye(attention_weights[0].size(-1)).to(attention_weights[0].device) | |
| # for layer in range(num_layers): | |
| # layer_attention = attention_weights[layer].mean(dim=1) # Average over heads | |
| # combined_attention = torch.matmul(layer_attention, combined_attention) | |
| # return combined_attention | |
| def compute_attention_rollout(attention_weights): | |
| """ | |
| Computes the attention rollout for a batch of samples. | |
| Expects attention_weights to be a list (length=num_layers) of tensors | |
| with shape (batch, num_heads, seq_len, seq_len). For each layer, we average | |
| over the heads and then compute the rollout per sample. | |
| Returns: | |
| rollout: A tensor of shape (batch, seq_len, seq_len) representing the | |
| effective attention from the input token (typically CLS) to all tokens. | |
| """ | |
| num_layers = len(attention_weights) | |
| # Get batch size and sequence length from the first layer's tensor | |
| batch_size, num_heads, seq_len, _ = attention_weights[0].shape | |
| # Initialize the combined attention as the identity matrix for each sample | |
| combined_attention = torch.eye(seq_len, device=attention_weights[0].device) | |
| combined_attention = combined_attention.unsqueeze(0).repeat(batch_size, 1, 1) | |
| for layer in range(num_layers): | |
| # Average over heads to get a (batch, seq_len, seq_len) tensor for this layer | |
| layer_attention = attention_weights[layer].mean(dim=1) | |
| # Update the rollout for each sample using batched matrix multiplication | |
| combined_attention = torch.bmm(layer_attention, combined_attention) | |
| return combined_attention | |
| def multimodal_attention_rollout(all_attention_weights): | |
| rna_rollout = compute_attention_rollout(all_attention_weights['rna']) | |
| atac_rollout = compute_attention_rollout(all_attention_weights['atac']) | |
| flux_rollout = compute_attention_rollout(all_attention_weights['flux']) | |
| cls_attention = all_attention_weights['cls'][0].mean(dim=1).squeeze(1) # Average over heads | |
| # Split CLS attention for each modality | |
| rna_cls_attn, atac_cls_attn, flux_cls_attn = cls_attention.split( | |
| [rna_rollout.size(1), atac_rollout.size(1), flux_rollout.size(1)], dim=1) | |
| final_rollout = torch.cat([ | |
| rna_cls_attn.unsqueeze(1) @ rna_rollout, | |
| atac_cls_attn.unsqueeze(1) @ atac_rollout, | |
| flux_cls_attn.unsqueeze(1) @ flux_rollout | |
| ], dim=2) | |
| return final_rollout.squeeze(1) # remove head dimension [samples, tokens] | |
| def print_top_features(attention_weights, feature_names, top_n=5, modality=None): | |
| print(f"\nTop {top_n} attended features ({modality} samples):") | |
| avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0) | |
| top_indices = avg_attention.argsort()[-top_n:][::-1] | |
| for i in top_indices: | |
| print(f"{feature_names[i]}: {avg_attention[i]:.4f}") | |
| def get_top_features(attention_weights, feature_names, top_n=100, modality=None): | |
| ls = [] | |
| avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0) | |
| if top_n: | |
| top_indices = avg_attention.argsort()[-top_n:][::-1] | |
| else: | |
| top_indices = avg_attention.argsort()[::-1] | |
| for i in top_indices: | |
| ls.append((feature_names[i],avg_attention[i])) | |
| return ls | |
| from scipy.sparse.csgraph import maximum_flow | |
| def compute_attention_flow(attention_weights): | |
| num_layers = len(attention_weights) | |
| num_tokens = attention_weights[0].size(-1) | |
| # Create adjacency matrix for the flow network | |
| adj_matrix = np.zeros((num_layers * num_tokens, num_layers * num_tokens)) | |
| for i in range(num_layers - 1): | |
| layer_attention = attention_weights[i].mean(dim=1).cpu().numpy() # Average over heads | |
| start_idx = i * num_tokens | |
| end_idx = (i + 1) * num_tokens | |
| adj_matrix[start_idx:end_idx, end_idx:(end_idx + num_tokens)] = layer_attention | |
| for i in range(num_layers - 1): | |
| start_idx = i * num_tokens | |
| end_idx = (i + 1) * num_tokens | |
| adj_matrix[start_idx:end_idx, end_idx:(end_idx + num_tokens)] += np.eye(num_tokens) | |
| flows = np.zeros((num_tokens, num_tokens)) | |
| for i in range(num_tokens): | |
| source = i | |
| for j in range(num_tokens): | |
| sink = (num_layers - 1) * num_tokens + j | |
| _, flow = maximum_flow(csr_matrix(adj_matrix), source, sink) | |
| flows[i, j] = flow | |
| return torch.tensor(flows, device=attention_weights[0].device) | |
| def multimodal_attention_flow(all_attention_weights): | |
| rna_flow = compute_attention_flow(all_attention_weights['rna']) | |
| atac_flow = compute_attention_flow(all_attention_weights['atac']) | |
| flux_flow = compute_attention_flow(all_attention_weights['flux']) | |
| cls_attention = all_attention_weights['cls'][0].mean(dim=1).squeeze(1) # Average over heads | |
| # Split CLS attention for each modality | |
| rna_cls_attn, atac_cls_attn, flux_cls_attn = cls_attention.split( | |
| [rna_flow.size(1), atac_flow.size(1), flux_flow.size(1)], dim=1) | |
| # Normalize flows | |
| rna_flow = rna_flow / rna_flow.sum(dim=1, keepdim=True) | |
| atac_flow = atac_flow / atac_flow.sum(dim=1, keepdim=True) | |
| flux_flow = flux_flow / flux_flow.sum(dim=1, keepdim=True) | |
| final_flow = torch.cat([ | |
| rna_cls_attn.unsqueeze(1) @ rna_flow, | |
| atac_cls_attn.unsqueeze(1) @ atac_flow, | |
| flux_cls_attn.unsqueeze(1) @ flux_flow | |
| ], dim=2) | |
| return final_flow.squeeze(1) |