Upload 2 files
Browse files
app.py
CHANGED
|
@@ -221,7 +221,7 @@ def get_chest_obs(idx=None):
|
|
| 221 |
idx, obs = get_obs_item(dataset_id, idx)
|
| 222 |
x = get_fig_arr(postprocess(obs["x"].clone()))
|
| 223 |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
| 224 |
-
f = FIND_CAT[
|
| 225 |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
| 226 |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
| 227 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
|
|
|
| 221 |
idx, obs = get_obs_item(dataset_id, idx)
|
| 222 |
x = get_fig_arr(postprocess(obs["x"].clone()))
|
| 223 |
s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())]
|
| 224 |
+
f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)]
|
| 225 |
r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)]
|
| 226 |
a = (obs["age"].clone().squeeze().numpy() + 1) * 50
|
| 227 |
return (idx, x, r, s, f, float(np.round(a, 1)))
|
vae.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch.distributions as dist
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from torch import Tensor, nn
|
| 8 |
|
| 9 |
-
from hps import Hparams
|
| 10 |
|
| 11 |
EPS = -9 # minimum logscale
|
| 12 |
|
|
@@ -85,7 +85,7 @@ class Block(nn.Module):
|
|
| 85 |
|
| 86 |
|
| 87 |
class Encoder(nn.Module):
|
| 88 |
-
def __init__(self, args
|
| 89 |
super().__init__()
|
| 90 |
# parse architecture
|
| 91 |
stages = []
|
|
@@ -135,7 +135,7 @@ class Encoder(nn.Module):
|
|
| 135 |
|
| 136 |
|
| 137 |
class DecoderBlock(nn.Module):
|
| 138 |
-
def __init__(self, args
|
| 139 |
super().__init__()
|
| 140 |
bottleneck = int(in_width / args.bottleneck)
|
| 141 |
self.res = resolution
|
|
@@ -207,7 +207,7 @@ class DecoderBlock(nn.Module):
|
|
| 207 |
|
| 208 |
|
| 209 |
class Decoder(nn.Module):
|
| 210 |
-
def __init__(self, args
|
| 211 |
super().__init__()
|
| 212 |
# parse architecture
|
| 213 |
stages = []
|
|
@@ -335,7 +335,7 @@ class Decoder(nn.Module):
|
|
| 335 |
|
| 336 |
|
| 337 |
class DGaussNet(nn.Module):
|
| 338 |
-
def __init__(self, args
|
| 339 |
super(DGaussNet, self).__init__()
|
| 340 |
self.x_loc = nn.Conv2d(
|
| 341 |
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
|
@@ -438,7 +438,7 @@ class DGaussNet(nn.Module):
|
|
| 438 |
|
| 439 |
|
| 440 |
class HVAE(nn.Module):
|
| 441 |
-
def __init__(self, args
|
| 442 |
super().__init__()
|
| 443 |
args.vr = "light" if "ukbb" in args.hps else None # hacky
|
| 444 |
self.encoder = Encoder(args)
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from torch import Tensor, nn
|
| 8 |
|
| 9 |
+
# from hps import Hparams
|
| 10 |
|
| 11 |
EPS = -9 # minimum logscale
|
| 12 |
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
class Encoder(nn.Module):
|
| 88 |
+
def __init__(self, args):
|
| 89 |
super().__init__()
|
| 90 |
# parse architecture
|
| 91 |
stages = []
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
class DecoderBlock(nn.Module):
|
| 138 |
+
def __init__(self, args, in_width, out_width, resolution):
|
| 139 |
super().__init__()
|
| 140 |
bottleneck = int(in_width / args.bottleneck)
|
| 141 |
self.res = resolution
|
|
|
|
| 207 |
|
| 208 |
|
| 209 |
class Decoder(nn.Module):
|
| 210 |
+
def __init__(self, args):
|
| 211 |
super().__init__()
|
| 212 |
# parse architecture
|
| 213 |
stages = []
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
class DGaussNet(nn.Module):
|
| 338 |
+
def __init__(self, args):
|
| 339 |
super(DGaussNet, self).__init__()
|
| 340 |
self.x_loc = nn.Conv2d(
|
| 341 |
args.widths[0], args.input_channels, kernel_size=1, stride=1
|
|
|
|
| 438 |
|
| 439 |
|
| 440 |
class HVAE(nn.Module):
|
| 441 |
+
def __init__(self, args):
|
| 442 |
super().__init__()
|
| 443 |
args.vr = "light" if "ukbb" in args.hps else None # hacky
|
| 444 |
self.encoder = Encoder(args)
|