| import torch |
| import numpy as np |
| import gradio as gr |
| import matplotlib.pylab as plt |
| import torch.nn.functional as F |
|
|
| from vae import HVAE |
| from datasets import morphomnist, ukbb, mimic, get_attr_max_min |
| from pgm.flow_pgm import MorphoMNISTPGM, FlowPGM, ChestPGM |
| from app_utils import ( |
| mnist_graph, |
| brain_graph, |
| chest_graph, |
| vae_preprocess, |
| normalize, |
| preprocess_brain, |
| get_fig_arr, |
| postprocess, |
| MidpointNormalize, |
| ) |
|
|
| DATA, MODELS = {}, {} |
| for k in ["Morpho-MNIST", "Brain MRI", "Chest X-ray"]: |
| DATA[k], MODELS[k] = {}, {} |
|
|
| |
| DIGITS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] |
| |
| MRISEQ_CAT = ["T1", "T2-FLAIR"] |
| SEX_CAT = ["female", "male"] |
| |
| HEIGHT, WIDTH = 500, 500 |
| |
| SEX_CAT_CHEST = ["female", "male"] |
| RACE_CAT = ["white", "black", "asian"] |
| |
| FIND_CAT = ["no disease", "effusion", "pneumonia"] |
| DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class Hparams: |
| def update(self, dict): |
| for k, v in dict.items(): |
| setattr(self, k, v) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def get_paths(dataset_id): |
| if "MNIST" in dataset_id: |
| data_path = "./data/morphomnist" |
| pgm_path = "./checkpoints/t_i_d/sup_pgm/checkpoint.pt" |
| vae_path = "./checkpoints/t_i_d/dgauss_cond_big_beta1_dropexo/checkpoint.pt" |
| elif "Brain" in dataset_id: |
| data_path = "./data/ukbb_subset" |
| pgm_path = "./checkpoints/m_b_v_s/sup_pgm/checkpoint.pt" |
| vae_path = "./checkpoints/m_b_v_s/ukbb192_beta5_dgauss_b33/checkpoint.pt" |
| elif "Chest" in dataset_id: |
| data_path = "./data/mimic_subset" |
| pgm_path = "./checkpoints/a_r_s_f/sup_pgm_mimic/60k_checkpoint.pt" |
| vae_path = [ |
| "./checkpoints/a_r_s_f/mimic_beta9_gelu_dgauss_1_lr3/60k_checkpoint.pt", |
| "./checkpoints/a_r_s_f/mimic_dscm_lr_1e5_lagrange_lr_1_damping_10/60k_checkpoint.pt", |
| ] |
| return data_path, vae_path, pgm_path |
|
|
|
|
|
|
| def load_pgm(dataset_id, pgm_path): |
| checkpoint = torch.load(pgm_path, map_location=DEVICE) |
| args = Hparams() |
| args.update(checkpoint["hparams"]) |
| args.device = DEVICE |
| if "MNIST" in dataset_id: |
| pgm = MorphoMNISTPGM(args).to(args.device) |
| elif "Brain" in dataset_id: |
| pgm = FlowPGM(args).to(args.device) |
| elif "Chest" in dataset_id: |
| pgm = ChestPGM(args).to(args.device) |
| pgm.load_state_dict(checkpoint["ema_model_state_dict"]) |
| MODELS[dataset_id]["pgm"] = pgm |
| MODELS[dataset_id]["pgm_args"] = args |
|
|
|
|
| def load_vae(dataset_id, vae_path): |
| if "Chest" in dataset_id: |
| vae_path, dscm_path = vae_path[0], vae_path[1] |
| checkpoint = torch.load(vae_path, map_location=DEVICE) |
| args = Hparams() |
| args.update(checkpoint["hparams"]) |
| |
| if not hasattr(args, "vae"): |
| args.vae = "hierarchical" |
| if not hasattr(args, "cond_prior"): |
| args.cond_prior = False |
| if hasattr(args, "free_bits"): |
| args.kl_free_bits = args.free_bits |
| args.device = DEVICE |
| vae = HVAE(args).to(args.device) |
|
|
| if "Chest" in dataset_id: |
| dscm_ckpt = torch.load(dscm_path, map_location=DEVICE) |
| vae.load_state_dict( |
| { |
| k[4:]: v |
| for k, v in dscm_ckpt["ema_model_state_dict"].items() |
| if "vae." in k |
| } |
| ) |
| else: |
| vae.load_state_dict(checkpoint["ema_model_state_dict"]) |
| MODELS[dataset_id]["vae"] = vae |
| MODELS[dataset_id]["vae_args"] = args |
| print(MODELS[dataset_id]["vae_args"]) |
|
|
|
|
| def get_dataloader(dataset_id, data_path): |
| MODELS[dataset_id]["pgm_args"].data_dir = data_path |
| args = MODELS[dataset_id]["pgm_args"] |
| if "MNIST" in dataset_id: |
| datasets = morphomnist(args) |
| elif "Brain" in dataset_id: |
| datasets = ukbb(args) |
| elif "Chest" in dataset_id: |
| datasets = mimic(args) |
| DATA[dataset_id]["test"] = torch.utils.data.DataLoader( |
| datasets["test"], shuffle=False, batch_size=args.bs, num_workers=4 |
| ) |
|
|
|
|
| def load_dataset(dataset_id): |
| data_path, _, pgm_path = get_paths(dataset_id) |
| checkpoint = torch.load(pgm_path, map_location=DEVICE) |
| args = Hparams() |
| args.update(checkpoint["hparams"]) |
| args.device = DEVICE |
| MODELS[dataset_id]["pgm_args"] = args |
| get_dataloader(dataset_id, data_path) |
|
|
|
|
| def load_model(dataset_id): |
| _, vae_path, pgm_path = get_paths(dataset_id) |
| load_pgm(dataset_id, pgm_path) |
| load_vae(dataset_id, vae_path) |
|
|
|
|
| @torch.no_grad() |
| def counterfactual_inference(dataset_id, obs, do_pa): |
| pa = {k: v.clone() for k, v in obs.items() if k != "x"} |
| cf_pa = MODELS[dataset_id]["pgm"].counterfactual( |
| obs=pa, intervention=do_pa, num_particles=1 |
| ) |
| args, vae = MODELS[dataset_id]["vae_args"], MODELS[dataset_id]["vae"] |
| _pa = vae_preprocess(args, {k: v.clone() for k, v in pa.items()}) |
| _cf_pa = vae_preprocess(args, {k: v.clone() for k, v in cf_pa.items()}) |
| z_t = 0.1 if "mnist" in args.hps else 1.0 |
| z = vae.abduct(x=obs["x"], parents=_pa, t=z_t) |
| if vae.cond_prior: |
| z = [z[j]["z"] for j in range(len(z))] |
| px_loc, px_scale = vae.forward_latents(latents=z, parents=_pa) |
| cf_loc, cf_scale = vae.forward_latents(latents=z, parents=_cf_pa) |
| u = (obs["x"] - px_loc) / px_scale.clamp(min=1e-12) |
| u_t = 0.1 if "mnist" in args.hps else 1.0 |
| cf_scale = cf_scale * u_t |
| cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1) |
| return {"cf_x": cf_x, "rec_x": px_loc, "cf_pa": cf_pa} |
|
|
|
|
| def get_obs_item(dataset_id, idx=None): |
| if idx is None: |
| n_test = len(DATA[dataset_id]["test"].dataset) |
| idx = torch.randperm(n_test)[0] |
| idx = int(idx) |
| return idx, DATA[dataset_id]["test"].dataset.__getitem__(idx) |
|
|
|
|
| def get_mnist_obs(idx=None): |
| dataset_id = "Morpho-MNIST" |
| if not DATA[dataset_id]: |
| load_dataset(dataset_id) |
| idx, obs = get_obs_item(dataset_id, idx) |
| x = get_fig_arr(obs["x"].clone().squeeze().numpy()) |
| t = (obs["thickness"].clone() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526 |
| i = (obs["intensity"].clone() + 1) / 2 * (254.90317 - 66.601204) + 66.601204 |
| y = DIGITS[obs["digit"].clone().argmax(-1)] |
| return (idx, x, float(np.round(t, 2)), float(np.round(i, 2)), y) |
|
|
|
|
| def get_brain_obs(idx=None): |
| dataset_id = "Brain MRI" |
| if not DATA[dataset_id]: |
| load_dataset(dataset_id) |
| idx, obs = get_obs_item(dataset_id, idx) |
| x = get_fig_arr(obs["x"].clone().squeeze().numpy()) |
| m = MRISEQ_CAT[int(obs["mri_seq"].clone().item())] |
| s = SEX_CAT[int(obs["sex"].clone().item())] |
| a = obs["age"].clone().item() |
| b = obs["brain_volume"].clone().item() / 1000 |
| v = obs["ventricle_volume"].clone().item() / 1000 |
| return (idx, x, m, s, a, float(np.round(b, 2)), float(np.round(v, 2))) |
|
|
|
|
| def get_chest_obs(idx=None): |
| dataset_id = "Chest X-ray" |
| if not DATA[dataset_id]: |
| load_dataset(dataset_id) |
| idx, obs = get_obs_item(dataset_id, idx) |
| x = get_fig_arr(postprocess(obs["x"].clone())) |
| s = SEX_CAT_CHEST[int(obs["sex"].clone().squeeze().numpy())] |
| f = FIND_CAT[obs["finding"].clone().squeeze().numpy().argmax(-1)] |
| r = RACE_CAT[obs["race"].clone().squeeze().numpy().argmax(-1)] |
| a = (obs["age"].clone().squeeze().numpy() + 1) * 50 |
| return (idx, x, r, s, f, float(np.round(a, 1))) |
|
|
|
|
| def infer_mnist_cf(*args): |
| dataset_id = "Morpho-MNIST" |
| idx, _, t, i, y, do_t, do_i, do_y = args |
| n_particles = 32 |
| |
| obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) |
| obs["x"] = (obs["x"] - 127.5) / 127.5 |
| for k, v in obs.items(): |
| obs[k] = v.view(1, 1) if len(v.shape) < 1 else v.unsqueeze(0) |
| obs[k] = obs[k].to(MODELS[dataset_id]["vae_args"].device).float() |
| if n_particles > 1: |
| ndims = (1,) * 3 if k == "x" else (1,) |
| obs[k] = obs[k].repeat(n_particles, *ndims) |
| |
| do_pa = {} |
| if do_t: |
| do_pa["thickness"] = torch.tensor( |
| normalize(t, x_max=6.255515, x_min=0.87598526) |
| ).view(1, 1) |
| if do_i: |
| do_pa["intensity"] = torch.tensor( |
| normalize(i, x_max=254.90317, x_min=66.601204) |
| ).view(1, 1) |
| if do_y: |
| do_pa["digit"] = F.one_hot(torch.tensor(DIGITS.index(y)), num_classes=10).view( |
| 1, 10 |
| ) |
|
|
| for k, v in do_pa.items(): |
| do_pa[k] = ( |
| v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) |
| ) |
| |
| out = counterfactual_inference(dataset_id, obs, do_pa) |
| |
| cf_x = out["cf_x"].mean(0) |
| cf_x_std = out["cf_x"].std(0) |
| rec_x = out["rec_x"].mean(0) |
| cf_t = out["cf_pa"]["thickness"].mean(0) |
| cf_i = out["cf_pa"]["intensity"].mean(0) |
| cf_y = out["cf_pa"]["digit"].mean(0) |
| |
| cf_x = postprocess(cf_x) |
| cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() |
| rec_x = postprocess(rec_x) |
| cf_t = np.round((cf_t.item() + 1) / 2 * (6.255515 - 0.87598526) + 0.87598526, 2) |
| cf_i = np.round((cf_i.item() + 1) / 2 * (254.90317 - 66.601204) + 66.601204, 2) |
| cf_y = DIGITS[cf_y.argmax(-1)] |
| |
| |
| effect = cf_x - rec_x |
| effect = get_fig_arr( |
| effect, cmap="RdBu_r", norm=MidpointNormalize(vmin=-255, midpoint=0, vmax=255) |
| ) |
| cf_x = get_fig_arr(cf_x) |
| cf_x_std = get_fig_arr(cf_x_std, cmap="jet") |
| return (cf_x, cf_x_std, effect, cf_t, cf_i, cf_y) |
|
|
|
|
| def infer_brain_cf(*args): |
| dataset_id = "Brain MRI" |
| idx, _, m, s, a, b, v = args[:7] |
| do_m, do_s, do_a, do_b, do_v = args[7:] |
| n_particles = 16 |
| |
| obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) |
| obs = preprocess_brain(MODELS[dataset_id]["vae_args"], obs) |
| for k, _v in obs.items(): |
| if n_particles > 1: |
| ndims = (1,) * 3 if k == "x" else (1,) |
| obs[k] = _v.repeat(n_particles, *ndims) |
| |
| do_pa = {} |
| if do_m: |
| do_pa["mri_seq"] = torch.tensor(MRISEQ_CAT.index(m)).view(1, 1) |
| if do_s: |
| do_pa["sex"] = torch.tensor(SEX_CAT.index(s)).view(1, 1) |
| if do_a: |
| do_pa["age"] = torch.tensor(a).view(1, 1) |
| if do_b: |
| do_pa["brain_volume"] = torch.tensor(b * 1000).view(1, 1) |
| if do_v: |
| do_pa["ventricle_volume"] = torch.tensor(v * 1000).view(1, 1) |
| |
| for k in ["age", "brain_volume", "ventricle_volume"]: |
| if k in do_pa.keys(): |
| k_max, k_min = get_attr_max_min(k) |
| do_pa[k] = (do_pa[k] - k_min) / (k_max - k_min) |
| do_pa[k] = 2 * do_pa[k] - 1 |
|
|
| for k, _v in do_pa.items(): |
| do_pa[k] = ( |
| _v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) |
| ) |
| |
| out = counterfactual_inference(dataset_id, obs, do_pa) |
| |
| cf_x = out["cf_x"].mean(0) |
| cf_x_std = out["cf_x"].std(0) |
| rec_x = out["rec_x"].mean(0) |
| cf_m = out["cf_pa"]["mri_seq"].mean(0) |
| cf_s = out["cf_pa"]["sex"].mean(0) |
| |
| cf_x = postprocess(cf_x) |
| cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() |
| rec_x = postprocess(rec_x) |
| cf_m = MRISEQ_CAT[int(cf_m.item())] |
| cf_s = SEX_CAT[int(cf_s.item())] |
| cf_ = {} |
| for k in ["age", "brain_volume", "ventricle_volume"]: |
| k_max, k_min = get_attr_max_min(k) |
| cf_[k] = (out["cf_pa"][k].mean(0).item() + 1) / 2 * (k_max - k_min) + k_min |
| |
| |
| effect = cf_x - rec_x |
| effect = get_fig_arr( |
| effect, |
| cmap="RdBu_r", |
| norm=MidpointNormalize(vmin=effect.min(), midpoint=0, vmax=effect.max()), |
| ) |
| cf_x = get_fig_arr(cf_x) |
| cf_x_std = get_fig_arr(cf_x_std, cmap="jet") |
| return ( |
| cf_x, |
| cf_x_std, |
| effect, |
| cf_m, |
| cf_s, |
| np.round(cf_["age"], 1), |
| np.round(cf_["brain_volume"] / 1000, 2), |
| np.round(cf_["ventricle_volume"] / 1000, 2), |
| ) |
|
|
|
|
| def infer_chest_cf(*args): |
| dataset_id = "Chest X-ray" |
| idx, _, r, s, f, a = args[:6] |
| do_r, do_s, do_f, do_a = args[6:] |
| n_particles = 16 |
| |
| obs = DATA[dataset_id]["test"].dataset.__getitem__(int(idx)) |
| observation = obs['x'] |
| for k, v in obs.items(): |
| obs[k] = v.to(MODELS[dataset_id]["vae_args"].device).float() |
| if n_particles > 1: |
| ndims = (1,) * 3 if k == "x" else (1,) |
| obs[k] = obs[k].repeat(n_particles, *ndims) |
| |
| do_pa = {} |
| with torch.no_grad(): |
| if do_s: |
| do_pa["sex"] = torch.tensor(SEX_CAT_CHEST.index(s)).view(1, 1) |
| if do_f: |
| do_pa["finding"] = F.one_hot( |
| torch.tensor(FIND_CAT.index(f)), num_classes=3 |
| ).view(1, 3) |
| |
| if do_r: |
| do_pa["race"] = F.one_hot( |
| torch.tensor(RACE_CAT.index(r)), num_classes=3 |
| ).view(1, 3) |
| if do_a: |
| do_pa["age"] = torch.tensor(a / 100 * 2 - 1).view(1, 1) |
| for k, v in do_pa.items(): |
| do_pa[k] = ( |
| v.to(MODELS[dataset_id]["vae_args"].device).float().repeat(n_particles, 1) |
| ) |
| |
| out = counterfactual_inference(dataset_id, obs, do_pa) |
| |
| cf_x = out["cf_x"].mean(0) |
| cf_x_std = out["cf_x"].std(0) |
| rec_x = out["rec_x"].mean(0) |
| cf_r = out["cf_pa"]["race"].mean(0) |
| cf_s = out["cf_pa"]["sex"].mean(0) |
| cf_f = out["cf_pa"]["finding"].mean(0) |
| cf_a = out["cf_pa"]["age"].mean(0) |
| |
| cf_x = postprocess(cf_x) |
| cf_x_std = cf_x_std.squeeze().detach().cpu().numpy() |
| rec_x = postprocess(rec_x) |
| cf_r = RACE_CAT[cf_r.argmax(-1)] |
| cf_s = SEX_CAT_CHEST[int(cf_s.item())] |
| cf_f = FIND_CAT[cf_f.argmax(-1)] |
| cf_a = (cf_a.item() + 1) * 50 |
| |
| |
| |
| effect = cf_x - postprocess(observation) |
| effect = get_fig_arr( |
| effect, |
| cmap="RdBu_r", |
| norm=MidpointNormalize(midpoint=0), |
| |
| ) |
| cf_x = get_fig_arr(cf_x) |
| cf_x_std = get_fig_arr(cf_x_std, cmap="jet") |
| return (cf_x, cf_x_std, effect, cf_r, cf_s, cf_f, np.round(cf_a, 1)) |
|
|
|
|
| with gr.Blocks(theme=gr.themes.Default()) as demo: |
| with gr.Tabs(): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| with gr.TabItem("Chest X-ray") as chest_tab: |
| chest_id = gr.Textbox(value=chest_tab.label, visible=False) |
|
|
| with gr.Row(): |
| idx_chest = gr.Number(value=0, visible=False) |
| with gr.Column(scale=1, min_width=200): |
| x_chest = gr.Image(label="Observation", interactive=False, height=HEIGHT) |
| with gr.Column(scale=1, min_width=200): |
| cf_x_chest = gr.Image( |
| label="Counterfactual", interactive=False, height=HEIGHT) |
| with gr.Column(scale=1, min_width=200): |
| cf_x_std_chest = gr.Image( |
| label="Counterfactual Uncertainty", interactive=False, height=HEIGHT) |
| with gr.Column(scale=1, min_width=200): |
| effect_chest = gr.Image( |
| label="Direct Causal Effect", interactive=False, height=HEIGHT) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2.55): |
| |
| |
| |
| |
| |
| |
| with gr.Row(equal_height=True): |
| with gr.Column(min_width=200): |
| do_f_chest = gr.Checkbox(label="do(disease)", value=False) |
| f_chest = gr.Radio(FIND_CAT, label="", interactive=False) |
| with gr.Column(min_width=200): |
| do_s_chest = gr.Checkbox(label="do(sex)", value=False) |
| s_chest = gr.Radio( |
| SEX_CAT_CHEST, label="", interactive=False |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(min_width=200): |
| do_r_chest = gr.Checkbox(label="do(race)", value=False) |
| r_chest = gr.Radio(RACE_CAT, label="", interactive=False) |
| with gr.Column(min_width=200): |
| do_a_chest = gr.Checkbox(label="do(age)", value=False) |
| a_chest = gr.Slider( |
| label="\u00A0", minimum=18, maximum=98, step=1 |
| ) |
|
|
| with gr.Row(): |
| new_chest = gr.Button("New Observation") |
| reset_chest = gr.Button("Reset", variant="stop") |
| submit_chest = gr.Button("Submit", variant="primary") |
| with gr.Column(scale=1): |
| |
| causal_graph_chest = gr.Image( |
| label="Causal Graph", interactive=False,height=345) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| do_chest = [do_r_chest, do_s_chest, do_f_chest, do_a_chest] |
| obs_chest = [idx_chest, x_chest, r_chest, s_chest, f_chest, a_chest] |
| cf_out_chest = [cf_x_chest, cf_x_std_chest, effect_chest] |
|
|
| |
| |
| |
| |
| |
| |
|
|
| demo.load(fn=get_chest_obs, inputs=None, outputs=obs_chest) |
| demo.load(fn=load_model, inputs=chest_id) |
|
|
|
|
| |
| demo.load(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) |
|
|
| |
| |
| |
|
|
| |
| |
| new_chest.click(fn=get_chest_obs, inputs=None, outputs=obs_chest) |
| |
|
|
| |
| |
| |
| new_chest.click(fn=chest_graph, inputs=do_chest, outputs=causal_graph_chest) |
|
|
| |
| |
| |
| |
| |
| new_chest.click(fn=lambda:(gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest) |
|
|
|
|
| |
| |
| |
| reset_chest.click(fn=get_chest_obs, inputs=idx_chest, outputs=obs_chest) |
|
|
| |
| |
| |
| |
| |
| |
| |
| reset_chest.click( |
| fn=lambda: (gr.update(value=False),) * len(do_chest), |
| inputs=None, |
| outputs=do_chest, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| reset_chest.click(fn=lambda: plt.close("all"), inputs=None, outputs=None) |
| reset_chest.click(fn=lambda: (gr.update(value=None),) * 3, inputs=None, outputs=cf_out_chest) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| for _k, _v in zip(do_chest, [r_chest, s_chest, f_chest, a_chest]): |
| _k.change(fn=lambda x: gr.update(interactive=x), inputs=_k, outputs=_v) |
| _k.change(chest_graph, inputs=do_chest, outputs=causal_graph_chest) |
|
|
| |
| |
| |
| |
| |
| |
| |
| submit_chest.click( |
| fn=infer_chest_cf, |
| inputs=obs_chest + do_chest, |
| outputs=cf_out_chest + [r_chest, s_chest, f_chest, a_chest], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue() |
| demo.launch(share = False) |