| import torch |
|
|
| def get_weights(A, B, weight): |
| |
| if weight == "L2": |
| weights_A = torch.linalg.norm(A, dim=2) |
| weights_B = torch.linalg.norm(B, dim=2) |
| elif weight == "L1": |
| weights_A = torch.linalg.norm(A, dim=2, ord=1) |
| weights_B = torch.linalg.norm(B, dim=2, ord=1) |
| elif weight == "no": |
| weights_A = torch.ones(A.size(0), A.size(1)).to(A.device) |
| weights_B = torch.ones(B.size(0), B.size(1)).to(B.device) |
| else: |
| raise NotImplementedError |
| return weights_A, weights_B |
|
|
|
|
| def pairwise_cosine_matrix(matrix1, matrix2): |
| dot = torch.matmul(matrix1, matrix2.transpose(1, 2)) |
| matrix1_norm = torch.norm(matrix1, dim=-1, keepdim=True) |
| matrix2_norm = torch.norm(matrix2, dim=-1, keepdim=True) |
| norm = torch.matmul(matrix1_norm, matrix2_norm.transpose(1, 2)) |
| return dot / norm |
|
|
|
|
| def subspace_batch(A): |
| """ Return the matrix of the subspace for a batch of matrices |
| Arg: |
| A: Bases of a linear subspace (batchsize, num_bases, emb_dim) |
| Return: |
| S: Orthonormalized bases of a linear subspace (batchsize, num_bases, emb_dim) |
| Example: |
| >>> A = torch.randn(5, 4, 300) |
| >>> subspace_batch(A) |
| """ |
| |
| S, _ = torch.linalg.qr(torch.transpose(A, 1, 2)) |
| return torch.transpose(S, 1, 2) |
|
|
|
|
| @torch.jit.script |
| def soft_membership_batch(S, v): |
| """ Compute soft membership degree between a subspace and a vector for a batch of vectors |
| |
| Args: |
| S: Orthonormalized bases of a linear subspace (batchsize, num_bases, emb_dim) |
| v: vector (batchsize, emb_dim) |
| Return: |
| soft_membership degree (batchsize,) |
| Example: |
| >>> S = torch.randn(5, 4, 300) |
| >>> v = torch.randn(5, 300) |
| >>> soft_membership_batch(S, v) |
| """ |
| |
| v = torch.nn.functional.normalize(v) |
| v = v.view(v.size(0), v.size(1), 1) |
|
|
| |
| m = torch.matmul(S, v) |
| s = torch.linalg.svdvals(m.float()) |
| return torch.mean(s, 1) |
| |
| |
| def subspace_johnson(A, B, weight="L2"): |
| """ Compute similarity between two vector sets (sentences) |
| Args: |
| A: Matrix of word embeddings for the first sentence |
| (batchsize, num_bases, dim) |
| B: Matrix of word embeddings for the second sentence |
| (batchsize, num_bases, dim) |
| Return: |
| similarity between A and B (batchsize,) |
| Example: |
| >>> A = torch.randn(5, 3, 300) |
| >>> B = torch.randn(5, 4, 300) |
| >>> subspace_johnson(A, B) |
| """ |
| def numerator(U, V, weights): |
| """ |
| U should be a matrix of word embeddings |
| V should be a matrix of orthonormalized bases |
| """ |
| softm = torch.stack([soft_membership_batch(V, vec) |
| for vec in torch.transpose(U, 0, 1)]) |
| softm = torch.transpose(softm, 0, 1) |
| return torch.sum(softm * weights, 1) |
| |
| |
| weights_A, weights_B = get_weights(A, B, weight) |
| |
| |
| x = numerator(A, subspace_batch(B), weights_A) / torch.sum(weights_A, 1) |
| y = numerator(B, subspace_batch(A), weights_B) / torch.sum(weights_B, 1) |
| return x + y |
|
|
|
|
|
|
| def subspace_bert_score(A, B, weight="L2"): |
| """ Compute similarity between two vector sets (sentences) |
| Args: |
| A: Matrix of word embeddings for the first sentence |
| (batchsize, num_bases, dim) |
| B: Matrix of word embeddings for the second sentence |
| (batchsize, num_bases, dim) |
| Return: |
| similarity between A and B (batchsize,) |
| Example: |
| >>> A = torch.randn(5, 3, 300) |
| >>> B = torch.randn(5, 4, 300) |
| >>> subspace_bert_score(A, B) |
| """ |
| def numerator(U, V, weights): |
| """ |
| U should be a matrix of word embeddings |
| V should be a matrix of orthonormalized bases |
| """ |
| softm = torch.stack([soft_membership_batch(V, vec) |
| for vec in torch.transpose(U, 0, 1)]) |
| softm = torch.transpose(softm, 0, 1) |
| return torch.sum(softm * weights, 1) |
| |
| |
| weights_A, weights_B = get_weights(A, B, weight) |
| |
| |
| R = numerator(A, subspace_batch(B), weights_A) / torch.sum(weights_A, 1) |
| P = numerator(B, subspace_batch(A), weights_B) / torch.sum(weights_B, 1) |
| F = (2 * P * R) / (P + R) |
| return P, R, F |
|
|
|
|
| def vanilla_bert_score(A, B, weight="L2"): |
| """ Compute similarity between two vector sets (sentences) |
| Args: |
| A: Matrix of word embeddings for the first sentence |
| (batchsize, num_bases, dim) |
| B: Matrix of word embeddings for the second sentence |
| (batchsize, num_bases, dim) |
| Return: |
| similarity between A and B (batchsize,) |
| Example: |
| >>> A = torch.randn(5, 3, 300) |
| >>> B = torch.randn(5, 4, 300) |
| >>> vanilla_bert_score(A, B) |
| """ |
| def numerator(pairwise_cos, dim, weights): |
| max_cos, _ = pairwise_cos.max(dim=dim) |
| return torch.sum(max_cos * weights, 1) |
| |
| |
| weights_A, weights_B = get_weights(A, B, weight) |
| |
| |
| pairwise_cos = pairwise_cosine_matrix(A, B) |
| |
| |
| R = numerator(pairwise_cos, 2, weights_A) / torch.sum(weights_A, 1) |
| P = numerator(pairwise_cos, 1, weights_B) / torch.sum(weights_B, 1) |
| F = (2 * P * R) / (P + R) |
| return P, R, F |
|
|