microforge / run_demo.py
asdf98's picture
Add run_demo.py
4f8b3e5 verified
#!/usr/bin/env python3
"""
MicroForge End-to-End Demo Script
Runs the full notebook content as pure Python (no Jupyter magic).
"""
import torch
import torch.nn as nn
import time
import os
import sys
# Ensure we can import microforge
sys.path.insert(0, '/app')
print("=" * 70)
print("πŸ”¨ MicroForge: End-to-End Architecture Demo")
print("=" * 70)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
# ── 1. Import all modules ──
from microforge.vae import MicroForgeVAE
from microforge.backbone import MicroForgeBackbone
from microforge.planner import RecurrentLatentPlanner
from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder
from microforge.training import MicroForgeTrainer, FlowMatchingScheduler, MicroForgeLoss
print("βœ“ All modules imported")
# ── 2. Test VAE configs ──
print("\n── VAE Configurations ──")
for config in ['tiny', 'small', 'base']:
vae = MicroForgeVAE(config=config)
params = sum(p.numel() for p in vae.parameters())
x = torch.randn(1, 3, 256, 256)
x_recon, mu, logvar = vae(x)
print(f" {config:>5}: {params:>12,} params | latent {mu.shape} | {params*2/1e6:.0f} MB fp16")
del vae
# ── 3. Test Backbone configs ──
print("\n── Backbone Configurations ──")
for config_name in ['tiny', 'small', 'base']:
lc = 16 if config_name == 'tiny' else 32
bb = MicroForgeBackbone(latent_channels=lc, config=config_name)
params = sum(p.numel() for p in bb.parameters())
z = torch.randn(1, lc, 8, 8)
t0 = time.time()
v = bb(z, torch.rand(1), torch.randn(1, 10, 768), torch.randn(1, 768))
ms = (time.time() - t0) * 1000
print(f" {config_name:>5}: {params:>12,} params | {ms:.0f}ms | {params*2/1e6:.0f} MB fp16")
del bb
# ── 4. Planner test ──
print("\n── Recurrent Latent Planner ──")
planner = RecurrentLatentPlanner(num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32)
params = sum(p.numel() for p in planner.parameters())
print(f" Params: {params:,} | Plan state: {planner.get_plan_size_bytes()} bytes")
text_pooled = torch.randn(1, 768)
plan = planner.initialize_plan(text_pooled, 1)
for step in range(3):
img = torch.randn(1, 64, 32)
t_emb = torch.randn(1, 384)
plan, out = planner(img, plan, t_emb)
plan = planner.initialize_plan(text_pooled, 1, prev_plan=plan)
print(f" Step {step}: plan_norm={plan.norm():.2f}, out_norm={out.norm():.2f}")
del planner
# ── 5. Full Pipeline ──
print("\n── Full Pipeline Assembly ──")
vae = MicroForgeVAE(config='tiny')
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
text_enc = SimpleTextEncoder(vocab_size=8192, embed_dim=768, num_layers=2)
pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu')
params = pipeline.count_parameters()
print(f" Total params: {params['total']:,}")
for k, v in params.items():
if k != 'total':
print(f" {k}: {v:,}")
mem = pipeline.get_memory_estimate(512, 512)
print(f" Est. inference @512px: {mem['estimated_inference_mb']:.0f} MB")
# ── 6. Text2Img ──
print("\n── Text-to-Image Generation ──")
tokens = torch.randint(0, 8192, (1, 10))
t0 = time.time()
images = pipeline.text2img(tokens, height=128, width=128, num_steps=4, cfg_scale=1.0, seed=42)
ms = (time.time() - t0) * 1000
print(f" Generated {images.shape} in {ms:.0f}ms")
print(f" Range: [{images.min():.2f}, {images.max():.2f}]")
# ── 7. Training Demo ──
print("\n── Training Pipeline Demo ──")
vae_train = MicroForgeVAE(config='tiny')
bb_train = MicroForgeBackbone(latent_channels=16, config='tiny')
pl_train = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
# VAE training
print(" Stage 1: VAE Training")
vae_train.train()
vae_opt = torch.optim.AdamW(vae_train.parameters(), lr=1e-4)
loss_fn = MicroForgeLoss(lambda_kl=1e-6)
for i in range(20):
imgs = torch.randn(4, 3, 128, 128) * 0.5
x_recon, mu, logvar = vae_train(imgs)
losses = loss_fn.vae_loss(x_recon, imgs, mu, logvar)
vae_opt.zero_grad()
losses['total'].backward()
torch.nn.utils.clip_grad_norm_(vae_train.parameters(), 2.0)
vae_opt.step()
if i % 5 == 0:
print(f" Step {i:3d}: recon={losses['recon'].item():.4f}")
# Backbone training
print(" Stage 2: Backbone Flow Matching")
vae_train.eval()
trainer = MicroForgeTrainer(vae_train, bb_train, pl_train, lr=1e-4, use_ema=True)
for i in range(20):
imgs = torch.randn(2, 3, 128, 128) * 0.5
text_emb = torch.randn(2, 10, 768)
text_pooled = torch.randn(2, 768)
losses = trainer.train_step(imgs, text_emb, text_pooled)
if i % 5 == 0:
print(f" Step {i:3d}: flow={losses['flow']:.2f}")
# ── 8. Editing pathway ──
print("\n── Editing Pathway Test ──")
bb = MicroForgeBackbone(latent_channels=16, config='tiny')
z_gen = torch.randn(1, 16, 8, 8)
z_edit = torch.randn(1, 16, 8, 16)
t = torch.rand(1)
te = torch.randn(1, 5, 768)
tp = torch.randn(1, 768)
v_gen = bb(z_gen, t, te, tp)
v_edit = bb(z_edit, t, te, tp)
print(f" Generation: {z_gen.shape} -> {v_gen.shape}")
print(f" Editing: {z_edit.shape} -> {v_edit.shape}")
# ── 9. Staged freeze/thaw ──
print("\n── Staged Training Config ──")
vae_s = MicroForgeVAE(config='tiny')
bb_s = MicroForgeBackbone(latent_channels=16, config='tiny')
pl_s = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
def count_t(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
def freeze(m):
for p in m.parameters(): p.requires_grad_(False)
def unfreeze(m):
for p in m.parameters(): p.requires_grad_(True)
freeze(bb_s); freeze(pl_s); unfreeze(vae_s)
print(f" Stage 1 (VAE only): {count_t(vae_s):,} trainable")
freeze(vae_s); unfreeze(bb_s); unfreeze(pl_s)
print(f" Stage 2 (Backbone+Plan): {count_t(bb_s)+count_t(pl_s):,} trainable")
unfreeze(vae_s)
print(f" Stage 5 (Joint): {count_t(vae_s)+count_t(bb_s)+count_t(pl_s):,} trainable")
# ── 10. Architecture comparison ──
print("\n── Architecture Comparison ──")
comparison = [
('SD-v1.5', '860M', '~3.4 GB', 'O(NΒ²)'),
('SDXL', '2.6B', '~6.5 GB', 'O(NΒ²)'),
('SANA-Sprint', '600M+2B', '~5.5 GB', 'O(N)'),
('SnapGen', '380M+2B', '~4 GB', 'O(NΒ²)'),
('DreamLite', '389M+2B', '~4 GB', 'O(NΒ²)'),
('MicroForge-tiny', '28M+text', '~0.2 GB', 'O(N)'),
('MicroForge-small', '114M+text', '~0.6 GB', 'O(N)'),
]
print(f" {'Model':>18} | {'Params':>12} | {'VRAM':>10} | {'Complexity':>10}")
print(" " + "-" * 60)
for row in comparison:
print(f" {row[0]:>18} | {row[1]:>12} | {row[2]:>10} | {row[3]:>10}")
# ── 11. Save checkpoint ──
print("\n── Save Checkpoint ──")
os.makedirs('/app/checkpoints', exist_ok=True)
ckpt = {
'vae': vae_train.state_dict(),
'backbone': bb_train.state_dict(),
'planner': pl_train.state_dict(),
'config': {
'vae_config': 'tiny', 'backbone_config': 'tiny',
'latent_channels': 16, 'plan_tokens': 16, 'plan_dim': 256,
},
'version': '0.1.0',
}
torch.save(ckpt, '/app/checkpoints/microforge_tiny_demo.pt')
size = os.path.getsize('/app/checkpoints/microforge_tiny_demo.pt') / 1e6
print(f" Saved: {size:.1f} MB")
print("\n" + "=" * 70)
print("βœ… MicroForge End-to-End Demo Complete β€” All Tests Passed")
print("=" * 70)