yusuf-tiryaki commited on
Commit
fd1bd17
·
1 Parent(s): 3e83674

fix torch cluster

Browse files
Files changed (3) hide show
  1. app.py +42 -34
  2. model.py +0 -6
  3. requirements.txt +0 -3
app.py CHANGED
@@ -3,12 +3,40 @@ import torch
3
  import numpy as np
4
  import trimesh
5
  import plotly.graph_objects as go
6
- from torch_geometric.data import Data
7
- from torch_cluster import fps, radius_graph, knn
8
  from torch_geometric.utils import to_dense_batch
9
-
10
  from model import HierarchicalFPSCliffordNet
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # --- Kategori Sözlükleri ---
13
  CATEGORIES = {
14
  'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
@@ -24,7 +52,7 @@ SEG_CLASSES = {
24
  }
25
 
26
  # --- Modeli Yükleme ---
27
- device = torch.device("cpu") # HF Free Tier her zaman CPU'dur
28
  model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
29
 
30
  try:
@@ -38,11 +66,9 @@ except Exception as e:
38
 
39
  # --- Veri Hazırlama Fonksiyonu ---
40
  def process_point_cloud(file_path, category_name):
41
- # Dosyayı oku ve noktalara çevir
42
  mesh = trimesh.load(file_path, force='mesh')
43
  points = np.array(mesh.vertices)
44
 
45
- # 1024 noktaya örnekle (Eğitimdeki gibi)
46
  if len(points) > 1024:
47
  idx = np.random.choice(len(points), 1024, replace=False)
48
  points = points[idx]
@@ -51,8 +77,6 @@ def process_point_cloud(file_path, category_name):
51
  points = points[idx]
52
 
53
  pos = torch.tensor(points, dtype=torch.float32)
54
-
55
- # Birim Küreye Normalize Et
56
  pos = pos - pos.mean(dim=0, keepdim=True)
57
  scale = pos.norm(dim=1).max().clamp(min=1e-8)
58
  pos = pos / scale
@@ -60,68 +84,52 @@ def process_point_cloud(file_path, category_name):
60
  batch = torch.zeros(pos.size(0), dtype=torch.long)
61
  cat_idx = torch.tensor([CATEGORIES[category_name]], dtype=torch.long)
62
 
63
- # Graf İndekslerini Hesapla (Offline Preprocessing simülasyonu)
64
- fps_idx_1 = fps(pos, batch, ratio=0.5).long()
65
  pos2 = pos[fps_idx_1]
66
  batch2 = torch.zeros(pos2.size(0), dtype=torch.long)
67
 
68
- fps_idx_2 = fps(pos2, batch2, ratio=0.25).long()
69
  pos3 = pos2[fps_idx_2]
70
  batch3 = torch.zeros(pos3.size(0), dtype=torch.long)
71
 
72
- edge_index_1 = radius_graph(pos, r=0.15, batch=batch, max_num_neighbors=16, loop=True)
73
- edge_index_2 = radius_graph(pos2, r=0.30, batch=batch2, max_num_neighbors=32, loop=True)
74
 
75
- assign_index_32 = knn(pos3, pos2, k=3, batch_x=batch3, batch_y=batch2)
76
- assign_index_21 = knn(pos2, pos, k=3, batch_x=batch2, batch_y=batch)
77
 
78
  dummy_x = torch.zeros(pos3.size(0), 1)
79
  _, x_dense_mask = to_dense_batch(dummy_x, batch3, max_num_nodes=int(torch.bincount(batch3).max().item()))
80
 
81
  return pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask
82
 
83
- # --- Çıkarım (Inference) Fonksiyonu ---
84
  def predict(file, category_name):
85
  if file is None:
86
  return None
87
 
88
- # Veriyi modele uygun formata getir
89
  (pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
90
  edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask) = process_point_cloud(file.name, category_name)
91
 
92
- # Modelden Geçir
93
  with torch.no_grad():
94
  logits = model(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
95
  edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask)
96
 
97
- # O kategoriye ait olmayan parçaları maskele
98
  valid_parts = SEG_CLASSES[category_name]
99
  mask = torch.zeros_like(logits, dtype=torch.bool)
100
  mask[:, valid_parts] = True
101
  logits = logits.masked_fill(~mask, -10000.0)
102
-
103
- # Sınıfı Tahmin Et
104
  preds = logits.argmax(dim=-1).numpy()
105
 
106
- # Plotly ile Görselleştir
107
  pos_np = pos.numpy()
108
  fig = go.Figure(data=[go.Scatter3d(
109
  x=pos_np[:, 0], y=pos_np[:, 1], z=pos_np[:, 2],
110
  mode='markers',
111
- marker=dict(
112
- size=4,
113
- color=preds, # Her parçaya farklı renk atar
114
- colorscale='Viridis',
115
- opacity=0.9,
116
- line=dict(width=0.5, color='black') # Noktalara derinlik katar
117
- )
118
  )])
119
 
120
- fig.update_layout(
121
- margin=dict(l=0, r=0, b=0, t=0),
122
- scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
123
- )
124
-
125
  return fig
126
 
127
  # --- Gradio Arayüzü ---
 
3
  import numpy as np
4
  import trimesh
5
  import plotly.graph_objects as go
 
 
6
  from torch_geometric.utils import to_dense_batch
 
7
  from model import HierarchicalFPSCliffordNet
8
 
9
+ # --- SAF PYTORCH C++ YEDEKLERİ (Hugging Face için Özel Çözüm) ---
10
+ def knn_pure(x, y, k):
11
+ dist = torch.cdist(x, y)
12
+ _, topk_idx = torch.topk(dist, k, dim=1, largest=False)
13
+ col = torch.arange(x.size(0)).view(-1, 1).expand(-1, k).reshape(-1)
14
+ row = topk_idx.reshape(-1)
15
+ return torch.stack([row, col], dim=0)
16
+
17
+ def radius_graph_pure(pos, r, max_num_neighbors=32):
18
+ dist = torch.cdist(pos, pos)
19
+ n = pos.size(0)
20
+ k = min(max_num_neighbors, n)
21
+ _, target = torch.topk(dist, k, dim=1, largest=False)
22
+ source = torch.arange(n).view(-1, 1).expand(-1, k)
23
+ valid = dist[source, target] < r
24
+ return torch.stack([source[valid], target[valid]], dim=0)
25
+
26
+ def fps_pure(pos, ratio=0.5):
27
+ n = pos.size(0)
28
+ k = max(1, int(n * ratio))
29
+ idx = torch.zeros(k, dtype=torch.long, device=pos.device)
30
+ dist = torch.full((n,), 1e10, device=pos.device)
31
+ farthest = 0
32
+ for i in range(k):
33
+ idx[i] = farthest
34
+ farthest_point = pos[farthest].view(1, 3)
35
+ dist_to_farthest = torch.cdist(pos, farthest_point).squeeze(-1)
36
+ dist = torch.min(dist, dist_to_farthest)
37
+ farthest = torch.argmax(dist).item()
38
+ return idx
39
+
40
  # --- Kategori Sözlükleri ---
41
  CATEGORIES = {
42
  'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
 
52
  }
53
 
54
  # --- Modeli Yükleme ---
55
+ device = torch.device("cpu")
56
  model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
57
 
58
  try:
 
66
 
67
  # --- Veri Hazırlama Fonksiyonu ---
68
  def process_point_cloud(file_path, category_name):
 
69
  mesh = trimesh.load(file_path, force='mesh')
70
  points = np.array(mesh.vertices)
71
 
 
72
  if len(points) > 1024:
73
  idx = np.random.choice(len(points), 1024, replace=False)
74
  points = points[idx]
 
77
  points = points[idx]
78
 
79
  pos = torch.tensor(points, dtype=torch.float32)
 
 
80
  pos = pos - pos.mean(dim=0, keepdim=True)
81
  scale = pos.norm(dim=1).max().clamp(min=1e-8)
82
  pos = pos / scale
 
84
  batch = torch.zeros(pos.size(0), dtype=torch.long)
85
  cat_idx = torch.tensor([CATEGORIES[category_name]], dtype=torch.long)
86
 
87
+ # Saf PyTorch ile Hızlı Hesaplamalar (C++ Yok)
88
+ fps_idx_1 = fps_pure(pos, ratio=0.5)
89
  pos2 = pos[fps_idx_1]
90
  batch2 = torch.zeros(pos2.size(0), dtype=torch.long)
91
 
92
+ fps_idx_2 = fps_pure(pos2, ratio=0.25)
93
  pos3 = pos2[fps_idx_2]
94
  batch3 = torch.zeros(pos3.size(0), dtype=torch.long)
95
 
96
+ edge_index_1 = radius_graph_pure(pos, r=0.15, max_num_neighbors=16)
97
+ edge_index_2 = radius_graph_pure(pos2, r=0.30, max_num_neighbors=32)
98
 
99
+ assign_index_32 = knn_pure(pos3, pos2, k=3)
100
+ assign_index_21 = knn_pure(pos2, pos, k=3)
101
 
102
  dummy_x = torch.zeros(pos3.size(0), 1)
103
  _, x_dense_mask = to_dense_batch(dummy_x, batch3, max_num_nodes=int(torch.bincount(batch3).max().item()))
104
 
105
  return pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask
106
 
107
+ # --- Çıkarım Fonksiyonu ---
108
  def predict(file, category_name):
109
  if file is None:
110
  return None
111
 
 
112
  (pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
113
  edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask) = process_point_cloud(file.name, category_name)
114
 
 
115
  with torch.no_grad():
116
  logits = model(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
117
  edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask)
118
 
 
119
  valid_parts = SEG_CLASSES[category_name]
120
  mask = torch.zeros_like(logits, dtype=torch.bool)
121
  mask[:, valid_parts] = True
122
  logits = logits.masked_fill(~mask, -10000.0)
 
 
123
  preds = logits.argmax(dim=-1).numpy()
124
 
 
125
  pos_np = pos.numpy()
126
  fig = go.Figure(data=[go.Scatter3d(
127
  x=pos_np[:, 0], y=pos_np[:, 1], z=pos_np[:, 2],
128
  mode='markers',
129
+ marker=dict(size=4, color=preds, colorscale='Viridis', opacity=0.9, line=dict(width=0.5, color='black'))
 
 
 
 
 
 
130
  )])
131
 
132
+ fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
 
 
 
 
133
  return fig
134
 
135
  # --- Gradio Arayüzü ---
model.py CHANGED
@@ -3,10 +3,8 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch_geometric.nn import MessagePassing
5
  from torch_geometric.utils import to_dense_batch
6
- from torch_cluster import fps, radius_graph, knn
7
  import math
8
 
9
- # --- TRITON & CPU FALLBACK KONTROLÜ ---
10
  try:
11
  import triton
12
  import triton.language as tl
@@ -15,7 +13,6 @@ except ImportError:
15
  HAS_TRITON = False
16
 
17
  def pytorch_clifford_product(a, b):
18
- """Hugging Face Free Tier (CPU) için saf PyTorch 8D Geometrik Çarpım"""
19
  res = torch.zeros_like(a)
20
  res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
21
  res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
@@ -28,13 +25,10 @@ def pytorch_clifford_product(a, b):
28
  return res
29
 
30
  def smart_clifford_product(a, b):
31
- # Eğer ortamda GPU ve Triton yoksa, CPU versiyonunu kullan.
32
  if HAS_TRITON:
33
- # Eğitilmiş Triton kodlarınız buraya gelir (Çıkarım için genellikle CPU yeterlidir)
34
  pass
35
  return pytorch_clifford_product(a, b)
36
 
37
-
38
  class CliffordDiracLayer(MessagePassing):
39
  def __init__(self, channels: int):
40
  super().__init__(aggr="add", node_dim=0)
 
3
  import torch.nn.functional as F
4
  from torch_geometric.nn import MessagePassing
5
  from torch_geometric.utils import to_dense_batch
 
6
  import math
7
 
 
8
  try:
9
  import triton
10
  import triton.language as tl
 
13
  HAS_TRITON = False
14
 
15
  def pytorch_clifford_product(a, b):
 
16
  res = torch.zeros_like(a)
17
  res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
18
  res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
 
25
  return res
26
 
27
  def smart_clifford_product(a, b):
 
28
  if HAS_TRITON:
 
29
  pass
30
  return pytorch_clifford_product(a, b)
31
 
 
32
  class CliffordDiracLayer(MessagePassing):
33
  def __init__(self, channels: int):
34
  super().__init__(aggr="add", node_dim=0)
requirements.txt CHANGED
@@ -1,8 +1,5 @@
1
  torch
2
  torch-geometric
3
- torch-scatter
4
- torch-sparse
5
- torch-cluster
6
  gradio
7
  plotly
8
  numpy
 
1
  torch
2
  torch-geometric
 
 
 
3
  gradio
4
  plotly
5
  numpy