yusuf-tiryaki commited on
Commit
0fb3da6
·
1 Parent(s): f7fc68e

fix model

Browse files
Files changed (1) hide show
  1. model.py +6 -0
model.py CHANGED
@@ -12,7 +12,11 @@ try:
12
  except ImportError:
13
  HAS_TRITON = False
14
 
 
15
  def pytorch_clifford_product(a, b):
 
 
 
16
  res = torch.zeros_like(a)
17
  res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
18
  res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
@@ -24,11 +28,13 @@ def pytorch_clifford_product(a, b):
24
  res[..., 7] = a[...,0]*b[...,7] + a[...,1]*b[...,5] + a[...,2]*b[...,6] + a[...,3]*b[...,4] + a[...,4]*b[...,3] + a[...,5]*b[...,1] + a[...,6]*b[...,2] + a[...,7]*b[...,0]
25
  return res
26
 
 
27
  def smart_clifford_product(a, b):
28
  if HAS_TRITON:
29
  pass
30
  return pytorch_clifford_product(a, b)
31
 
 
32
  class CliffordDiracLayer(MessagePassing):
33
  def __init__(self, channels: int):
34
  super().__init__(aggr="add", node_dim=0)
 
12
  except ImportError:
13
  HAS_TRITON = False
14
 
15
+
16
  def pytorch_clifford_product(a, b):
17
+ # DÜZELTME BURADA: a ve b matrislerini çarpmadan önce ortak boyuta genişlet (Broadcast)
18
+ a, b = torch.broadcast_tensors(a, b)
19
+
20
  res = torch.zeros_like(a)
21
  res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
22
  res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
 
28
  res[..., 7] = a[...,0]*b[...,7] + a[...,1]*b[...,5] + a[...,2]*b[...,6] + a[...,3]*b[...,4] + a[...,4]*b[...,3] + a[...,5]*b[...,1] + a[...,6]*b[...,2] + a[...,7]*b[...,0]
29
  return res
30
 
31
+
32
  def smart_clifford_product(a, b):
33
  if HAS_TRITON:
34
  pass
35
  return pytorch_clifford_product(a, b)
36
 
37
+
38
  class CliffordDiracLayer(MessagePassing):
39
  def __init__(self, channels: int):
40
  super().__init__(aggr="add", node_dim=0)