| print("Importing standard...") |
| import subprocess |
| import shutil |
| from pathlib import Path |
|
|
| print("Importing external...") |
| import torch |
| import numpy as np |
| from PIL import Image |
|
|
| REDUCTION = "pca" |
| if REDUCTION == "umap": |
| from umap import UMAP |
| elif REDUCTION == "tsne": |
| from sklearn.manifold import TSNE |
| elif REDUCTION == "pca": |
| from sklearn.decomposition import PCA |
|
|
|
|
| def symlog(x): |
| return torch.sign(x) * torch.log(torch.abs(x) + 1) |
|
|
|
|
| def preprocess_masks_features(masks, features): |
| |
| B, M, H, W = masks.shape |
| Bf, F, Hf, Wf = features.shape |
| masks = masks.reshape(B, M, 1, H * W) |
| |
| |
| |
| |
|
|
| |
| mask_areas = masks.sum(dim=3) |
| features = features.reshape(B, 1, F, H * W) |
| |
| |
| |
|
|
| return masks, features, M, B, H, W, F |
|
|
|
|
| def get_row_col(H, W, device): |
| |
| row = torch.linspace(0, 1, H, device=device) |
| col = torch.linspace(0, 1, W, device=device) |
| return row, col |
|
|
|
|
| def get_current_git_commit(): |
| try: |
| |
| commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip() |
| |
| return commit_hash.decode("utf-8") |
| except subprocess.CalledProcessError: |
| |
| print("An error occurred while trying to retrieve the git commit hash.") |
| return None |
|
|
|
|
| def clean_dir(dirname): |
| """Removes all directories in dirname that don't have a done.txt file""" |
| dstdir = Path(dirname) |
| dstdir.mkdir(exist_ok=True, parents=True) |
| for f in dstdir.iterdir(): |
| |
| if f.is_dir() and not (f / "done.txt").exists(): |
| shutil.rmtree(f) |
|
|
|
|
| def save_tensor_as_image(tensor, dstfile, global_step): |
| dstfile = Path(dstfile) |
| dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix( |
| ".jpg" |
| ) |
| save(tensor, str(dstfile)) |
|
|
|
|
| def minmaxnorm(x): |
| return (x - x.min()) / (x.max() - x.min()) |
|
|
|
|
| def save(tensor, name, channel_offset=0): |
| tensor = to_img(tensor, channel_offset=channel_offset) |
| Image.fromarray(tensor).save(name) |
|
|
|
|
| def to_img(tensor, channel_offset=0): |
| tensor = minmaxnorm(tensor) |
| tensor = (tensor * 255).to(torch.uint8) |
| C, H, W = tensor.shape |
| if tensor.shape[0] == 1: |
| tensor = tensor[0] |
| elif tensor.shape[0] == 2: |
| tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0) |
| tensor = tensor.permute(1, 2, 0) |
| elif tensor.shape[0] >= 3: |
| tensor = tensor[channel_offset : channel_offset + 3] |
| tensor = tensor.permute(1, 2, 0) |
| tensor = tensor.cpu().numpy() |
| return tensor |
|
|
|
|
| def log_input_output( |
| name, |
| x, |
| y_hat, |
| global_step, |
| img_dstdir, |
| out_dstdir, |
| reduce_dim=True, |
| reduction=REDUCTION, |
| resample_size=20000, |
| ): |
| y_hat = y_hat.reshape( |
| y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4] |
| ) |
| if reduce_dim and y_hat.shape[1] >= 3: |
| reducer = ( |
| UMAP(n_components=3) |
| if (reduction == "umap") |
| else ( |
| TSNE(n_components=3) |
| if reduction == "tsne" |
| else PCA(n_components=3) |
| if reduction == "pca" |
| else None |
| ) |
| ) |
| np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() |
| np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) |
| np_y_hat = np_y_hat.T |
| sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size] |
| print("dim reduction fit..." + " " * 30, end="\r") |
| reducer = reducer.fit(sampled_pixels) |
| print("dim reduction transform..." + " " * 30, end="\r") |
| reducer.transform(np_y_hat[:10]) |
| np_y_hat = reducer.transform(np_y_hat) |
| |
| y_hat2 = ( |
| torch.from_numpy( |
| np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3]) |
| ) |
| .to(y_hat.device) |
| .permute(1, 0, 2, 3) |
| ) |
| print("done" + " " * 30, end="\r") |
| else: |
| y_hat2 = y_hat |
|
|
| for i in range(min(len(x), 8)): |
| save_tensor_as_image( |
| x[i], |
| img_dstdir / f"input_{name}_{str(i).zfill(2)}", |
| global_step=global_step, |
| ) |
| for c in range(y_hat.shape[1]): |
| save_tensor_as_image( |
| y_hat[i, c : c + 1], |
| out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}", |
| global_step=global_step, |
| ) |
| |
|
|
| assert len(y_hat2.shape) == 4, "should be B, F, H, W" |
| if reduce_dim: |
| save_tensor_as_image( |
| y_hat2[i][:3], |
| out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}", |
| global_step=global_step, |
| ) |
| save_tensor_as_image( |
| y_hat[i][:3], |
| out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}", |
| global_step=global_step, |
| ) |
|
|
|
|
| def check_for_nan(loss, model, batch): |
| try: |
| assert torch.isnan(loss) == False |
| except Exception as e: |
| |
| |
| print("img batch contains nan?", torch.isnan(batch[0]).any()) |
| print("mask batch contains nan?", torch.isnan(batch[1]).any()) |
| |
| for name, param in model.named_parameters(): |
| if torch.isnan(param).any(): |
| print(name, "contains nan") |
| |
| print("output contains nan?", torch.isnan(model(batch[0])).any()) |
| |
| raise e |
|
|
|
|
| def calculate_iou(pred, label): |
| intersection = ((label == 1) & (pred == 1)).sum() |
| union = ((label == 1) | (pred == 1)).sum() |
| if not union: |
| return 0 |
| else: |
| iou = intersection.item() / union.item() |
| return iou |
|
|
|
|
| def load_from_ckpt(net, ckpt_path, strict=True): |
| """Load network weights""" |
| if ckpt_path and Path(ckpt_path).exists(): |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| if "MODEL_STATE" in ckpt: |
| ckpt = ckpt["MODEL_STATE"] |
| elif "state_dict" in ckpt: |
| ckpt = ckpt["state_dict"] |
| net.load_state_dict(ckpt, strict=strict) |
| print("Loaded checkpoint from", ckpt_path) |
| return net |
|
|