3dpartsegmentation / model.py
yusuf-tiryaki's picture
fix model
0fb3da6
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_dense_batch
import math
try:
import triton
import triton.language as tl
HAS_TRITON = torch.cuda.is_available()
except ImportError:
HAS_TRITON = False
def pytorch_clifford_product(a, b):
# DÜZELTME BURADA: a ve b matrislerini çarpmadan önce ortak boyuta genişlet (Broadcast)
a, b = torch.broadcast_tensors(a, b)
res = torch.zeros_like(a)
res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
res[..., 2] = a[...,0]*b[...,2] + a[...,1]*b[...,4] + a[...,2]*b[...,0] - a[...,3]*b[...,5] - a[...,4]*b[...,1] + a[...,5]*b[...,3] - a[...,6]*b[...,7] - a[...,7]*b[...,6]
res[..., 3] = a[...,0]*b[...,3] - a[...,1]*b[...,6] + a[...,2]*b[...,5] + a[...,3]*b[...,0] + a[...,4]*b[...,7] - a[...,5]*b[...,2] + a[...,6]*b[...,1] - a[...,7]*b[...,4]
res[..., 4] = a[...,0]*b[...,4] + a[...,1]*b[...,2] - a[...,2]*b[...,1] + a[...,3]*b[...,7] + a[...,4]*b[...,0] - a[...,5]*b[...,6] + a[...,6]*b[...,5] - a[...,7]*b[...,3]
res[..., 5] = a[...,0]*b[...,5] + a[...,1]*b[...,7] + a[...,2]*b[...,3] - a[...,3]*b[...,2] + a[...,4]*b[...,6] + a[...,5]*b[...,0] - a[...,6]*b[...,4] - a[...,7]*b[...,1]
res[..., 6] = a[...,0]*b[...,6] - a[...,1]*b[...,3] + a[...,2]*b[...,7] + a[...,3]*b[...,1] - a[...,4]*b[...,5] + a[...,5]*b[...,4] + a[...,6]*b[...,0] - a[...,7]*b[...,2]
res[..., 7] = a[...,0]*b[...,7] + a[...,1]*b[...,5] + a[...,2]*b[...,6] + a[...,3]*b[...,4] + a[...,4]*b[...,3] + a[...,5]*b[...,1] + a[...,6]*b[...,2] + a[...,7]*b[...,0]
return res
def smart_clifford_product(a, b):
if HAS_TRITON:
pass
return pytorch_clifford_product(a, b)
class CliffordDiracLayer(MessagePassing):
def __init__(self, channels: int):
super().__init__(aggr="add", node_dim=0)
self.channels = channels
self.weight = nn.Linear(channels, channels, bias=False)
self.distance_mlp = nn.Sequential(nn.Linear(1, channels), nn.SiLU(), nn.Linear(channels, channels))
self.resonance_mlp = nn.Sequential(nn.Linear(7, 16), nn.GELU(), nn.Linear(16, 1))
def forward(self, x, edge_index, v_ij, dist, edge_mask):
x_proj = self.weight(x.transpose(1, 2)).transpose(1, 2)
dist_weight = self.distance_mlp(dist)
msg = self.propagate(edge_index, x=x_proj, v_ij=v_ij, dist_weight=dist_weight, edge_mask=edge_mask)
gate = self.resonance_gate(x_proj, msg)
return gate * msg
def message(self, x_j, v_ij, dist_weight, edge_mask):
v_8d = F.pad(v_ij, (1, 4))
v_8d_exp = v_8d.unsqueeze(1).expand_as(x_j)
geom_msg = smart_clifford_product(v_8d_exp, x_j)
msg = geom_msg * dist_weight.unsqueeze(-1)
return msg * edge_mask.view(-1, 1, 1).float()
def resonance_gate(self, x_state, msg_state):
x_norm = F.normalize(x_state, dim=-1)
msg_norm = F.normalize(msg_state, dim=-1)
gp = smart_clifford_product(x_norm, msg_norm)
scalar_align = gp[..., 0:1]
vector_align = torch.norm(gp[..., 1:4], dim=-1, keepdim=True)
bivector_align = torch.norm(gp[..., 4:7], dim=-1, keepdim=True)
pseudoscalar_align = torch.abs(gp[..., 7:8])
state_norm = torch.norm(x_state, dim=-1, keepdim=True)
msg_norm_mag = torch.norm(msg_state, dim=-1, keepdim=True)
delta_norm = torch.norm(msg_state - x_state, dim=-1, keepdim=True)
feats = torch.cat([scalar_align, vector_align, bivector_align, pseudoscalar_align, state_norm, msg_norm_mag, delta_norm], dim=-1)
return torch.sigmoid(self.resonance_mlp(feats))
class CliffordSelfAttention(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.channels = channels
self.W_q = nn.Linear(channels, channels, bias=False)
self.W_k = nn.Linear(channels, channels, bias=False)
self.W_v = nn.Linear(channels, channels, bias=False)
self.score_net = nn.Sequential(nn.Linear(8, 16), nn.GELU(), nn.Linear(16, 1))
def forward(self, x_dense, mask):
B, N, C, _ = x_dense.shape
q = self.W_q(x_dense.transpose(2, 3)).transpose(2, 3)
k = self.W_k(x_dense.transpose(2, 3)).transpose(2, 3)
v = self.W_v(x_dense.transpose(2, 3)).transpose(2, 3)
q_expanded = q.unsqueeze(2)
k_expanded = k.unsqueeze(1)
geom_prod = smart_clifford_product(q_expanded, k_expanded)
scores = self.score_net(geom_prod.mean(dim=3)).squeeze(-1)
scores = scores.masked_fill(~(mask.unsqueeze(1) & mask.unsqueeze(2)), -10000.0)
attn = F.softmax(scores / math.sqrt(C), dim=-1)
return (attn.view(B, N, N, 1, 1) * v.unsqueeze(1)).sum(dim=2) + x_dense
class HierarchicalFPSCliffordNet(nn.Module):
def __init__(self, base_channels: int = 12, num_part_classes: int = 50, num_categories: int = 16):
super().__init__()
self.C = base_channels
self.layer1 = CliffordDiracLayer(base_channels)
self.layer2 = CliffordDiracLayer(base_channels * 2)
self.lin1 = nn.Linear(base_channels * 8, (base_channels * 2) * 8)
self.lin2 = nn.Linear((base_channels * 2) * 8, (base_channels * 4) * 8)
self.manager = CliffordSelfAttention(base_channels * 4)
self.cat_emb = nn.Embedding(num_categories, 64)
combined_dim = (base_channels * 4 + base_channels * 2 + base_channels) * 8 + 64
self.head = nn.Sequential(
nn.Linear(combined_dim, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(0.4),
nn.Linear(64, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(0.2),
nn.Linear(64, num_part_classes)
)
def forward(self, pos, batch, category, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask):
x0 = torch.zeros(pos.size(0), self.C, 8, device=pos.device)
x0[..., 1:4] = pos.unsqueeze(1).expand(-1, self.C, -1)
row1, col1 = edge_index_1[1], edge_index_1[0]
diff1 = pos[row1] - pos[col1]
d1 = diff1.norm(dim=-1, keepdim=True).clamp(min=1e-8)
dummy_mask1 = torch.ones(edge_index_1.size(1), dtype=torch.bool, device=pos.device)
x1 = x0 + self.layer1(x0, edge_index_1, diff1 / d1, d1, dummy_mask1)
f1 = (x1 / (x1.norm(dim=-1, keepdim=True).mean(dim=1, keepdim=True) + 1e-6)).reshape(x1.size(0), -1)
x2_in = self.lin1(x1[fps_idx_1].reshape(fps_idx_1.numel(), -1)).reshape(-1, self.C * 2, 8)
row2, col2 = edge_index_2[1], edge_index_2[0]
diff2 = pos2[row2] - pos2[col2]
d2 = diff2.norm(dim=-1, keepdim=True).clamp(min=1e-8)
dummy_mask2 = torch.ones(edge_index_2.size(1), dtype=torch.bool, device=pos.device)
x2 = x2_in + self.layer2(x2_in, edge_index_2, diff2 / d2, d2, dummy_mask2)
f2 = (x2 / (x2.norm(dim=-1, keepdim=True).mean(dim=1, keepdim=True) + 1e-6)).reshape(x2.size(0), -1)
x3_in = self.lin2(x2[fps_idx_2].reshape(fps_idx_2.numel(), -1)).reshape(-1, self.C * 4, 8)
x_dense, _ = to_dense_batch(x3_in.reshape(x3_in.size(0), -1), batch3, max_num_nodes=x_dense_mask.size(1))
x3 = self.manager(x_dense.view(x_dense.size(0), x_dense.size(1), self.C * 4, 8), x_dense_mask)
f3 = x3[x_dense_mask].reshape(-1, self.C * 4 * 8)
row_32, col_32 = assign_index_32[0], assign_index_32[1]
out_f3_to_pos2 = torch.zeros(pos2.size(0), f3.size(1), device=pos.device)
out_f3_to_pos2.scatter_add_(0, row_32.unsqueeze(1).expand(-1, f3.size(1)), f3[col_32])
out_f3_to_pos2 = out_f3_to_pos2 / 3.0
f2_combined = torch.cat([out_f3_to_pos2, f2], dim=-1)
row_21, col_21 = assign_index_21[0], assign_index_21[1]
f2_up = torch.zeros(pos.size(0), f2_combined.size(1), device=pos.device)
f2_up.scatter_add_(0, row_21.unsqueeze(1).expand(-1, f2_combined.size(1)), f2_combined[col_21])
f2_up = f2_up / 3.0
cat_features = self.cat_emb(category)[batch]
f1_final = torch.cat([f2_up, f1, cat_features], dim=-1)
return self.head(f1_final)