import numpy as np import torch from torch.utils.data import DataLoader, Subset from utils.helpers import create_multimodal_model from models import SingleTransformer from scipy.sparse import csr_matrix def filter_idx(dataset, idx): """ Filter the idx to only return the samples that none of its modalities are all zeros Args: dataset: Dataset object containing the data. idx: List of indices to filter. Returns: filtered_idx: List of filtered indices. """ rna = dataset.rna_data atac = dataset.atac_data flux = dataset.flux_data mask = (rna != 0).any(axis=1) & (atac != 0).any(axis=1) & (flux != 0).any(axis=1) # filter the idx if the id is in the mask filtered_idx = [i for i in idx if mask[i]] return filtered_idx def analyze_cls_attention(id, fold_results, dataset, model_config, device, indices, average_heads=True, return_flow_attention=False): """ Extracts the attention weights of the validation set of each fold Args: id: The type of data to use. Must be one of 'RNA', 'ATAC', 'Flux', 'Multi'. fold_results: List of dictionaries containing the results of each fold. dataset: Dataset object containing the data. model_config: Dictionary containing the model configuration. device: Device to run the model on. sample_type: The type of samples to analyze. Must be one of 'all', 'dead-end', or 'reprogramming'. Defaults to 'all'. average_heads: Whether to average the attention weights across heads. Defaults to True. Returns: all_attention_weights: Numpy array containing the attention weights of the validation set """ if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") all_attention_weights = [] for fold in fold_results: val_idx = fold['val_idx'] # filter val_idx if is in indices val_idx = [i for i in val_idx if i in indices] if id == 'Multi': val_idx = filter_idx(dataset, val_idx) if len(val_idx) == 0: print('No samples of the specified type in the validation set. Skipping...') continue val_ds = Subset(dataset, val_idx) val_loader = DataLoader(val_ds, batch_size=32, shuffle=False) if id=='Multi': model = create_multimodal_model(model_config, device, use_mlm=False) else: model = SingleTransformer(id=id, **model_config).to(device) model_path = fold['best_model_path'] state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict) model.eval() with torch.no_grad(): for batch in val_loader: x, b, _ = batch if isinstance(x, list): rna = x[0].to(device) atac = x[1].to(device) flux = x[2].to(device) x = (rna, atac, flux) else: x = x.to(device) b = b.to(device) _, _, attention_weights = model(x, b, return_attention=True, return_flow_attention=return_flow_attention) if not return_flow_attention: if average_heads: attention_weights = attention_weights.squeeze(-2).mean(dim=1) # Average across heads (batch, 1, seq_len) -> (batch, seq_len) else: attention_weights = attention_weights.squeeze(-2) # (batch, num_heads, 1, seq_len) -> (batch, num_heads, seq_len) # if hasattr(attention_weights, 'numpy'): # attention_weights = attention_weights.cpu().numpy() all_attention_weights.append(attention_weights) if not return_flow_attention: return np.concatenate(all_attention_weights, axis=0) # (n_samples, seq_len) or (n_samples, num_heads, seq_len) else: att_w = {'rna': [], 'atac': [], 'flux': [], 'cls': []} # noew we have a dict. So concatenating all values for each key num_layers_mlm = len(all_attention_weights[0]['rna']) num_layers_cls = len(all_attention_weights[0]['cls']) if isinstance(all_attention_weights[0]['cls'], list) else 1 for key in all_attention_weights[0].keys(): key_all_attentions = [] for batch_row in all_attention_weights: modality_batch_attention_layers = batch_row[key] if isinstance(modality_batch_attention_layers, list): for i, modality_attention_layers in enumerate(modality_batch_attention_layers): modality_batch_attention_layers[i] = modality_attention_layers.cpu() key_all_attentions.append(modality_batch_attention_layers) else: key_all_attentions.append([modality_batch_attention_layers.cpu()]) # now I have a list of attention weights for each batch in each layer [[layer0_att_weights_batch1, layer1_att_weights_batch1, ...], [layer0_att_weights_batch2, layer1_att_weights_batch2, ...], ...] # I want to concatenate all the attention weights for each layer num_layers = num_layers_cls if key == 'cls' else num_layers_mlm att_w[key] = [torch.cat([layer[i] for layer in key_all_attentions], axis=0) for i in range(num_layers)] return att_w # def compute_attention_rollout(attention_weights): # num_layers = len(attention_weights) # combined_attention = torch.eye(attention_weights[0].size(-1)).to(attention_weights[0].device) # for layer in range(num_layers): # layer_attention = attention_weights[layer].mean(dim=1) # Average over heads # combined_attention = torch.matmul(layer_attention, combined_attention) # return combined_attention def compute_attention_rollout(attention_weights): """ Computes the attention rollout for a batch of samples. Expects attention_weights to be a list (length=num_layers) of tensors with shape (batch, num_heads, seq_len, seq_len). For each layer, we average over the heads and then compute the rollout per sample. Returns: rollout: A tensor of shape (batch, seq_len, seq_len) representing the effective attention from the input token (typically CLS) to all tokens. """ num_layers = len(attention_weights) # Get batch size and sequence length from the first layer's tensor batch_size, num_heads, seq_len, _ = attention_weights[0].shape # Initialize the combined attention as the identity matrix for each sample combined_attention = torch.eye(seq_len, device=attention_weights[0].device) combined_attention = combined_attention.unsqueeze(0).repeat(batch_size, 1, 1) for layer in range(num_layers): # Average over heads to get a (batch, seq_len, seq_len) tensor for this layer layer_attention = attention_weights[layer].mean(dim=1) # Update the rollout for each sample using batched matrix multiplication combined_attention = torch.bmm(layer_attention, combined_attention) return combined_attention def multimodal_attention_rollout(all_attention_weights): rna_rollout = compute_attention_rollout(all_attention_weights['rna']) atac_rollout = compute_attention_rollout(all_attention_weights['atac']) flux_rollout = compute_attention_rollout(all_attention_weights['flux']) cls_attention = all_attention_weights['cls'][0].mean(dim=1).squeeze(1) # Average over heads # Split CLS attention for each modality rna_cls_attn, atac_cls_attn, flux_cls_attn = cls_attention.split( [rna_rollout.size(1), atac_rollout.size(1), flux_rollout.size(1)], dim=1) final_rollout = torch.cat([ rna_cls_attn.unsqueeze(1) @ rna_rollout, atac_cls_attn.unsqueeze(1) @ atac_rollout, flux_cls_attn.unsqueeze(1) @ flux_rollout ], dim=2) return final_rollout.squeeze(1) # remove head dimension [samples, tokens] def print_top_features(attention_weights, feature_names, top_n=5, modality=None): print(f"\nTop {top_n} attended features ({modality} samples):") avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0) top_indices = avg_attention.argsort()[-top_n:][::-1] for i in top_indices: print(f"{feature_names[i]}: {avg_attention[i]:.4f}") def get_top_features(attention_weights, feature_names, top_n=100, modality=None): ls = [] avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0) if top_n: top_indices = avg_attention.argsort()[-top_n:][::-1] else: top_indices = avg_attention.argsort()[::-1] for i in top_indices: ls.append((feature_names[i],avg_attention[i])) return ls from scipy.sparse.csgraph import maximum_flow def compute_attention_flow(attention_weights): num_layers = len(attention_weights) num_tokens = attention_weights[0].size(-1) # Create adjacency matrix for the flow network adj_matrix = np.zeros((num_layers * num_tokens, num_layers * num_tokens)) for i in range(num_layers - 1): layer_attention = attention_weights[i].mean(dim=1).cpu().numpy() # Average over heads start_idx = i * num_tokens end_idx = (i + 1) * num_tokens adj_matrix[start_idx:end_idx, end_idx:(end_idx + num_tokens)] = layer_attention for i in range(num_layers - 1): start_idx = i * num_tokens end_idx = (i + 1) * num_tokens adj_matrix[start_idx:end_idx, end_idx:(end_idx + num_tokens)] += np.eye(num_tokens) flows = np.zeros((num_tokens, num_tokens)) for i in range(num_tokens): source = i for j in range(num_tokens): sink = (num_layers - 1) * num_tokens + j _, flow = maximum_flow(csr_matrix(adj_matrix), source, sink) flows[i, j] = flow return torch.tensor(flows, device=attention_weights[0].device) def multimodal_attention_flow(all_attention_weights): rna_flow = compute_attention_flow(all_attention_weights['rna']) atac_flow = compute_attention_flow(all_attention_weights['atac']) flux_flow = compute_attention_flow(all_attention_weights['flux']) cls_attention = all_attention_weights['cls'][0].mean(dim=1).squeeze(1) # Average over heads # Split CLS attention for each modality rna_cls_attn, atac_cls_attn, flux_cls_attn = cls_attention.split( [rna_flow.size(1), atac_flow.size(1), flux_flow.size(1)], dim=1) # Normalize flows rna_flow = rna_flow / rna_flow.sum(dim=1, keepdim=True) atac_flow = atac_flow / atac_flow.sum(dim=1, keepdim=True) flux_flow = flux_flow / flux_flow.sum(dim=1, keepdim=True) final_flow = torch.cat([ rna_cls_attn.unsqueeze(1) @ rna_flow, atac_cls_attn.unsqueeze(1) @ atac_flow, flux_cls_attn.unsqueeze(1) @ flux_flow ], dim=2) return final_flow.squeeze(1)