| |
| """ |
| MicroForge End-to-End Demo Script |
| Runs the full notebook content as pure Python (no Jupyter magic). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import time |
| import os |
| import sys |
|
|
| |
| sys.path.insert(0, '/app') |
|
|
| print("=" * 70) |
| print("π¨ MicroForge: End-to-End Architecture Demo") |
| print("=" * 70) |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f'Device: {device}') |
|
|
| |
| from microforge.vae import MicroForgeVAE |
| from microforge.backbone import MicroForgeBackbone |
| from microforge.planner import RecurrentLatentPlanner |
| from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder |
| from microforge.training import MicroForgeTrainer, FlowMatchingScheduler, MicroForgeLoss |
|
|
| print("β All modules imported") |
|
|
| |
| print("\nββ VAE Configurations ββ") |
| for config in ['tiny', 'small', 'base']: |
| vae = MicroForgeVAE(config=config) |
| params = sum(p.numel() for p in vae.parameters()) |
| x = torch.randn(1, 3, 256, 256) |
| x_recon, mu, logvar = vae(x) |
| print(f" {config:>5}: {params:>12,} params | latent {mu.shape} | {params*2/1e6:.0f} MB fp16") |
| del vae |
|
|
| |
| print("\nββ Backbone Configurations ββ") |
| for config_name in ['tiny', 'small', 'base']: |
| lc = 16 if config_name == 'tiny' else 32 |
| bb = MicroForgeBackbone(latent_channels=lc, config=config_name) |
| params = sum(p.numel() for p in bb.parameters()) |
| z = torch.randn(1, lc, 8, 8) |
| t0 = time.time() |
| v = bb(z, torch.rand(1), torch.randn(1, 10, 768), torch.randn(1, 768)) |
| ms = (time.time() - t0) * 1000 |
| print(f" {config_name:>5}: {params:>12,} params | {ms:.0f}ms | {params*2/1e6:.0f} MB fp16") |
| del bb |
|
|
| |
| print("\nββ Recurrent Latent Planner ββ") |
| planner = RecurrentLatentPlanner(num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32) |
| params = sum(p.numel() for p in planner.parameters()) |
| print(f" Params: {params:,} | Plan state: {planner.get_plan_size_bytes()} bytes") |
|
|
| text_pooled = torch.randn(1, 768) |
| plan = planner.initialize_plan(text_pooled, 1) |
| for step in range(3): |
| img = torch.randn(1, 64, 32) |
| t_emb = torch.randn(1, 384) |
| plan, out = planner(img, plan, t_emb) |
| plan = planner.initialize_plan(text_pooled, 1, prev_plan=plan) |
| print(f" Step {step}: plan_norm={plan.norm():.2f}, out_norm={out.norm():.2f}") |
| del planner |
|
|
| |
| print("\nββ Full Pipeline Assembly ββ") |
| vae = MicroForgeVAE(config='tiny') |
| backbone = MicroForgeBackbone(latent_channels=16, config='tiny') |
| planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16) |
| text_enc = SimpleTextEncoder(vocab_size=8192, embed_dim=768, num_layers=2) |
|
|
| pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu') |
| params = pipeline.count_parameters() |
| print(f" Total params: {params['total']:,}") |
| for k, v in params.items(): |
| if k != 'total': |
| print(f" {k}: {v:,}") |
|
|
| mem = pipeline.get_memory_estimate(512, 512) |
| print(f" Est. inference @512px: {mem['estimated_inference_mb']:.0f} MB") |
|
|
| |
| print("\nββ Text-to-Image Generation ββ") |
| tokens = torch.randint(0, 8192, (1, 10)) |
| t0 = time.time() |
| images = pipeline.text2img(tokens, height=128, width=128, num_steps=4, cfg_scale=1.0, seed=42) |
| ms = (time.time() - t0) * 1000 |
| print(f" Generated {images.shape} in {ms:.0f}ms") |
| print(f" Range: [{images.min():.2f}, {images.max():.2f}]") |
|
|
| |
| print("\nββ Training Pipeline Demo ββ") |
| vae_train = MicroForgeVAE(config='tiny') |
| bb_train = MicroForgeBackbone(latent_channels=16, config='tiny') |
| pl_train = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16) |
|
|
| |
| print(" Stage 1: VAE Training") |
| vae_train.train() |
| vae_opt = torch.optim.AdamW(vae_train.parameters(), lr=1e-4) |
| loss_fn = MicroForgeLoss(lambda_kl=1e-6) |
| for i in range(20): |
| imgs = torch.randn(4, 3, 128, 128) * 0.5 |
| x_recon, mu, logvar = vae_train(imgs) |
| losses = loss_fn.vae_loss(x_recon, imgs, mu, logvar) |
| vae_opt.zero_grad() |
| losses['total'].backward() |
| torch.nn.utils.clip_grad_norm_(vae_train.parameters(), 2.0) |
| vae_opt.step() |
| if i % 5 == 0: |
| print(f" Step {i:3d}: recon={losses['recon'].item():.4f}") |
|
|
| |
| print(" Stage 2: Backbone Flow Matching") |
| vae_train.eval() |
| trainer = MicroForgeTrainer(vae_train, bb_train, pl_train, lr=1e-4, use_ema=True) |
| for i in range(20): |
| imgs = torch.randn(2, 3, 128, 128) * 0.5 |
| text_emb = torch.randn(2, 10, 768) |
| text_pooled = torch.randn(2, 768) |
| losses = trainer.train_step(imgs, text_emb, text_pooled) |
| if i % 5 == 0: |
| print(f" Step {i:3d}: flow={losses['flow']:.2f}") |
|
|
| |
| print("\nββ Editing Pathway Test ββ") |
| bb = MicroForgeBackbone(latent_channels=16, config='tiny') |
| z_gen = torch.randn(1, 16, 8, 8) |
| z_edit = torch.randn(1, 16, 8, 16) |
| t = torch.rand(1) |
| te = torch.randn(1, 5, 768) |
| tp = torch.randn(1, 768) |
|
|
| v_gen = bb(z_gen, t, te, tp) |
| v_edit = bb(z_edit, t, te, tp) |
| print(f" Generation: {z_gen.shape} -> {v_gen.shape}") |
| print(f" Editing: {z_edit.shape} -> {v_edit.shape}") |
|
|
| |
| print("\nββ Staged Training Config ββ") |
| vae_s = MicroForgeVAE(config='tiny') |
| bb_s = MicroForgeBackbone(latent_channels=16, config='tiny') |
| pl_s = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16) |
|
|
| def count_t(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) |
| def freeze(m): |
| for p in m.parameters(): p.requires_grad_(False) |
| def unfreeze(m): |
| for p in m.parameters(): p.requires_grad_(True) |
|
|
| freeze(bb_s); freeze(pl_s); unfreeze(vae_s) |
| print(f" Stage 1 (VAE only): {count_t(vae_s):,} trainable") |
|
|
| freeze(vae_s); unfreeze(bb_s); unfreeze(pl_s) |
| print(f" Stage 2 (Backbone+Plan): {count_t(bb_s)+count_t(pl_s):,} trainable") |
|
|
| unfreeze(vae_s) |
| print(f" Stage 5 (Joint): {count_t(vae_s)+count_t(bb_s)+count_t(pl_s):,} trainable") |
|
|
| |
| print("\nββ Architecture Comparison ββ") |
| comparison = [ |
| ('SD-v1.5', '860M', '~3.4 GB', 'O(NΒ²)'), |
| ('SDXL', '2.6B', '~6.5 GB', 'O(NΒ²)'), |
| ('SANA-Sprint', '600M+2B', '~5.5 GB', 'O(N)'), |
| ('SnapGen', '380M+2B', '~4 GB', 'O(NΒ²)'), |
| ('DreamLite', '389M+2B', '~4 GB', 'O(NΒ²)'), |
| ('MicroForge-tiny', '28M+text', '~0.2 GB', 'O(N)'), |
| ('MicroForge-small', '114M+text', '~0.6 GB', 'O(N)'), |
| ] |
| print(f" {'Model':>18} | {'Params':>12} | {'VRAM':>10} | {'Complexity':>10}") |
| print(" " + "-" * 60) |
| for row in comparison: |
| print(f" {row[0]:>18} | {row[1]:>12} | {row[2]:>10} | {row[3]:>10}") |
|
|
| |
| print("\nββ Save Checkpoint ββ") |
| os.makedirs('/app/checkpoints', exist_ok=True) |
| ckpt = { |
| 'vae': vae_train.state_dict(), |
| 'backbone': bb_train.state_dict(), |
| 'planner': pl_train.state_dict(), |
| 'config': { |
| 'vae_config': 'tiny', 'backbone_config': 'tiny', |
| 'latent_channels': 16, 'plan_tokens': 16, 'plan_dim': 256, |
| }, |
| 'version': '0.1.0', |
| } |
| torch.save(ckpt, '/app/checkpoints/microforge_tiny_demo.pt') |
| size = os.path.getsize('/app/checkpoints/microforge_tiny_demo.pt') / 1e6 |
| print(f" Saved: {size:.1f} MB") |
|
|
| print("\n" + "=" * 70) |
| print("β
MicroForge End-to-End Demo Complete β All Tests Passed") |
| print("=" * 70) |
|
|