import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing from torch_geometric.utils import to_dense_batch import math try: import triton import triton.language as tl HAS_TRITON = torch.cuda.is_available() except ImportError: HAS_TRITON = False def pytorch_clifford_product(a, b): # DÜZELTME BURADA: a ve b matrislerini çarpmadan önce ortak boyuta genişlet (Broadcast) a, b = torch.broadcast_tensors(a, b) res = torch.zeros_like(a) res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7] res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5] res[..., 2] = a[...,0]*b[...,2] + a[...,1]*b[...,4] + a[...,2]*b[...,0] - a[...,3]*b[...,5] - a[...,4]*b[...,1] + a[...,5]*b[...,3] - a[...,6]*b[...,7] - a[...,7]*b[...,6] res[..., 3] = a[...,0]*b[...,3] - a[...,1]*b[...,6] + a[...,2]*b[...,5] + a[...,3]*b[...,0] + a[...,4]*b[...,7] - a[...,5]*b[...,2] + a[...,6]*b[...,1] - a[...,7]*b[...,4] res[..., 4] = a[...,0]*b[...,4] + a[...,1]*b[...,2] - a[...,2]*b[...,1] + a[...,3]*b[...,7] + a[...,4]*b[...,0] - a[...,5]*b[...,6] + a[...,6]*b[...,5] - a[...,7]*b[...,3] res[..., 5] = a[...,0]*b[...,5] + a[...,1]*b[...,7] + a[...,2]*b[...,3] - a[...,3]*b[...,2] + a[...,4]*b[...,6] + a[...,5]*b[...,0] - a[...,6]*b[...,4] - a[...,7]*b[...,1] res[..., 6] = a[...,0]*b[...,6] - a[...,1]*b[...,3] + a[...,2]*b[...,7] + a[...,3]*b[...,1] - a[...,4]*b[...,5] + a[...,5]*b[...,4] + a[...,6]*b[...,0] - a[...,7]*b[...,2] res[..., 7] = a[...,0]*b[...,7] + a[...,1]*b[...,5] + a[...,2]*b[...,6] + a[...,3]*b[...,4] + a[...,4]*b[...,3] + a[...,5]*b[...,1] + a[...,6]*b[...,2] + a[...,7]*b[...,0] return res def smart_clifford_product(a, b): if HAS_TRITON: pass return pytorch_clifford_product(a, b) class CliffordDiracLayer(MessagePassing): def __init__(self, channels: int): super().__init__(aggr="add", node_dim=0) self.channels = channels self.weight = nn.Linear(channels, channels, bias=False) self.distance_mlp = nn.Sequential(nn.Linear(1, channels), nn.SiLU(), nn.Linear(channels, channels)) self.resonance_mlp = nn.Sequential(nn.Linear(7, 16), nn.GELU(), nn.Linear(16, 1)) def forward(self, x, edge_index, v_ij, dist, edge_mask): x_proj = self.weight(x.transpose(1, 2)).transpose(1, 2) dist_weight = self.distance_mlp(dist) msg = self.propagate(edge_index, x=x_proj, v_ij=v_ij, dist_weight=dist_weight, edge_mask=edge_mask) gate = self.resonance_gate(x_proj, msg) return gate * msg def message(self, x_j, v_ij, dist_weight, edge_mask): v_8d = F.pad(v_ij, (1, 4)) v_8d_exp = v_8d.unsqueeze(1).expand_as(x_j) geom_msg = smart_clifford_product(v_8d_exp, x_j) msg = geom_msg * dist_weight.unsqueeze(-1) return msg * edge_mask.view(-1, 1, 1).float() def resonance_gate(self, x_state, msg_state): x_norm = F.normalize(x_state, dim=-1) msg_norm = F.normalize(msg_state, dim=-1) gp = smart_clifford_product(x_norm, msg_norm) scalar_align = gp[..., 0:1] vector_align = torch.norm(gp[..., 1:4], dim=-1, keepdim=True) bivector_align = torch.norm(gp[..., 4:7], dim=-1, keepdim=True) pseudoscalar_align = torch.abs(gp[..., 7:8]) state_norm = torch.norm(x_state, dim=-1, keepdim=True) msg_norm_mag = torch.norm(msg_state, dim=-1, keepdim=True) delta_norm = torch.norm(msg_state - x_state, dim=-1, keepdim=True) feats = torch.cat([scalar_align, vector_align, bivector_align, pseudoscalar_align, state_norm, msg_norm_mag, delta_norm], dim=-1) return torch.sigmoid(self.resonance_mlp(feats)) class CliffordSelfAttention(nn.Module): def __init__(self, channels: int): super().__init__() self.channels = channels self.W_q = nn.Linear(channels, channels, bias=False) self.W_k = nn.Linear(channels, channels, bias=False) self.W_v = nn.Linear(channels, channels, bias=False) self.score_net = nn.Sequential(nn.Linear(8, 16), nn.GELU(), nn.Linear(16, 1)) def forward(self, x_dense, mask): B, N, C, _ = x_dense.shape q = self.W_q(x_dense.transpose(2, 3)).transpose(2, 3) k = self.W_k(x_dense.transpose(2, 3)).transpose(2, 3) v = self.W_v(x_dense.transpose(2, 3)).transpose(2, 3) q_expanded = q.unsqueeze(2) k_expanded = k.unsqueeze(1) geom_prod = smart_clifford_product(q_expanded, k_expanded) scores = self.score_net(geom_prod.mean(dim=3)).squeeze(-1) scores = scores.masked_fill(~(mask.unsqueeze(1) & mask.unsqueeze(2)), -10000.0) attn = F.softmax(scores / math.sqrt(C), dim=-1) return (attn.view(B, N, N, 1, 1) * v.unsqueeze(1)).sum(dim=2) + x_dense class HierarchicalFPSCliffordNet(nn.Module): def __init__(self, base_channels: int = 12, num_part_classes: int = 50, num_categories: int = 16): super().__init__() self.C = base_channels self.layer1 = CliffordDiracLayer(base_channels) self.layer2 = CliffordDiracLayer(base_channels * 2) self.lin1 = nn.Linear(base_channels * 8, (base_channels * 2) * 8) self.lin2 = nn.Linear((base_channels * 2) * 8, (base_channels * 4) * 8) self.manager = CliffordSelfAttention(base_channels * 4) self.cat_emb = nn.Embedding(num_categories, 64) combined_dim = (base_channels * 4 + base_channels * 2 + base_channels) * 8 + 64 self.head = nn.Sequential( nn.Linear(combined_dim, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(0.4), nn.Linear(64, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(0.2), nn.Linear(64, num_part_classes) ) def forward(self, pos, batch, category, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask): x0 = torch.zeros(pos.size(0), self.C, 8, device=pos.device) x0[..., 1:4] = pos.unsqueeze(1).expand(-1, self.C, -1) row1, col1 = edge_index_1[1], edge_index_1[0] diff1 = pos[row1] - pos[col1] d1 = diff1.norm(dim=-1, keepdim=True).clamp(min=1e-8) dummy_mask1 = torch.ones(edge_index_1.size(1), dtype=torch.bool, device=pos.device) x1 = x0 + self.layer1(x0, edge_index_1, diff1 / d1, d1, dummy_mask1) f1 = (x1 / (x1.norm(dim=-1, keepdim=True).mean(dim=1, keepdim=True) + 1e-6)).reshape(x1.size(0), -1) x2_in = self.lin1(x1[fps_idx_1].reshape(fps_idx_1.numel(), -1)).reshape(-1, self.C * 2, 8) row2, col2 = edge_index_2[1], edge_index_2[0] diff2 = pos2[row2] - pos2[col2] d2 = diff2.norm(dim=-1, keepdim=True).clamp(min=1e-8) dummy_mask2 = torch.ones(edge_index_2.size(1), dtype=torch.bool, device=pos.device) x2 = x2_in + self.layer2(x2_in, edge_index_2, diff2 / d2, d2, dummy_mask2) f2 = (x2 / (x2.norm(dim=-1, keepdim=True).mean(dim=1, keepdim=True) + 1e-6)).reshape(x2.size(0), -1) x3_in = self.lin2(x2[fps_idx_2].reshape(fps_idx_2.numel(), -1)).reshape(-1, self.C * 4, 8) x_dense, _ = to_dense_batch(x3_in.reshape(x3_in.size(0), -1), batch3, max_num_nodes=x_dense_mask.size(1)) x3 = self.manager(x_dense.view(x_dense.size(0), x_dense.size(1), self.C * 4, 8), x_dense_mask) f3 = x3[x_dense_mask].reshape(-1, self.C * 4 * 8) row_32, col_32 = assign_index_32[0], assign_index_32[1] out_f3_to_pos2 = torch.zeros(pos2.size(0), f3.size(1), device=pos.device) out_f3_to_pos2.scatter_add_(0, row_32.unsqueeze(1).expand(-1, f3.size(1)), f3[col_32]) out_f3_to_pos2 = out_f3_to_pos2 / 3.0 f2_combined = torch.cat([out_f3_to_pos2, f2], dim=-1) row_21, col_21 = assign_index_21[0], assign_index_21[1] f2_up = torch.zeros(pos.size(0), f2_combined.size(1), device=pos.device) f2_up.scatter_add_(0, row_21.unsqueeze(1).expand(-1, f2_combined.size(1)), f2_combined[col_21]) f2_up = f2_up / 3.0 cat_features = self.cat_emb(category)[batch] f1_final = torch.cat([f2_up, f1, cat_features], dim=-1) return self.head(f1_final)