ticketguy commited on
Commit
ae7a539
Β·
verified Β·
1 Parent(s): 246f26e

GPU memory experiment + CogMemBench on real model

Browse files
Files changed (1) hide show
  1. memory_and_cogmem.py +211 -0
memory_and_cogmem.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Two experiments:
4
+ 1. Research GPU memory reduction for FigQuant (figcache mode on GPU)
5
+ 2. Run CogMemBench on TinyLlama
6
+ """
7
+ import os, sys, subprocess, time, gc, json
8
+ import numpy as np
9
+
10
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
11
+ "transformers", "accelerate", "datasets", "sentencepiece", "protobuf", "psutil", "numpy"])
12
+ subprocess.check_call(["git", "clone", "https://github.com/ticketguy/littlefig.git", "/app/littlefig"])
13
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-e", "/app/littlefig[train]"])
14
+ sys.path.insert(0, "/app/littlefig/src")
15
+ sys.path.insert(0, "/app/littlefig")
16
+
17
+ import torch
18
+
19
+ def log(msg): print(f"[EXP] {msg}", flush=True)
20
+
21
+ log(f"PyTorch {torch.__version__}, CUDA={torch.cuda.is_available()}")
22
+ if torch.cuda.is_available():
23
+ log(f"GPU: {torch.cuda.get_device_name()} ({torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB)")
24
+
25
+ # ═══════════════════════════════════════════════════════════════════════════════
26
+ # EXPERIMENT 1: GPU Memory Profiling β€” what eats the VRAM?
27
+ # ═══════════════════════════════════════════════════════════════════════════════
28
+ log("\n" + "="*60)
29
+ log(" EXPERIMENT 1: GPU Memory Profiling")
30
+ log("="*60)
31
+
32
+ from little_fig.engine import FigModel
33
+ from little_fig.engine.tier import TrainingTier
34
+
35
+ MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
36
+ gc.collect(); torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
37
+
38
+ # Profile: what's the memory at each stage?
39
+ log("\n Memory at each stage (lowram mode):")
40
+
41
+ # Stage 1: load model on CPU
42
+ model = FigModel.from_pretrained(MODEL, lora_r=16, lora_alpha=32,
43
+ tier=TrainingTier.STREAMING_LORA, target_modules=["q_proj","k_proj","v_proj","o_proj"],
44
+ fast=False)
45
+
46
+ log(f" After load (CPU): GPU={torch.cuda.memory_allocated()/1e6:.0f}MB")
47
+
48
+ # Stage 2: move to GPU
49
+ dev = torch.device("cuda")
50
+ model = model.to(dev)
51
+ torch.cuda.synchronize()
52
+ after_move = torch.cuda.memory_allocated()/1e6
53
+ log(f" After .to(cuda): GPU={after_move:.0f}MB")
54
+
55
+ # Stage 3: single forward pass
56
+ tok = model.tokenizer
57
+ enc = tok("Hello world", return_tensors="pt", max_length=64, truncation=True, padding="max_length")
58
+ enc = {k: v.to(dev) for k, v in enc.items()}
59
+
60
+ torch.cuda.reset_peak_memory_stats()
61
+ with torch.autocast("cuda", dtype=torch.float16):
62
+ out = model(input_ids=enc["input_ids"], labels=enc["input_ids"])
63
+ after_fwd = torch.cuda.max_memory_allocated()/1e6
64
+ log(f" After forward: GPU={after_fwd:.0f}MB (peak)")
65
+
66
+ # Stage 4: backward pass
67
+ torch.cuda.reset_peak_memory_stats()
68
+ out.loss.backward()
69
+ after_bwd = torch.cuda.max_memory_allocated()/1e6
70
+ log(f" After backward: GPU={after_bwd:.0f}MB (peak)")
71
+
72
+ log(f"\n ANALYSIS:")
73
+ log(f" Model on GPU: {after_move:.0f}MB")
74
+ log(f" Forward peak: {after_fwd:.0f}MB (+{after_fwd-after_move:.0f}MB activations)")
75
+ log(f" Backward peak: {after_bwd:.0f}MB (+{after_bwd-after_fwd:.0f}MB gradients)")
76
+ log(f" Total training: {after_bwd:.0f}MB")
77
+
78
+ # What's eating memory? The INT4 weights are tiny, but they get dequantized to FP32 in forward
79
+ # In lowram mode: each forward dequants to fp32 temporarily β†’ that's where the spike is
80
+ # With autocast(fp16): the dequant goes to fp16 (our dtype fix) β†’ should be 2Γ— less
81
+
82
+ # Count parameters by type
83
+ int4_bytes = 0
84
+ fp32_bytes = 0
85
+ for name, param in model.named_parameters():
86
+ if param.requires_grad:
87
+ fp32_bytes += param.numel() * param.element_size()
88
+ for name, buf in model.named_buffers():
89
+ if buf is not None:
90
+ if buf.dtype == torch.uint8:
91
+ int4_bytes += buf.numel()
92
+ else:
93
+ fp32_bytes += buf.numel() * buf.element_size()
94
+
95
+ log(f"\n Weight breakdown:")
96
+ log(f" INT4 packed indices: {int4_bytes/1e6:.1f}MB")
97
+ log(f" FP32 params/buffers: {fp32_bytes/1e6:.1f}MB")
98
+ log(f" LoRA trainable: {sum(p.numel()*4 for p in model.parameters() if p.requires_grad)/1e6:.1f}MB")
99
+
100
+ # FINDING: The issue is that dequant creates full fp32/fp16 weight tensors per layer per forward
101
+ # For 88 quantized layers at ~4MB each = ~350MB of temporary dequantized weights
102
+ # Plus activations + gradients for a 1.1B model = total ~10GB
103
+
104
+ log(f"\n ROOT CAUSE: Each forward dequantizes 88 layers Γ— ~4MB each = ~350MB temp tensors")
105
+ log(f" Plus activations for 1.1B model at seq_len=512 = ~several GB")
106
+ log(f" SOLUTIONS:")
107
+ log(f" 1. Gradient checkpointing (already used β€” recompute activations)")
108
+ log(f" 2. Smaller batch size (reduce activation memory)")
109
+ log(f" 3. Shorter sequence length")
110
+ log(f" 4. FP16 dequant instead of FP32 (our dtype fix helps)")
111
+ log(f" 5. Layer-wise gradient accumulation (dequant only active layer)")
112
+
113
+ del model; gc.collect(); torch.cuda.empty_cache()
114
+
115
+ # ═══════════════════════════════════════════════════════════════════════════════
116
+ # EXPERIMENT 2: Can we reduce memory by using smaller batch + shorter seq?
117
+ # ═══════════════════════════════════════════════════════════════════════════════
118
+ log("\n" + "="*60)
119
+ log(" EXPERIMENT 2: Memory vs Batch Size/Seq Length")
120
+ log("="*60)
121
+
122
+ configs = [
123
+ (1, 128, "batch=1, seq=128"),
124
+ (1, 256, "batch=1, seq=256"),
125
+ (2, 256, "batch=2, seq=256"),
126
+ (4, 256, "batch=4, seq=256"),
127
+ (4, 512, "batch=4, seq=512"),
128
+ ]
129
+
130
+ results_mem = []
131
+ for batch_sz, seq_len, label in configs:
132
+ gc.collect(); torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
133
+
134
+ model = FigModel.from_pretrained(MODEL, lora_r=16, lora_alpha=32,
135
+ tier=TrainingTier.STREAMING_LORA, target_modules=["q_proj","k_proj","v_proj","o_proj"],
136
+ fast=False)
137
+ model = model.to(dev)
138
+
139
+ ids = torch.randint(0, 32000, (batch_sz, seq_len), device=dev)
140
+
141
+ try:
142
+ torch.cuda.reset_peak_memory_stats()
143
+ with torch.autocast("cuda", dtype=torch.float16):
144
+ out = model(input_ids=ids, labels=ids)
145
+ out.loss.backward()
146
+ peak = torch.cuda.max_memory_allocated()/1e6
147
+ results_mem.append((label, peak, "βœ“"))
148
+ log(f" {label:>20}: {peak:.0f}MB βœ“")
149
+ except torch.cuda.OutOfMemoryError:
150
+ results_mem.append((label, 0, "OOM"))
151
+ log(f" {label:>20}: OOM βœ—")
152
+
153
+ del model; gc.collect(); torch.cuda.empty_cache()
154
+
155
+ log(f"\n FINDING: Memory scales with batch Γ— seq_len")
156
+ log(f" For T4 (16GB): batch=2, seq=256 is the sweet spot for FigQuant lowram")
157
+
158
+ # ═══════════════════════════════════════════════════════════════════════════════
159
+ # EXPERIMENT 3: Run CogMemBench on TinyLlama
160
+ # ═══════════════════════════════════════════════════════════════════════════════
161
+ log("\n" + "="*60)
162
+ log(" EXPERIMENT 3: CogMemBench on TinyLlama")
163
+ log("="*60)
164
+
165
+ from cogmembench import CogMemGenerator, CogMemScorer, CogMemRunner
166
+ from transformers import AutoModelForCausalLM, AutoTokenizer
167
+
168
+ gc.collect(); torch.cuda.empty_cache()
169
+
170
+ log("Loading TinyLlama for benchmark...")
171
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
172
+ model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float16, device_map="auto")
173
+ tokenizer.pad_token = tokenizer.eos_token
174
+
175
+ def generate_response(prompt):
176
+ """Generate a response from TinyLlama given a CogMemBench prompt."""
177
+ messages = [{"role": "user", "content": prompt}]
178
+ try:
179
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
180
+ except:
181
+ text = f"<|user|>\n{prompt}\n<|assistant|>\n"
182
+
183
+ enc = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
184
+ with torch.no_grad():
185
+ out = model.generate(**enc, max_new_tokens=150, do_sample=False,
186
+ pad_token_id=tokenizer.eos_token_id)
187
+ response = tokenizer.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True)
188
+ return response
189
+
190
+ # Run on a subset (full 1000 would take too long)
191
+ runner = CogMemRunner(seed=42, per_axis=20) # 100 total cases
192
+ log("Running CogMemBench (100 cases, 5 axes)...")
193
+
194
+ results = runner.run(
195
+ model_fn=generate_response,
196
+ max_cases=100,
197
+ verbose=True,
198
+ )
199
+
200
+ log(f"\n CogMem Score: {results['cogmem_score']}/100")
201
+ log(f" Per-axis:")
202
+ for ax, acc in results['axis_accuracy'].items():
203
+ log(f" {ax:>15}: {acc*100:.1f}%")
204
+
205
+ # Save results
206
+ with open("/app/cogmem_results.json", "w") as f:
207
+ json.dump({k: v for k, v in results.items() if k != 'details'}, f, indent=2)
208
+
209
+ log("\n" + "="*60)
210
+ log(" ALL EXPERIMENTS COMPLETE")
211
+ log("="*60)