HeavyAE (Heavy AutoEncoder)

HeavyAE is a high-resolution symmetric convolutional autoencoder optimized for reconstructing mobile interface screenshots at a native aspect ratio.

Model Details

  • Model Type: Convolutional Autoencoder (Non-Variational)
  • Parameters: ~12.6 Million
  • Input Resolution: 1600 x 720 (RGB)
  • Latent Bottleneck: 8 Channels
  • Activation: LeakyReLU (Encoder/Decoder) | Sigmoid (Output)

Metadata

Feature Value
Layers 10 (5 Encoder, 5 Decoder)
Target Size 1600x720
Latent Space 8x45x100
Format PyTorch (model.pt)

Benchmarks

The following results were obtained from internal testing samples (e.g., UI reconstruction tasks):

  • Average Reconstruction Accuracy: 85.41%
  • Average MSE: 0.0058

[Insert Benchmark] I will edit it after i have one, too lazy to do so.

Model Performance Benchmark: AE (10M Parameters)

Architecture: 1600×720 Autoencoder | Bottleneck: 96:1 | Optimization: Secret

Test Case Input Resolution Aspect Ratio Original Accuracy Low-Noise Acc High-Noise Acc Primary Challenge
IRL Faces (Forest) ~4k×3k (HQ) 4:3 94.45% 94.44% 94.07% Complex gradients & textures
AI Generated Art 128×128 (LQ) 1:1 96.68% 96.61% 95.85% Upscaling/Interpolation noise
Digital Doodle 720×720 (MD) 1:1 95.91% 95.90% 95.71% Sharp high-contrast edges

Inference & Stress Test (Google Colab)

import torch
import torch.nn as nn
import numpy as np
import requests
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

class HeavyAE(nn.Module):
    def __init__(self):
        super(HeavyAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, 3, stride=2, padding=1), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, stride=2, padding=1), nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1024, 3, stride=2, padding=1), nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 8, 3, stride=1, padding=1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 1024, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(1024, 512, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 3, 3, stride=1, padding=1), nn.Sigmoid()
        )
    def forward(self, x): return self.decoder(self.encoder(x))

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HeavyAE().to(device)
model_url = "https://huggingface.co/Parallax-labs-1/parallax_VISION-ValidPhone/resolve/main/model.pt"

# Download weights
response = requests.get(model_url)
with open("model.pt", "wb") as f:
    f.write(response.content)

model.load_state_dict(torch.load("model.pt", map_location=device))
model.eval()

def test_model(img_path):
    orig = Image.open(img_path).convert('RGB')
    w, h = orig.size

    preprocess = transforms.Compose([transforms.Resize((720, 1600)), transforms.ToTensor()])
    input_t = preprocess(orig).unsqueeze(0).to(device)

    with torch.no_grad():
        recon = model(input_t)
        # Stress Tests
        noise_l = model(input_t + torch.randn_like(input_t) * 0.05)
        noise_h = model(input_t + torch.randn_like(input_t) * 0.2)

    # Metrics
    def acc(a, b): return (1 - torch.mean(torch.abs(a - b)).item()) * 100
    print(f"--- Log ---\nOriginal Accuracy: {acc(input_t, recon):.2f}%")
    print(f"Low-Noise Accuracy: {acc(input_t, noise_l):.2f}%")
    print(f"High-Noise Accuracy: {acc(input_t, noise_h):.2f}%")

    # Output Images
    res = transforms.ToPILImage()(recon.squeeze().cpu()).resize((w, h))
    diff = np.abs(np.array(orig).astype(float) - np.array(res).astype(float)).astype(np.uint8)

    fig, ax = plt.subplots(1, 3, figsize=(18, 6))
    ax[0].imshow(orig); ax[0].set_title("Input")
    ax[1].imshow(res); ax[1].set_title("Reconstruction")
    ax[2].imshow(diff); ax[2].set_title("Error Map")
    for a in ax: a.axis('off')
    plt.show()
Downloads last month
54
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support