| def calc_cg_score_gnn_with_sampling( |
| A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False |
| ): |
| """ |
| Optimized CG-score calculation with edge sampling. |
| """ |
|
|
| N = A.shape[0] |
| cg_scores = { |
| "vi": np.zeros((N, N)), |
| "ab": np.zeros((N, N)), |
| "a2": np.zeros((N, N)), |
| "b2": np.zeros((N, N)), |
| "times": np.zeros((N, N)), |
| } |
|
|
| A = A.to(device) |
| X = X.to(device) |
| labels = labels.to(device) |
|
|
| @torch.no_grad() |
| def normalize(tensor): |
| return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8) |
|
|
| for _ in range(rep_num): |
| AX = torch.matmul(A, X) |
| norm_AX = normalize(AX) |
|
|
| |
| dataset = defaultdict(list) |
| data_idx = defaultdict(list) |
| for i, label in enumerate(labels): |
| dataset[label.item()].append(norm_AX[i].unsqueeze(0)) |
| data_idx[label.item()].append(i) |
|
|
| for label in dataset: |
| dataset[label] = torch.cat(dataset[label], dim=0) |
| data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device) |
|
|
| |
| neg_samples_dict = {} |
| neg_indices_dict = {} |
| for label in dataset: |
| neg_samples = torch.cat([dataset[l] for l in dataset if l != label]) |
| neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label]) |
| neg_samples_dict[label] = neg_samples |
| neg_indices_dict[label] = neg_indices |
|
|
| |
| for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"): |
| curr_indices = data_idx[curr_label] |
| curr_num = len(curr_samples) |
|
|
| chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False) |
| chosen_curr_samples = curr_samples[chosen_curr_idx] |
| chosen_curr_indices = curr_indices[chosen_curr_idx] |
|
|
| |
| neg_samples = neg_samples_dict[curr_label] |
| neg_indices = neg_indices_dict[curr_label] |
| neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) |
| rand_idx = torch.randperm(len(neg_samples))[:neg_num] |
| chosen_neg_samples = neg_samples[rand_idx] |
| chosen_neg_indices = neg_indices[rand_idx] |
|
|
| combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0) |
| y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device) |
|
|
| |
| H_inner = torch.matmul(combined_samples, combined_samples.T) |
| H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
| H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
| H.fill_diagonal_(0.5) |
| H += 1e-6 * torch.eye(H.size(0), device=device) |
| invH = torch.inverse(H) |
| original_error = y @ (invH @ y) |
|
|
| |
| for idx_i in tqdm(chosen_curr_indices.tolist(), desc=f"Nodes in label {curr_label}"): |
| for j in range(idx_i + 1, N): |
| if A[idx_i, j] == 0: |
| continue |
|
|
| |
| AX1_i = AX[idx_i] - A[idx_i, j] * X[j] |
| AX1_j = AX[j] - A[j, idx_i] * X[idx_i] |
|
|
| norm_AX1 = norm_AX.clone() |
| norm_AX1[idx_i] = AX1_i / (torch.norm(AX1_i) + 1e-8) |
| norm_AX1[j] = AX1_j / (torch.norm(AX1_j) + 1e-8) |
|
|
| |
| curr_samples_A1 = norm_AX1[chosen_curr_indices] |
| neg_samples_A1 = norm_AX1[chosen_neg_indices] |
| combined_samples_A1 = torch.cat([curr_samples_A1, neg_samples_A1], dim=0) |
|
|
| |
| H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T) |
| H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0) |
| H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi) |
| H_A1.fill_diagonal_(0.5) |
| H_A1 += 1e-6 * torch.eye(H_A1.size(0), device=device) |
| invH_A1 = torch.inverse(H_A1) |
| error_A1 = y @ (invH_A1 @ y) |
|
|
| score = (original_error - error_A1).item() |
| cg_scores["vi"][idx_i, j] += score |
| cg_scores["vi"][j, idx_i] = cg_scores["vi"][idx_i, j] |
| cg_scores["times"][idx_i, j] += 1 |
| cg_scores["times"][j, idx_i] += 1 |
|
|
| |
| for key in cg_scores: |
| if key != "times": |
| cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) |
|
|
| return cg_scores if sub_term else cg_scores["vi"] |
|
|
|
|
| def calc_cg_score_gnn_with_sampling( |
| A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64 |
| ): |
| """ |
| Optimized CG-score calculation with edge batching and GPU acceleration. |
| """ |
| |
| |
|
|
| N = A.shape[0] |
| cg_scores = { |
| "vi": np.zeros((N, N)), |
| "ab": np.zeros((N, N)), |
| "a2": np.zeros((N, N)), |
| "b2": np.zeros((N, N)), |
| "times": np.zeros((N, N)), |
| } |
|
|
| A = A.to(device) |
| X = X.to(device) |
| labels = labels.to(device) |
|
|
| @torch.no_grad() |
| def normalize(tensor): |
| return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8) |
|
|
| for _ in range(rep_num): |
| AX = torch.matmul(A, X) |
| norm_AX = normalize(AX) |
|
|
| |
| dataset = defaultdict(list) |
| data_idx = defaultdict(list) |
| for i, label in enumerate(labels): |
| dataset[label.item()].append(norm_AX[i].unsqueeze(0)) |
| data_idx[label.item()].append(i) |
|
|
| for label in dataset: |
| dataset[label] = torch.cat(dataset[label], dim=0) |
| data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device) |
|
|
| |
| neg_samples_dict = {} |
| neg_indices_dict = {} |
| for label in dataset: |
| neg_samples = torch.cat([dataset[l] for l in dataset if l != label]) |
| neg_indices = torch.cat([data_idx[l] for l in data_idx if l != label]) |
| neg_samples_dict[label] = neg_samples |
| neg_indices_dict[label] = neg_indices |
|
|
| for curr_label, curr_samples in tqdm(dataset.items(), desc="Label groups"): |
| curr_indices = data_idx[curr_label] |
| curr_num = len(curr_samples) |
|
|
| chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False) |
| chosen_curr_samples = curr_samples[chosen_curr_idx] |
| chosen_curr_indices = curr_indices[chosen_curr_idx] |
|
|
| neg_samples = neg_samples_dict[curr_label] |
| neg_indices = neg_indices_dict[curr_label] |
| neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) |
| rand_idx = torch.randperm(len(neg_samples))[:neg_num] |
| chosen_neg_samples = neg_samples[rand_idx] |
| chosen_neg_indices = neg_indices[rand_idx] |
|
|
| combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0) |
| y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device) |
|
|
| |
| H_inner = torch.matmul(combined_samples, combined_samples.T) |
| H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
| H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
| H.fill_diagonal_(0.5) |
| H += 1e-6 * torch.eye(H.size(0), device=device) |
| invH = torch.inverse(H) |
| original_error = y @ (invH @ y) |
|
|
| |
| edge_batch = [] |
| for idx_i in chosen_curr_indices.tolist(): |
| for j in range(idx_i + 1, N): |
| if A[idx_i, j] != 0: |
| edge_batch.append((idx_i, j)) |
|
|
| |
| for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False): |
| batch = edge_batch[k : k + batch_size] |
| B = len(batch) |
|
|
| norm_AX1_batch = norm_AX.repeat(B, 1, 1) |
| updates = [] |
| for b, (i, j) in enumerate(batch): |
| AX1_i = AX[i] - A[i, j] * X[j] |
| AX1_j = AX[j] - A[j, i] * X[i] |
| norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8) |
| norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8) |
|
|
| sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist() |
| sample_batch = norm_AX1_batch[:, sample_idx, :] |
|
|
| H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2)) |
| H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
| H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
| eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H) |
| H = H + 1e-6 * eye |
| H.diagonal(dim1=-2, dim2=-1).copy_(0.5) |
|
|
| invH = torch.inverse(H) |
| y_expanded = y.unsqueeze(0).expand(B, -1) |
| error_A1 = torch.einsum('bi,bij,bj->b', y_expanded, invH, y_expanded) |
|
|
| for b, (i, j) in enumerate(batch): |
| score = (original_error - error_A1[b]).item() |
| cg_scores["vi"][i, j] += score |
| cg_scores["vi"][j, i] = cg_scores["vi"][i, j] |
| cg_scores["times"][i, j] += 1 |
| cg_scores["times"][j, i] += 1 |
|
|
| for key in cg_scores: |
| if key != "times": |
| cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) |
|
|
| return cg_scores if sub_term else cg_scores["vi"] |
|
|
| def calc_cg_score_gnn_with_sampling( |
| A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False |
| ): |
| """ |
| Calculate CG-score for each edge in a graph with node labels and random sampling. |
| |
| Args: |
| A: torch.Tensor |
| Adjacency matrix of the graph (size: N x N). |
| X: torch.Tensor |
| Node features matrix (size: N x F). |
| labels: torch.Tensor |
| Node labels (size: N). |
| device: torch.device |
| Device to perform calculations. |
| rep_num: int |
| Number of repetitions for Monte Carlo sampling. |
| unbalance_ratio: float |
| Ratio of unbalanced data (1:unbalance_ratio). |
| sub_term: bool |
| If True, calculate and return sub-terms. |
| |
| Returns: |
| cg_scores: dict |
| Dictionary containing CG-scores for edges and optionally sub-terms. |
| """ |
| N = A.shape[0] |
| cg_scores = { |
| "vi": np.zeros((N, N)), |
| "ab": np.zeros((N, N)), |
| "a2": np.zeros((N, N)), |
| "b2": np.zeros((N, N)), |
| "times": np.zeros((N, N)), |
| } |
|
|
| with torch.no_grad(): |
| for _ in range(rep_num): |
| |
| AX = torch.matmul(A, X).to(device) |
| norm_AX = AX / torch.norm(AX, dim=1, keepdim=True) |
|
|
| |
| dataset = defaultdict(list) |
| data_idx = defaultdict(list) |
| for i, label in enumerate(labels): |
| dataset[label.item()].append(norm_AX[i].unsqueeze(0)) |
| data_idx[label.item()].append(i) |
|
|
| |
| for label, data_list in dataset.items(): |
| dataset[label] = torch.cat(data_list, dim=0) |
| data_idx[label] = torch.tensor(data_idx[label], dtype=torch.long, device=device) |
|
|
| |
| for curr_label, curr_samples in dataset.items(): |
| curr_indices = data_idx[curr_label] |
| curr_num = len(curr_samples) |
|
|
| |
| chosen_curr_idx = np.random.choice(range(curr_num), curr_num, replace=False) |
| chosen_curr_samples = curr_samples[chosen_curr_idx] |
| chosen_curr_indices = curr_indices[chosen_curr_idx] |
|
|
| |
| neg_samples = torch.cat( |
| [dataset[l] for l in dataset if l != curr_label], dim=0 |
| ) |
| neg_indices = torch.cat( |
| [data_idx[l] for l in data_idx if l != curr_label], dim=0 |
| ) |
| neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) |
| chosen_neg_samples = neg_samples[ |
| torch.randperm(len(neg_samples))[:neg_num] |
| ] |
|
|
| |
| combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0) |
| y = torch.cat( |
| [torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0 |
| ).to(device) |
|
|
| |
| H_inner = torch.matmul(combined_samples, combined_samples.T) |
| del combined_samples |
| |
| H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
| |
| H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
| del H_inner |
|
|
| H.fill_diagonal_(0.5) |
| |
| epsilon = 1e-6 |
| H = H + epsilon * torch.eye(H.size(0), device=H.device) |
| |
| invH = torch.inverse(H) |
| del H |
| original_error = y @ (invH @ y) |
|
|
| |
| for i in chosen_curr_indices: |
| print("the node index:", i) |
| for j in range(i + 1, N): |
| |
| if A[i, j] == 0: |
| continue |
|
|
| |
| A1 = A.clone() |
| A1[i, j] = A1[j, i] = 0 |
|
|
| |
| AX1 = torch.matmul(A1, X).to(device) |
| norm_AX1 = AX1 / torch.norm(AX1, dim=1, keepdim=True) |
|
|
| |
| curr_samples_A1 = norm_AX1[chosen_curr_indices] |
| neg_samples_A1 = norm_AX1[neg_indices] |
| chosen_neg_samples_A1 = neg_samples_A1[ |
| torch.randperm(len(neg_samples_A1))[:neg_num] |
| ] |
| combined_samples_A1 = torch.cat( |
| [curr_samples_A1, chosen_neg_samples_A1], dim=0 |
| ) |
| H_inner_A1 = torch.matmul(combined_samples_A1, combined_samples_A1.T) |
|
|
| del combined_samples_A1 |
| |
| |
| H_inner_A1 = torch.clamp(H_inner_A1, min=-1.0, max=1.0) |
| |
|
|
| H_A1 = H_inner_A1 * (np.pi - torch.acos(H_inner_A1)) / (2 * np.pi) |
| del H_inner_A1 |
| H_A1.fill_diagonal_(0.5) |
|
|
| |
| epsilon = 1e-6 |
| H_A1= H_A1 + epsilon * torch.eye(H_A1.size(0), device=H_A1.device) |
| |
| invH_A1 = torch.inverse(H_A1) |
| del H_A1 |
|
|
| error_A1 = y @ (invH_A1 @ y) |
| |
| print("i:", i) |
| print("j:", j) |
| print("current score:", (original_error - error_A1).item()) |
| |
| cg_scores["vi"][i, j] += (original_error - error_A1).item() |
| cg_scores["vi"][j, i] = cg_scores["vi"][i, j] |
| cg_scores["times"][i, j] += 1 |
| cg_scores["times"][j, i] += 1 |
|
|
| |
| for key, values in cg_scores.items(): |
| if key == "times": |
| continue |
| cg_scores[key] = values / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) |
|
|