Upload layers.py
Browse files- pgm/layers.py +2 -2
pgm/layers.py
CHANGED
|
@@ -173,8 +173,8 @@ class ArgMaxGumbelMax(Transform):
|
|
| 173 |
|
| 174 |
def sample_gumbel(shape, eps=1e-20):
|
| 175 |
U = torch.rand(shape)
|
| 176 |
-
|
| 177 |
-
U = U.to(torch.device('cuda:1'))
|
| 178 |
return -torch.log(-torch.log(U + eps) + eps)
|
| 179 |
def gumbel_softmax_sample(logits, temperature):
|
| 180 |
y = logits + sample_gumbel(logits.shape)
|
|
|
|
| 173 |
|
| 174 |
def sample_gumbel(shape, eps=1e-20):
|
| 175 |
U = torch.rand(shape)
|
| 176 |
+
U = U.to(torch.device('cpu'))
|
| 177 |
+
# U = U.to(torch.device('cuda:1'))
|
| 178 |
return -torch.log(-torch.log(U + eps) + eps)
|
| 179 |
def gumbel_softmax_sample(logits, temperature):
|
| 180 |
y = logits + sample_gumbel(logits.shape)
|