Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import MessagePassing | |
| from torch_geometric.utils import to_dense_batch | |
| import math | |
| try: | |
| import triton | |
| import triton.language as tl | |
| HAS_TRITON = torch.cuda.is_available() | |
| except ImportError: | |
| HAS_TRITON = False | |
| def pytorch_clifford_product(a, b): | |
| # DÜZELTME BURADA: a ve b matrislerini çarpmadan önce ortak boyuta genişlet (Broadcast) | |
| a, b = torch.broadcast_tensors(a, b) | |
| res = torch.zeros_like(a) | |
| res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7] | |
| res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5] | |
| res[..., 2] = a[...,0]*b[...,2] + a[...,1]*b[...,4] + a[...,2]*b[...,0] - a[...,3]*b[...,5] - a[...,4]*b[...,1] + a[...,5]*b[...,3] - a[...,6]*b[...,7] - a[...,7]*b[...,6] | |
| res[..., 3] = a[...,0]*b[...,3] - a[...,1]*b[...,6] + a[...,2]*b[...,5] + a[...,3]*b[...,0] + a[...,4]*b[...,7] - a[...,5]*b[...,2] + a[...,6]*b[...,1] - a[...,7]*b[...,4] | |
| res[..., 4] = a[...,0]*b[...,4] + a[...,1]*b[...,2] - a[...,2]*b[...,1] + a[...,3]*b[...,7] + a[...,4]*b[...,0] - a[...,5]*b[...,6] + a[...,6]*b[...,5] - a[...,7]*b[...,3] | |
| res[..., 5] = a[...,0]*b[...,5] + a[...,1]*b[...,7] + a[...,2]*b[...,3] - a[...,3]*b[...,2] + a[...,4]*b[...,6] + a[...,5]*b[...,0] - a[...,6]*b[...,4] - a[...,7]*b[...,1] | |
| res[..., 6] = a[...,0]*b[...,6] - a[...,1]*b[...,3] + a[...,2]*b[...,7] + a[...,3]*b[...,1] - a[...,4]*b[...,5] + a[...,5]*b[...,4] + a[...,6]*b[...,0] - a[...,7]*b[...,2] | |
| res[..., 7] = a[...,0]*b[...,7] + a[...,1]*b[...,5] + a[...,2]*b[...,6] + a[...,3]*b[...,4] + a[...,4]*b[...,3] + a[...,5]*b[...,1] + a[...,6]*b[...,2] + a[...,7]*b[...,0] | |
| return res | |
| def smart_clifford_product(a, b): | |
| if HAS_TRITON: | |
| pass | |
| return pytorch_clifford_product(a, b) | |
| class CliffordDiracLayer(MessagePassing): | |
| def __init__(self, channels: int): | |
| super().__init__(aggr="add", node_dim=0) | |
| self.channels = channels | |
| self.weight = nn.Linear(channels, channels, bias=False) | |
| self.distance_mlp = nn.Sequential(nn.Linear(1, channels), nn.SiLU(), nn.Linear(channels, channels)) | |
| self.resonance_mlp = nn.Sequential(nn.Linear(7, 16), nn.GELU(), nn.Linear(16, 1)) | |
| def forward(self, x, edge_index, v_ij, dist, edge_mask): | |
| x_proj = self.weight(x.transpose(1, 2)).transpose(1, 2) | |
| dist_weight = self.distance_mlp(dist) | |
| msg = self.propagate(edge_index, x=x_proj, v_ij=v_ij, dist_weight=dist_weight, edge_mask=edge_mask) | |
| gate = self.resonance_gate(x_proj, msg) | |
| return gate * msg | |
| def message(self, x_j, v_ij, dist_weight, edge_mask): | |
| v_8d = F.pad(v_ij, (1, 4)) | |
| v_8d_exp = v_8d.unsqueeze(1).expand_as(x_j) | |
| geom_msg = smart_clifford_product(v_8d_exp, x_j) | |
| msg = geom_msg * dist_weight.unsqueeze(-1) | |
| return msg * edge_mask.view(-1, 1, 1).float() | |
| def resonance_gate(self, x_state, msg_state): | |
| x_norm = F.normalize(x_state, dim=-1) | |
| msg_norm = F.normalize(msg_state, dim=-1) | |
| gp = smart_clifford_product(x_norm, msg_norm) | |
| scalar_align = gp[..., 0:1] | |
| vector_align = torch.norm(gp[..., 1:4], dim=-1, keepdim=True) | |
| bivector_align = torch.norm(gp[..., 4:7], dim=-1, keepdim=True) | |
| pseudoscalar_align = torch.abs(gp[..., 7:8]) | |
| state_norm = torch.norm(x_state, dim=-1, keepdim=True) | |
| msg_norm_mag = torch.norm(msg_state, dim=-1, keepdim=True) | |
| delta_norm = torch.norm(msg_state - x_state, dim=-1, keepdim=True) | |
| feats = torch.cat([scalar_align, vector_align, bivector_align, pseudoscalar_align, state_norm, msg_norm_mag, delta_norm], dim=-1) | |
| return torch.sigmoid(self.resonance_mlp(feats)) | |
| class CliffordSelfAttention(nn.Module): | |
| def __init__(self, channels: int): | |
| super().__init__() | |
| self.channels = channels | |
| self.W_q = nn.Linear(channels, channels, bias=False) | |
| self.W_k = nn.Linear(channels, channels, bias=False) | |
| self.W_v = nn.Linear(channels, channels, bias=False) | |
| self.score_net = nn.Sequential(nn.Linear(8, 16), nn.GELU(), nn.Linear(16, 1)) | |
| def forward(self, x_dense, mask): | |
| B, N, C, _ = x_dense.shape | |
| q = self.W_q(x_dense.transpose(2, 3)).transpose(2, 3) | |
| k = self.W_k(x_dense.transpose(2, 3)).transpose(2, 3) | |
| v = self.W_v(x_dense.transpose(2, 3)).transpose(2, 3) | |
| q_expanded = q.unsqueeze(2) | |
| k_expanded = k.unsqueeze(1) | |
| geom_prod = smart_clifford_product(q_expanded, k_expanded) | |
| scores = self.score_net(geom_prod.mean(dim=3)).squeeze(-1) | |
| scores = scores.masked_fill(~(mask.unsqueeze(1) & mask.unsqueeze(2)), -10000.0) | |
| attn = F.softmax(scores / math.sqrt(C), dim=-1) | |
| return (attn.view(B, N, N, 1, 1) * v.unsqueeze(1)).sum(dim=2) + x_dense | |
| class HierarchicalFPSCliffordNet(nn.Module): | |
| def __init__(self, base_channels: int = 12, num_part_classes: int = 50, num_categories: int = 16): | |
| super().__init__() | |
| self.C = base_channels | |
| self.layer1 = CliffordDiracLayer(base_channels) | |
| self.layer2 = CliffordDiracLayer(base_channels * 2) | |
| self.lin1 = nn.Linear(base_channels * 8, (base_channels * 2) * 8) | |
| self.lin2 = nn.Linear((base_channels * 2) * 8, (base_channels * 4) * 8) | |
| self.manager = CliffordSelfAttention(base_channels * 4) | |
| self.cat_emb = nn.Embedding(num_categories, 64) | |
| combined_dim = (base_channels * 4 + base_channels * 2 + base_channels) * 8 + 64 | |
| self.head = nn.Sequential( | |
| nn.Linear(combined_dim, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(0.4), | |
| nn.Linear(64, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(0.2), | |
| nn.Linear(64, num_part_classes) | |
| ) | |
| def forward(self, pos, batch, category, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, | |
| edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask): | |
| x0 = torch.zeros(pos.size(0), self.C, 8, device=pos.device) | |
| x0[..., 1:4] = pos.unsqueeze(1).expand(-1, self.C, -1) | |
| row1, col1 = edge_index_1[1], edge_index_1[0] | |
| diff1 = pos[row1] - pos[col1] | |
| d1 = diff1.norm(dim=-1, keepdim=True).clamp(min=1e-8) | |
| dummy_mask1 = torch.ones(edge_index_1.size(1), dtype=torch.bool, device=pos.device) | |
| x1 = x0 + self.layer1(x0, edge_index_1, diff1 / d1, d1, dummy_mask1) | |
| f1 = (x1 / (x1.norm(dim=-1, keepdim=True).mean(dim=1, keepdim=True) + 1e-6)).reshape(x1.size(0), -1) | |
| x2_in = self.lin1(x1[fps_idx_1].reshape(fps_idx_1.numel(), -1)).reshape(-1, self.C * 2, 8) | |
| row2, col2 = edge_index_2[1], edge_index_2[0] | |
| diff2 = pos2[row2] - pos2[col2] | |
| d2 = diff2.norm(dim=-1, keepdim=True).clamp(min=1e-8) | |
| dummy_mask2 = torch.ones(edge_index_2.size(1), dtype=torch.bool, device=pos.device) | |
| x2 = x2_in + self.layer2(x2_in, edge_index_2, diff2 / d2, d2, dummy_mask2) | |
| f2 = (x2 / (x2.norm(dim=-1, keepdim=True).mean(dim=1, keepdim=True) + 1e-6)).reshape(x2.size(0), -1) | |
| x3_in = self.lin2(x2[fps_idx_2].reshape(fps_idx_2.numel(), -1)).reshape(-1, self.C * 4, 8) | |
| x_dense, _ = to_dense_batch(x3_in.reshape(x3_in.size(0), -1), batch3, max_num_nodes=x_dense_mask.size(1)) | |
| x3 = self.manager(x_dense.view(x_dense.size(0), x_dense.size(1), self.C * 4, 8), x_dense_mask) | |
| f3 = x3[x_dense_mask].reshape(-1, self.C * 4 * 8) | |
| row_32, col_32 = assign_index_32[0], assign_index_32[1] | |
| out_f3_to_pos2 = torch.zeros(pos2.size(0), f3.size(1), device=pos.device) | |
| out_f3_to_pos2.scatter_add_(0, row_32.unsqueeze(1).expand(-1, f3.size(1)), f3[col_32]) | |
| out_f3_to_pos2 = out_f3_to_pos2 / 3.0 | |
| f2_combined = torch.cat([out_f3_to_pos2, f2], dim=-1) | |
| row_21, col_21 = assign_index_21[0], assign_index_21[1] | |
| f2_up = torch.zeros(pos.size(0), f2_combined.size(1), device=pos.device) | |
| f2_up.scatter_add_(0, row_21.unsqueeze(1).expand(-1, f2_combined.size(1)), f2_combined[col_21]) | |
| f2_up = f2_up / 3.0 | |
| cat_features = self.cat_emb(category)[batch] | |
| f1_final = torch.cat([f2_up, f1, cat_features], dim=-1) | |
| return self.head(f1_final) |