File size: 9,308 Bytes
ae7a539 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | #!/usr/bin/env python3
"""
Two experiments:
1. Research GPU memory reduction for FigQuant (figcache mode on GPU)
2. Run CogMemBench on TinyLlama
"""
import os, sys, subprocess, time, gc, json
import numpy as np
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
"transformers", "accelerate", "datasets", "sentencepiece", "protobuf", "psutil", "numpy"])
subprocess.check_call(["git", "clone", "https://github.com/ticketguy/littlefig.git", "/app/littlefig"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/app/littlefig[train]"])
sys.path.insert(0, "/app/littlefig/src")
sys.path.insert(0, "/app/littlefig")
import torch
def log(msg): print(f"[EXP] {msg}", flush=True)
log(f"PyTorch {torch.__version__}, CUDA={torch.cuda.is_available()}")
if torch.cuda.is_available():
log(f"GPU: {torch.cuda.get_device_name()} ({torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB)")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# EXPERIMENT 1: GPU Memory Profiling β what eats the VRAM?
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
log("\n" + "="*60)
log(" EXPERIMENT 1: GPU Memory Profiling")
log("="*60)
from little_fig.engine import FigModel
from little_fig.engine.tier import TrainingTier
MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
gc.collect(); torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
# Profile: what's the memory at each stage?
log("\n Memory at each stage (lowram mode):")
# Stage 1: load model on CPU
model = FigModel.from_pretrained(MODEL, lora_r=16, lora_alpha=32,
tier=TrainingTier.STREAMING_LORA, target_modules=["q_proj","k_proj","v_proj","o_proj"],
fast=False)
log(f" After load (CPU): GPU={torch.cuda.memory_allocated()/1e6:.0f}MB")
# Stage 2: move to GPU
dev = torch.device("cuda")
model = model.to(dev)
torch.cuda.synchronize()
after_move = torch.cuda.memory_allocated()/1e6
log(f" After .to(cuda): GPU={after_move:.0f}MB")
# Stage 3: single forward pass
tok = model.tokenizer
enc = tok("Hello world", return_tensors="pt", max_length=64, truncation=True, padding="max_length")
enc = {k: v.to(dev) for k, v in enc.items()}
torch.cuda.reset_peak_memory_stats()
with torch.autocast("cuda", dtype=torch.float16):
out = model(input_ids=enc["input_ids"], labels=enc["input_ids"])
after_fwd = torch.cuda.max_memory_allocated()/1e6
log(f" After forward: GPU={after_fwd:.0f}MB (peak)")
# Stage 4: backward pass
torch.cuda.reset_peak_memory_stats()
out.loss.backward()
after_bwd = torch.cuda.max_memory_allocated()/1e6
log(f" After backward: GPU={after_bwd:.0f}MB (peak)")
log(f"\n ANALYSIS:")
log(f" Model on GPU: {after_move:.0f}MB")
log(f" Forward peak: {after_fwd:.0f}MB (+{after_fwd-after_move:.0f}MB activations)")
log(f" Backward peak: {after_bwd:.0f}MB (+{after_bwd-after_fwd:.0f}MB gradients)")
log(f" Total training: {after_bwd:.0f}MB")
# What's eating memory? The INT4 weights are tiny, but they get dequantized to FP32 in forward
# In lowram mode: each forward dequants to fp32 temporarily β that's where the spike is
# With autocast(fp16): the dequant goes to fp16 (our dtype fix) β should be 2Γ less
# Count parameters by type
int4_bytes = 0
fp32_bytes = 0
for name, param in model.named_parameters():
if param.requires_grad:
fp32_bytes += param.numel() * param.element_size()
for name, buf in model.named_buffers():
if buf is not None:
if buf.dtype == torch.uint8:
int4_bytes += buf.numel()
else:
fp32_bytes += buf.numel() * buf.element_size()
log(f"\n Weight breakdown:")
log(f" INT4 packed indices: {int4_bytes/1e6:.1f}MB")
log(f" FP32 params/buffers: {fp32_bytes/1e6:.1f}MB")
log(f" LoRA trainable: {sum(p.numel()*4 for p in model.parameters() if p.requires_grad)/1e6:.1f}MB")
# FINDING: The issue is that dequant creates full fp32/fp16 weight tensors per layer per forward
# For 88 quantized layers at ~4MB each = ~350MB of temporary dequantized weights
# Plus activations + gradients for a 1.1B model = total ~10GB
log(f"\n ROOT CAUSE: Each forward dequantizes 88 layers Γ ~4MB each = ~350MB temp tensors")
log(f" Plus activations for 1.1B model at seq_len=512 = ~several GB")
log(f" SOLUTIONS:")
log(f" 1. Gradient checkpointing (already used β recompute activations)")
log(f" 2. Smaller batch size (reduce activation memory)")
log(f" 3. Shorter sequence length")
log(f" 4. FP16 dequant instead of FP32 (our dtype fix helps)")
log(f" 5. Layer-wise gradient accumulation (dequant only active layer)")
del model; gc.collect(); torch.cuda.empty_cache()
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# EXPERIMENT 2: Can we reduce memory by using smaller batch + shorter seq?
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
log("\n" + "="*60)
log(" EXPERIMENT 2: Memory vs Batch Size/Seq Length")
log("="*60)
configs = [
(1, 128, "batch=1, seq=128"),
(1, 256, "batch=1, seq=256"),
(2, 256, "batch=2, seq=256"),
(4, 256, "batch=4, seq=256"),
(4, 512, "batch=4, seq=512"),
]
results_mem = []
for batch_sz, seq_len, label in configs:
gc.collect(); torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
model = FigModel.from_pretrained(MODEL, lora_r=16, lora_alpha=32,
tier=TrainingTier.STREAMING_LORA, target_modules=["q_proj","k_proj","v_proj","o_proj"],
fast=False)
model = model.to(dev)
ids = torch.randint(0, 32000, (batch_sz, seq_len), device=dev)
try:
torch.cuda.reset_peak_memory_stats()
with torch.autocast("cuda", dtype=torch.float16):
out = model(input_ids=ids, labels=ids)
out.loss.backward()
peak = torch.cuda.max_memory_allocated()/1e6
results_mem.append((label, peak, "β"))
log(f" {label:>20}: {peak:.0f}MB β")
except torch.cuda.OutOfMemoryError:
results_mem.append((label, 0, "OOM"))
log(f" {label:>20}: OOM β")
del model; gc.collect(); torch.cuda.empty_cache()
log(f"\n FINDING: Memory scales with batch Γ seq_len")
log(f" For T4 (16GB): batch=2, seq=256 is the sweet spot for FigQuant lowram")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# EXPERIMENT 3: Run CogMemBench on TinyLlama
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
log("\n" + "="*60)
log(" EXPERIMENT 3: CogMemBench on TinyLlama")
log("="*60)
from cogmembench import CogMemGenerator, CogMemScorer, CogMemRunner
from transformers import AutoModelForCausalLM, AutoTokenizer
gc.collect(); torch.cuda.empty_cache()
log("Loading TinyLlama for benchmark...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float16, device_map="auto")
tokenizer.pad_token = tokenizer.eos_token
def generate_response(prompt):
"""Generate a response from TinyLlama given a CogMemBench prompt."""
messages = [{"role": "user", "content": prompt}]
try:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except:
text = f"<|user|>\n{prompt}\n<|assistant|>\n"
enc = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
with torch.no_grad():
out = model.generate(**enc, max_new_tokens=150, do_sample=False,
pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True)
return response
# Run on a subset (full 1000 would take too long)
runner = CogMemRunner(seed=42, per_axis=20) # 100 total cases
log("Running CogMemBench (100 cases, 5 axes)...")
results = runner.run(
model_fn=generate_response,
max_cases=100,
verbose=True,
)
log(f"\n CogMem Score: {results['cogmem_score']}/100")
log(f" Per-axis:")
for ax, acc in results['axis_accuracy'].items():
log(f" {ax:>15}: {acc*100:.1f}%")
# Save results
with open("/app/cogmem_results.json", "w") as f:
json.dump({k: v for k, v in results.items() if k != 'details'}, f, indent=2)
log("\n" + "="*60)
log(" ALL EXPERIMENTS COMPLETE")
log("="*60)
|