--- license: apache-2.0 library_name: pytorch tags: - autoencoder - vision - image-reconstruction - heavy-ae metrics: - accuracy - mse --- # HeavyAE (Heavy AutoEncoder) HeavyAE is a high-resolution symmetric convolutional autoencoder optimized for reconstructing mobile interface screenshots at a native aspect ratio. ## Model Details - **Model Type:** Convolutional Autoencoder (Non-Variational) - **Parameters:** ~12.6 Million - **Input Resolution:** 1600 x 720 (RGB) - **Latent Bottleneck:** 8 Channels - **Activation:** LeakyReLU (Encoder/Decoder) | Sigmoid (Output) ## Metadata | Feature | Value | |---|---| | Layers | 10 (5 Encoder, 5 Decoder) | | Target Size | 1600x720 | | Latent Space | 8x45x100 | | Format | PyTorch (`model.pt`) | ## Benchmarks The following results were obtained from internal testing samples (e.g., UI reconstruction tasks): * **Average Reconstruction Accuracy:** 85.41% * **Average MSE:** 0.0058 [Insert Benchmark] **I will edit it after i have one, too lazy to do so.** ### Model Performance Benchmark: AE (10M Parameters) **Architecture:** 1600×720 Autoencoder | **Bottleneck:** 96:1 | **Optimization:** Secret | Test Case | Input Resolution | Aspect Ratio | Original Accuracy | Low-Noise Acc | High-Noise Acc | Primary Challenge | |---|---|---|---|---|---|---| | **IRL Faces (Forest)** | ~4k×3k (HQ) | 4:3 | **94.45%** | 94.44% | 94.07% | Complex gradients & textures | | **AI Generated Art** | 128×128 (LQ) | 1:1 | **96.68%** | 96.61% | 95.85% | Upscaling/Interpolation noise | | **Digital Doodle** | 720×720 (MD) | 1:1 | **95.91%** | 95.90% | 95.71% | Sharp high-contrast edges | ## Inference & Stress Test (Google Colab) ```python import torch import torch.nn as nn import numpy as np import requests from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt class HeavyAE(nn.Module): def __init__(self): super(HeavyAE, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 128, 3, stride=2, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, 3, stride=2, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(512, 1024, 3, stride=2, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 8, 3, stride=1, padding=1) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(8, 1024, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2), nn.ConvTranspose2d(1024, 512, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2), nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2), nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2), nn.ConvTranspose2d(128, 3, 3, stride=1, padding=1), nn.Sigmoid() ) def forward(self, x): return self.decoder(self.encoder(x)) # Setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = HeavyAE().to(device) model_url = "https://huggingface.co/Parallax-labs-1/parallax_VISION-ValidPhone/resolve/main/model.pt" # Download weights response = requests.get(model_url) with open("model.pt", "wb") as f: f.write(response.content) model.load_state_dict(torch.load("model.pt", map_location=device)) model.eval() def test_model(img_path): orig = Image.open(img_path).convert('RGB') w, h = orig.size preprocess = transforms.Compose([transforms.Resize((720, 1600)), transforms.ToTensor()]) input_t = preprocess(orig).unsqueeze(0).to(device) with torch.no_grad(): recon = model(input_t) # Stress Tests noise_l = model(input_t + torch.randn_like(input_t) * 0.05) noise_h = model(input_t + torch.randn_like(input_t) * 0.2) # Metrics def acc(a, b): return (1 - torch.mean(torch.abs(a - b)).item()) * 100 print(f"--- Log ---\nOriginal Accuracy: {acc(input_t, recon):.2f}%") print(f"Low-Noise Accuracy: {acc(input_t, noise_l):.2f}%") print(f"High-Noise Accuracy: {acc(input_t, noise_h):.2f}%") # Output Images res = transforms.ToPILImage()(recon.squeeze().cpu()).resize((w, h)) diff = np.abs(np.array(orig).astype(float) - np.array(res).astype(float)).astype(np.uint8) fig, ax = plt.subplots(1, 3, figsize=(18, 6)) ax[0].imshow(orig); ax[0].set_title("Input") ax[1].imshow(res); ax[1].set_title("Reconstruction") ax[2].imshow(diff); ax[2].set_title("Error Map") for a in ax: a.axis('off') plt.show() ```