kaveh's picture
init
ef814bf
import torch
from utils.helpers import get_token_embeddings
def compute_similarity_matrix(model, dataset, device):
"""
Compute the similarity matrix for the dataset.
Args:
model (torch.nn.Module): Model.
dataset (torch.utils.data.Dataset): Dataset.
device (str): Device to use.
Returns:
np.ndarray: Similarity matrix.
"""
embeddings = get_token_embeddings(model, dataset, device) # shape: (n_samples, seq_len, d_model)
# Compute the mean embedding for each token across all samples
mean_token_embeddings = embeddings.mean(dim=0) # shape: (seq_len, d_model)
# Normalize the mean token embeddings (for cosine similarity)
mean_token_embeddings = mean_token_embeddings / mean_token_embeddings.norm(dim=1, keepdim=True)
# Calculate cosine similarity for all pairs of tokens using matrix multiplication
similarity_matrix = torch.mm(mean_token_embeddings, mean_token_embeddings.T).cpu().numpy()
return similarity_matrix # Convert to numpy array if needed