Parallax-LM's picture
Update README.md
586b078 verified
---
license: apache-2.0
datasets:
- Parallax-labs-1/dataset_VIDEO-Boxes
language:
- en
library_name: pytorch
tags:
- video-generation
- autoencoder
- latent-variable-models
- rgba
pipeline_tag: unconditional-image-generation
---
# Parallax-VIDEO-Boxes
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/)
[![Download Weights](https://img.shields.io/badge/Download-Weights-blue)](https://huggingface.co/Parallax-labs-1/parallax_VIDEO-Boxes/tree/main)
A high-performance temporal latent system designed for RGBA video synthesis. This repository hosts a dual-model architecture comprising the **GlobalHCA Autoencoder** for spatial compression and a **Latent Predictor** for state-space transitions.
## ๐Ÿš€ Overview
This project focuses on compressing 4-channel (RGBA) synthetic environments into a 1012-dimensional latent space. It is designed to maintain high structural fidelity and transparency data for 45x45 resolution frames.
### Features
* **Dual-Model Pipeline:** Separates spatial understanding (AE) from temporal prediction (Predictor).
* **High-Dimensional Latents:** Uses a 1012-unit bottleneck for rich feature representation.
* **Robustness-Tested:** Maintains stable performance across varying signal-to-noise ratios.
* **RGBA Native:** Built specifically to handle Alpha channel transparency.
## ๐Ÿ› ๏ธ Quick Start (Google Colab / Python)
The following code automatically fetches the necessary weights from the **Parallax-Labs** Hugging Face repositories and initializes the architectures.
```python
import torch
import torch.nn as nn
import os
import requests
# 1. DOWNLOAD MODELS
def download_file(url, filename):
if not os.path.exists(filename):
print(f"Downloading {filename}...")
r = requests.get(url, allow_redirects=True)
with open(filename, 'wb') as f:
f.write(r.content)
urls = {
"ae_global.pt": "https://huggingface.co/Parallax-labs-1/parallax_VIDEO-Boxes/resolve/main/ae_global.pt",
"predictor.pt": "https://huggingface.co/Parallax-labs-1/parallax_VIDEO-Boxes/resolve/main/predictor.pt",
"vision_base.pt": "https://huggingface.co/Parallax-labs-1/parallax_VISION-boxes-RGBA/resolve/main/model.pt"
}
for name, url in urls.items():
download_file(url, name)
# 2. DEFINE ARCHITECTURES
class GlobalHCA_AE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(4, 16, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(),
nn.Flatten(),
nn.Linear(32 * 12 * 12, 2048), nn.ReLU(),
nn.Linear(2048, 1012 * 2)
)
self.decoder = nn.Sequential(
nn.Linear(1012, 2048), nn.ReLU(),
nn.Linear(2048, 32 * 12 * 12), nn.ReLU(),
nn.Unflatten(1, (32, 12, 12)),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(16, 4, 3, stride=2, padding=1), nn.Sigmoid()
)
def forward(self, x):
h = self.encoder(x)
mu, _ = h.chunk(2, dim=-1)
return self.decoder(mu), mu
class LatentPredictor(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1012, 2048), nn.ReLU(),
nn.Linear(2048, 4096), nn.ReLU(),
nn.Linear(4096, 8100),
nn.Unflatten(1, (4, 45, 45))
)
def forward(self, z):
return self.net(z)
class AlphaAutoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(4, 32, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 4, 1)
)
self.decoder = nn.Sequential(
nn.Conv2d(4, 256, 3, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 3, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 3, padding=1),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2),
nn.Conv2d(16, 16, 3, padding=1),
nn.PixelShuffle(2),
nn.Sigmoid()
)
def forward(self, x):
z = self.encoder(x)
return self.decoder(z), z
# 3. LOAD WEIGHTS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize models
ae_global = GlobalHCA_AE().to(device)
predictor = LatentPredictor().to(device)
vision_base = AlphaAutoencoder().to(device)
# Load weights
ae_global.load_state_dict(torch.load("ae_global.pt", map_location=device))
predictor.load_state_dict(torch.load("predictor.pt", map_location=device))
vision_base.load_state_dict(torch.load("vision_base.pt", map_location=device))
print("All Parallax-Labs models (Global AE, Predictor, and Vision Base) loaded successfully.")
```
## ๐Ÿ“Š Performance Metrics
| Phase | Metric | Value |
|---|---|---|
| **Autoencoder** | Reconstruction MSE | 0.001722 |
| **Predictor** | Latent-to-Pixel MSE | 0.000150 |
| **Robustness** | Max Noise Stability | 0.30 |
## ๐Ÿ“‚ Repository Links
* Main Video Repository
* Vision Base Repository (RGBA Boxes)
**Developed by Parallax-Labs**