John6666 commited on
Commit
b25b2b1
·
verified ·
1 Parent(s): d533e48

Upload 10 files

Browse files
Makefile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ check_dirs := src
2
+
3
+ # this target runs checks on all files
4
+ quality:
5
+ black --required-version 23 --check $(check_dirs)
6
+ ruff $(check_dirs)
7
+
8
+ # Format source code automatically and check is there are any problems left that need manual fixing
9
+ style:
10
+ black --required-version 23 $(check_dirs)
11
+ ruff $(check_dirs) --fix
README.md CHANGED
@@ -1,12 +1,19 @@
1
  ---
2
- title: Model Memory Usage Mod
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
 
9
  pinned: false
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: Model Memory Utility
3
+ emoji: 🚀
4
+ colorFrom: pink
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
+ python_version: "3.10"
9
+ app_file: src/app.py
10
  pinned: false
11
+ license: apache-2.0
12
+ hf_oauth: true
13
+ hf_oauth_scopes:
14
+ - gated-repos
15
+ - read-repos
16
  ---
17
 
18
+ This Space provides a static memory estimate for Hugging Face Hub models.
19
+ For gated models, users can either paste an API token or sign in with Hugging Face OAuth.
measure_model_size.png ADDED
pre-requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pip>=24.2,<25
2
+ setuptools>=70
3
+ wheel
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 119
3
+ target-version = ['py37']
4
+
5
+ [tool.ruff]
6
+ # Never enforce `E501` (line length violations).
7
+ ignore = ["E501", "E741", "W605"]
8
+ select = ["E", "F", "I", "W"]
9
+ line-length = 119
10
+
11
+ # Ignore import violations in all `__init__.py` files.
12
+ [tool.ruff.per-file-ignores]
13
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
14
+
15
+ [tool.ruff.isort]
16
+ lines-after-imports = 2
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate>=1.13.0
2
+ transformers>=5.3.0
3
+ timm>=1.0.25
4
+ huggingface_hub>=1.7.1
5
+ tabulate>=0.9.0
6
+ einops>=0.8.1
7
+ gradio_huggingfacehub_search==0.0.12
src/__init__.py ADDED
File without changes
src/app.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import hashlib
3
+ import json
4
+ import tempfile
5
+ import threading
6
+ import time
7
+ import traceback
8
+ from collections import OrderedDict
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+ from uuid import uuid4
13
+
14
+ import accelerate
15
+ import gradio as gr
16
+ import huggingface_hub
17
+
18
+ try:
19
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
20
+ HAS_HF_HUB_SEARCH = True
21
+ except Exception:
22
+ HuggingfaceHubSearch = None
23
+ HAS_HF_HUB_SEARCH = False
24
+ import pandas as pd
25
+ import timm
26
+ import transformers
27
+ from accelerate.utils import convert_bytes
28
+
29
+ from model_utils import (
30
+ calculate_memory,
31
+ get_model_normalized,
32
+ normalize_model_name,
33
+ preflight_model_access_normalized,
34
+ )
35
+
36
+
37
+ DEFAULT_MODEL = "bert-base-cased"
38
+ DEFAULT_LIBRARY = "auto"
39
+ DEFAULT_OPTIONS = ["float32"]
40
+ RESULTS_CACHE_SIZE = 128
41
+ DOWNLOAD_RETENTION_SECONDS = 60 * 60
42
+ DOWNLOAD_CLEANUP_MAX_FILES = 256
43
+
44
+
45
+ def log_startup_versions():
46
+ print(
47
+ "[startup] versions "
48
+ f"gradio={gr.__version__} "
49
+ f"accelerate={accelerate.__version__} "
50
+ f"transformers={transformers.__version__} "
51
+ f"huggingface_hub={huggingface_hub.__version__} "
52
+ f"timm={timm.__version__}"
53
+ )
54
+
55
+
56
+ log_startup_versions()
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class EstimateRequest:
61
+ original_model_name: str
62
+ normalized_model_name: str
63
+ library: str
64
+ options: tuple[str, ...]
65
+ access_token: str | None
66
+ auth_mode: str
67
+
68
+ @property
69
+ def cache_key(self):
70
+ token_key = "anonymous"
71
+ if self.access_token is not None:
72
+ token_key = hashlib.sha256(self.access_token.encode("utf-8")).hexdigest()
73
+ return (
74
+ self.normalized_model_name,
75
+ self.library,
76
+ self.options,
77
+ token_key,
78
+ )
79
+
80
+
81
+ @dataclass
82
+ class EstimatePayload:
83
+ display_rows: list[dict]
84
+ raw_rows: list[dict]
85
+ explanation: str
86
+ breakdown_df: pd.DataFrame
87
+
88
+
89
+ @dataclass
90
+ class EstimateViewModel:
91
+ title: str
92
+ auth_message: str
93
+ summary_df: pd.DataFrame
94
+ explanation: str
95
+ breakdown_df: pd.DataFrame
96
+ error_summary: str = ""
97
+ error_details: str = ""
98
+ summary_path: str | None = None
99
+ breakdown_path: str | None = None
100
+ json_path: str | None = None
101
+
102
+ def to_updates(self):
103
+ return [
104
+ self.title,
105
+ gr.update(value=self.auth_message, visible=True),
106
+ gr.update(visible=not self.summary_df.empty, value=self.summary_df),
107
+ gr.update(visible=self.explanation != "", value=self.explanation),
108
+ gr.update(visible=not self.breakdown_df.empty, value=self.breakdown_df),
109
+ gr.update(visible=self.error_summary != "", value=self.error_summary),
110
+ gr.update(visible=self.error_details != "", value=self.error_details),
111
+ gr.update(visible=self.summary_path is not None, value=self.summary_path),
112
+ gr.update(visible=self.breakdown_path is not None, value=self.breakdown_path),
113
+ gr.update(visible=self.json_path is not None, value=self.json_path),
114
+ ]
115
+
116
+
117
+ @dataclass
118
+ class ResetViewModel:
119
+ model_name: str = DEFAULT_MODEL
120
+ library: str = DEFAULT_LIBRARY
121
+ options: list[str] | tuple[str, ...] = None
122
+ access_token: str = ""
123
+ title: str = ""
124
+
125
+ def __post_init__(self):
126
+ if self.options is None:
127
+ self.options = list(DEFAULT_OPTIONS)
128
+
129
+ def to_updates(self):
130
+ return [
131
+ self.model_name,
132
+ self.library,
133
+ list(self.options),
134
+ self.access_token,
135
+ self.title,
136
+ gr.update(visible=False, value=""),
137
+ gr.update(visible=False, value=pd.DataFrame()),
138
+ gr.update(visible=False, value=""),
139
+ gr.update(visible=False, value=pd.DataFrame()),
140
+ gr.update(visible=False, value=""),
141
+ gr.update(visible=False, value=""),
142
+ gr.update(visible=False, value=None),
143
+ gr.update(visible=False, value=None),
144
+ gr.update(visible=False, value=None),
145
+ ]
146
+
147
+
148
+ @dataclass
149
+ class _InflightEntry:
150
+ event: threading.Event
151
+ data: list[dict] | None = None
152
+ error: Exception | None = None
153
+
154
+
155
+ class ResultCache:
156
+ def __init__(self, max_size: int):
157
+ self.max_size = max_size
158
+ self._values = OrderedDict()
159
+ self._lock = threading.Lock()
160
+ self._inflight: dict[tuple, _InflightEntry] = {}
161
+
162
+ def get_or_compute(self, request: EstimateRequest, compute_fn):
163
+ cache_key = request.cache_key
164
+
165
+ with self._lock:
166
+ if cache_key in self._values:
167
+ self._values.move_to_end(cache_key)
168
+ return copy.deepcopy(self._values[cache_key])
169
+
170
+ entry = self._inflight.get(cache_key)
171
+ if entry is None:
172
+ entry = _InflightEntry(event=threading.Event())
173
+ self._inflight[cache_key] = entry
174
+ is_owner = True
175
+ else:
176
+ is_owner = False
177
+
178
+ if not is_owner:
179
+ entry.event.wait()
180
+ if entry.error is not None:
181
+ raise entry.error
182
+ return copy.deepcopy(entry.data)
183
+
184
+ try:
185
+ data = compute_fn()
186
+ with self._lock:
187
+ self._values[cache_key] = copy.deepcopy(data)
188
+ if len(self._values) > self.max_size:
189
+ self._values.popitem(last=False)
190
+ entry.data = copy.deepcopy(data)
191
+ return copy.deepcopy(data)
192
+ except Exception as error:
193
+ entry.error = error
194
+ raise
195
+ finally:
196
+ entry.event.set()
197
+ with self._lock:
198
+ self._inflight.pop(cache_key, None)
199
+
200
+
201
+ RESULT_CACHE = ResultCache(max_size=RESULTS_CACHE_SIZE)
202
+
203
+
204
+ def get_auth_status(oauth_profile: gr.OAuthProfile | None):
205
+ if oauth_profile is None:
206
+ return "Not signed in. You can still paste an API token for gated models."
207
+
208
+ username = getattr(oauth_profile, "preferred_username", None) or getattr(oauth_profile, "name", None)
209
+ if username is None:
210
+ username = "Hugging Face user"
211
+
212
+ return (
213
+ f"Signed in as `{username}`. "
214
+ "If the API Token field is blank, this session token will be used for gated models."
215
+ )
216
+
217
+
218
+ def use_hub_search(repo_id: str | None):
219
+ return (repo_id or "").strip()
220
+
221
+
222
+ def get_hub_search_status():
223
+ if HAS_HF_HUB_SEARCH:
224
+ return "Search Hugging Face Hub to fill the model field automatically."
225
+ return "Hub Search component is unavailable in this runtime. Manual model input still works."
226
+
227
+
228
+ def validate_model_name(model_name: str):
229
+ stripped_name = model_name.strip()
230
+ if stripped_name == "":
231
+ raise gr.Error("Enter a model name or a Hugging Face model URL.")
232
+
233
+ try:
234
+ parsed = urlparse(stripped_name)
235
+ if parsed.scheme and parsed.netloc:
236
+ valid_hosts = {"huggingface.co", "www.huggingface.co"}
237
+ if parsed.netloc not in valid_hosts:
238
+ raise gr.Error("Only Hugging Face model URLs are supported here.")
239
+ except gr.Error:
240
+ raise
241
+ except Exception:
242
+ pass
243
+
244
+ return stripped_name
245
+
246
+
247
+ def validate_options(options: list):
248
+ if not options:
249
+ raise gr.Error("Select at least one precision.")
250
+
251
+
252
+ def validate_access_token(access_token: str):
253
+ if access_token and any(char.isspace() for char in access_token):
254
+ raise gr.Error("API tokens should not contain whitespace.")
255
+
256
+
257
+ def resolve_access_token(access_token: str, oauth_token: gr.OAuthToken | None):
258
+ if access_token == "":
259
+ access_token = None
260
+
261
+ if access_token is not None:
262
+ return access_token, "manual"
263
+
264
+ if oauth_token is not None:
265
+ return oauth_token.token, "oauth"
266
+
267
+ return None, "anonymous"
268
+
269
+
270
+ def build_estimate_request(
271
+ model_name: str,
272
+ library: str,
273
+ options: list,
274
+ access_token: str,
275
+ oauth_token: gr.OAuthToken | None,
276
+ ):
277
+ stripped_name = validate_model_name(model_name)
278
+ validate_options(options)
279
+ validate_access_token(access_token)
280
+
281
+ normalized_name = normalize_model_name(stripped_name)
282
+ resolved_token, auth_mode = resolve_access_token(access_token, oauth_token)
283
+
284
+ return EstimateRequest(
285
+ original_model_name=stripped_name,
286
+ normalized_model_name=normalized_name,
287
+ library=library,
288
+ options=tuple(options),
289
+ access_token=resolved_token,
290
+ auth_mode=auth_mode,
291
+ )
292
+
293
+
294
+ def get_auth_message(auth_mode: str):
295
+ if auth_mode == "manual":
296
+ return "Using the manually provided API token for this estimate."
297
+ if auth_mode == "oauth":
298
+ return "Using your Hugging Face OAuth session for this estimate."
299
+ return "Running anonymously. Gated models will require a token or a signed-in Hugging Face session."
300
+
301
+
302
+ def get_download_dir():
303
+ temp_dir = Path(tempfile.gettempdir()) / "model_memory_usage"
304
+ temp_dir.mkdir(parents=True, exist_ok=True)
305
+ return temp_dir
306
+
307
+
308
+ def cleanup_old_download_files(temp_dir: Path):
309
+ cutoff = time.time() - DOWNLOAD_RETENTION_SECONDS
310
+
311
+ try:
312
+ entries = [path for path in temp_dir.iterdir() if path.is_file()]
313
+ except FileNotFoundError:
314
+ return
315
+
316
+ for path in entries:
317
+ try:
318
+ if path.stat().st_mtime < cutoff:
319
+ path.unlink(missing_ok=True)
320
+ except OSError:
321
+ continue
322
+
323
+ try:
324
+ remaining_files = sorted(
325
+ [path for path in temp_dir.iterdir() if path.is_file()],
326
+ key=lambda path: path.stat().st_mtime,
327
+ reverse=True,
328
+ )
329
+ except FileNotFoundError:
330
+ return
331
+
332
+ for stale_path in remaining_files[DOWNLOAD_CLEANUP_MAX_FILES:]:
333
+ try:
334
+ stale_path.unlink(missing_ok=True)
335
+ except OSError:
336
+ continue
337
+
338
+
339
+ def make_download_files(model_name: str, summary_df: pd.DataFrame, breakdown_df: pd.DataFrame, raw_data: list):
340
+ safe_name = model_name.replace("/", "__") or "model"
341
+ temp_dir = get_download_dir()
342
+ cleanup_old_download_files(temp_dir)
343
+ unique_id = uuid4().hex
344
+
345
+ summary_path = temp_dir / f"{safe_name}_{unique_id}_summary.csv"
346
+ summary_df.to_csv(summary_path, index=False)
347
+
348
+ breakdown_path = None
349
+ if not breakdown_df.empty:
350
+ breakdown_path = temp_dir / f"{safe_name}_{unique_id}_adam_breakdown.csv"
351
+ breakdown_df.to_csv(breakdown_path, index=False)
352
+
353
+ json_path = temp_dir / f"{safe_name}_{unique_id}_estimate.json"
354
+ with json_path.open("w", encoding="utf-8") as handle:
355
+ json.dump({"model_name": model_name, "estimates": raw_data}, handle, indent=2)
356
+
357
+ return str(summary_path), str(breakdown_path) if breakdown_path is not None else None, str(json_path)
358
+
359
+
360
+ def fetch_raw_estimate_data(request: EstimateRequest):
361
+ def _compute():
362
+ model = get_model_normalized(
363
+ request.normalized_model_name,
364
+ request.library,
365
+ request.access_token,
366
+ skip_auth_check=True,
367
+ )
368
+ return calculate_memory(model, list(request.options))
369
+
370
+ return RESULT_CACHE.get_or_compute(request, _compute)
371
+
372
+
373
+ def build_estimate_payload(raw_rows: list[dict], options: tuple[str, ...]):
374
+ display_rows = copy.deepcopy(raw_rows)
375
+ stages = {"model": [], "gradients": [], "optimizer": [], "step": []}
376
+
377
+ for index, option in enumerate(display_rows):
378
+ for stage in stages:
379
+ stages[stage].append(option["Training using Adam (Peak vRAM)"][stage])
380
+
381
+ peak_value = max(display_rows[index]["Training using Adam (Peak vRAM)"].values())
382
+ display_rows[index]["Training using Adam (Peak vRAM)"] = "N/A" if peak_value == -1 else convert_bytes(peak_value)
383
+
384
+ explanation = ""
385
+ breakdown_df = pd.DataFrame(
386
+ columns=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"]
387
+ )
388
+
389
+ if any(value != -1 for value in stages["model"]):
390
+ explanation = "## Training using Adam explained:\n"
391
+ explanation += (
392
+ "When training on a batch size of 1, each stage of the training process is expected "
393
+ "to have near the following memory results for each precision you selected:\n"
394
+ )
395
+
396
+ for index, dtype in enumerate(options):
397
+ if stages["model"][index] != -1:
398
+ breakdown_df.loc[len(breakdown_df.index)] = [
399
+ dtype,
400
+ convert_bytes(stages["model"][index]),
401
+ convert_bytes(stages["gradients"][index]),
402
+ convert_bytes(stages["optimizer"][index]),
403
+ convert_bytes(stages["step"][index]),
404
+ ]
405
+
406
+ return EstimatePayload(
407
+ display_rows=display_rows,
408
+ raw_rows=copy.deepcopy(raw_rows),
409
+ explanation=explanation,
410
+ breakdown_df=breakdown_df,
411
+ )
412
+
413
+
414
+ def build_success_view_model(request: EstimateRequest, payload: EstimatePayload):
415
+ auth_message = get_auth_message(request.auth_mode)
416
+ summary_df = pd.DataFrame(payload.display_rows)
417
+ summary_path, breakdown_path, json_path = make_download_files(
418
+ request.normalized_model_name,
419
+ summary_df,
420
+ payload.breakdown_df,
421
+ payload.raw_rows,
422
+ )
423
+ return EstimateViewModel(
424
+ title=f"## Static memory estimate for `{request.normalized_model_name}`",
425
+ auth_message=auth_message,
426
+ summary_df=summary_df,
427
+ explanation=payload.explanation,
428
+ breakdown_df=payload.breakdown_df,
429
+ summary_path=summary_path,
430
+ breakdown_path=breakdown_path,
431
+ json_path=json_path,
432
+ )
433
+
434
+
435
+ def build_error_view_model(request: EstimateRequest, error: Exception):
436
+ auth_message = get_auth_message(request.auth_mode)
437
+ message = str(error).strip() or error.__class__.__name__
438
+ details = traceback.format_exc().strip()
439
+ return EstimateViewModel(
440
+ title=f"## Unable to estimate memory for `{request.normalized_model_name}`",
441
+ auth_message=auth_message,
442
+ summary_df=pd.DataFrame(),
443
+ explanation="",
444
+ breakdown_df=pd.DataFrame(),
445
+ error_summary=(
446
+ f"{message}\n\n"
447
+ "Check the **Details** section below for the full traceback."
448
+ ),
449
+ error_details=details,
450
+ )
451
+
452
+
453
+ def reset_app():
454
+ return ResetViewModel().to_updates()
455
+
456
+
457
+ def get_results(
458
+ model_name: str,
459
+ library: str,
460
+ options: list,
461
+ access_token: str,
462
+ oauth_token: gr.OAuthToken | None,
463
+ progress=gr.Progress(track_tqdm=False),
464
+ ):
465
+ progress(0.05, desc="Checking inputs")
466
+ request = build_estimate_request(model_name, library, options, access_token, oauth_token)
467
+
468
+ try:
469
+ progress(0.12, desc="Checking Hub access")
470
+ preflight_model_access_normalized(request.normalized_model_name, request.access_token)
471
+
472
+ progress(0.3, desc="Building model skeleton")
473
+ raw_rows = fetch_raw_estimate_data(request)
474
+
475
+ progress(0.75, desc="Formatting results")
476
+ payload = build_estimate_payload(raw_rows, request.options)
477
+
478
+ progress(0.95, desc="Writing downloads")
479
+ view_model = build_success_view_model(request, payload)
480
+ progress(1.0, desc="Done")
481
+ return view_model.to_updates()
482
+ except Exception as error:
483
+ progress(1.0, desc="Failed")
484
+ return build_error_view_model(request, error).to_updates()
485
+
486
+
487
+ with gr.Blocks(delete_cache=(3600, DOWNLOAD_RETENTION_SECONDS)) as demo:
488
+ with gr.Column():
489
+ gr.HTML(
490
+ """<img src="https://huggingface.co/spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="250" height="250"><h1>🤗 Model Memory Calculator</h1>
491
+ <p>This tool provides a static memory estimate for the vRAM needed to load and train Hub models.</p>
492
+ <p>The minimum recommended vRAM needed to load a model is denoted as the size of the "largest layer", and training of a model is roughly 4x its size (for Adam).</p>
493
+ <p>These calculations are accurate within a few percent at most, such as <code>bert-base-cased</code> being 413.68 MB and the calculator estimating 413.18 MB.</p>
494
+ <p>When performing inference, expect to add up to an additional 20% to this as found by <a href="https://blog.eleuther.ai/transformer-math/" target="_blank">EleutherAI</a>.</p>
495
+ <p>More tests will be performed in the future to get a more accurate benchmark for each model.</p>
496
+ <p>Currently this tool supports all models hosted that use <code>transformers</code> and <code>timm</code>.</p>
497
+ <p>To use this tool pass in the URL or model name of the model you want to calculate the memory usage for, select which framework it originates from (<code>auto</code> will try and detect it from the model metadata), and what precisions you want to use.</p>"""
498
+ )
499
+
500
+ with gr.Group():
501
+ with gr.Row(equal_height=True):
502
+ inp = gr.Textbox(label="Model Name or URL", value=DEFAULT_MODEL)
503
+
504
+ with gr.Column():
505
+ if HAS_HF_HUB_SEARCH:
506
+ hub_search = HuggingfaceHubSearch(
507
+ label="Search Hugging Face Hub",
508
+ placeholder="Search for models on Hugging Face",
509
+ search_type="model",
510
+ sumbit_on_select=True,
511
+ )
512
+ hub_search_status = gr.Markdown(get_hub_search_status())
513
+ else:
514
+ hub_search = None
515
+ hub_search_status = gr.Markdown(get_hub_search_status())
516
+
517
+ with gr.Row(equal_height=True):
518
+ library = gr.Radio(["auto", "transformers", "timm"], label="Library", value=DEFAULT_LIBRARY)
519
+ options = gr.CheckboxGroup(
520
+ ["float32", "float16/bfloat16", "int8", "int4"],
521
+ value=DEFAULT_OPTIONS,
522
+ label="Model Precision",
523
+ )
524
+
525
+ with gr.Column():
526
+ gr.LoginButton()
527
+ access_token = gr.Textbox(
528
+ label="API Token",
529
+ placeholder="Optional. If blank, your Sign in with HF session will be used for gated models.",
530
+ )
531
+ auth_status = gr.Markdown("Not signed in. You can still paste an API token for gated models.")
532
+ run_auth_status = gr.Markdown(visible=False)
533
+
534
+ with gr.Group():
535
+ with gr.Row(equal_height=True):
536
+ btn = gr.Button("Calculate Memory Usage")
537
+ reset_btn = gr.Button("Reset")
538
+
539
+ out_text = gr.Markdown()
540
+ error_text = gr.Markdown(visible=False)
541
+ out = gr.DataFrame(
542
+ headers=["dtype", "Largest Layer", "Total Size", "Training using Adam (Peak vRAM)"],
543
+ interactive=False,
544
+ visible=False,
545
+ )
546
+ out_explain = gr.Markdown(visible=False)
547
+ memory_values = gr.DataFrame(
548
+ headers=["dtype", "Model", "Gradient calculation", "Backward pass", "Optimizer step"],
549
+ interactive=False,
550
+ visible=False,
551
+ )
552
+
553
+ with gr.Accordion("Downloads", open=False):
554
+ summary_file = gr.File(label="Summary CSV", visible=False)
555
+ breakdown_file = gr.File(label="Adam Breakdown CSV", visible=False)
556
+ json_file = gr.File(label="Full JSON", visible=False)
557
+
558
+ with gr.Accordion("Details", open=False):
559
+ error_details = gr.Textbox(
560
+ label="Error Details",
561
+ lines=12,
562
+ interactive=False,
563
+ visible=False,
564
+ )
565
+
566
+ demo.load(
567
+ get_auth_status,
568
+ inputs=None,
569
+ outputs=auth_status,
570
+ api_name=False,
571
+ queue=False,
572
+ )
573
+
574
+ if HAS_HF_HUB_SEARCH:
575
+ gr.on(
576
+ triggers=[hub_search.submit],
577
+ fn=use_hub_search,
578
+ inputs=[hub_search],
579
+ outputs=[inp],
580
+ api_name=False,
581
+ show_progress="hidden",
582
+ queue=False,
583
+ )
584
+
585
+ gr.on(
586
+ triggers=[btn.click, inp.submit],
587
+ fn=get_results,
588
+ inputs=[inp, library, options, access_token],
589
+ outputs=[
590
+ out_text,
591
+ run_auth_status,
592
+ out,
593
+ out_explain,
594
+ memory_values,
595
+ error_text,
596
+ error_details,
597
+ summary_file,
598
+ breakdown_file,
599
+ json_file,
600
+ ],
601
+ show_api=False,
602
+ show_progress="minimal",
603
+ concurrency_limit=1,
604
+ concurrency_id="memory-estimate",
605
+ )
606
+
607
+ reset_btn.click(
608
+ reset_app,
609
+ inputs=None,
610
+ outputs=[
611
+ inp,
612
+ library,
613
+ options,
614
+ access_token,
615
+ out_text,
616
+ run_auth_status,
617
+ out,
618
+ out_explain,
619
+ memory_values,
620
+ error_text,
621
+ error_details,
622
+ summary_file,
623
+ breakdown_file,
624
+ json_file,
625
+ ],
626
+ api_name=False,
627
+ show_progress="hidden",
628
+ queue=False,
629
+ )
630
+
631
+
632
+ demo.queue(default_concurrency_limit=1, max_size=24)
633
+ demo.launch()
src/hub_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utilities related to searching and posting on the Hub
2
+ import os
3
+ import webbrowser
4
+
5
+ import pandas as pd
6
+ from huggingface_hub import HfApi
7
+ from model_utils import calculate_memory, extract_from_url, get_model
8
+
9
+
10
+ def check_for_discussion(model_name: str):
11
+ "Checks if an automated discussion has been opened on the model by `model-sizer-bot`"
12
+ api = HfApi(token=os.environ.get("HUGGINGFACE_API_LOGIN", None))
13
+ model_name = extract_from_url(model_name)
14
+ discussions = list(api.get_repo_discussions(model_name))
15
+ return any(
16
+ discussion.author == "model-sizer-bot"
17
+ for discussion in discussions
18
+ )
19
+
20
+
21
+ def report_results(model_name, library, access_token):
22
+ "Reports the results of a memory calculation to the model's discussion page, and opens a new tab to it afterwards"
23
+ model = get_model(model_name, library, access_token)
24
+ data = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
25
+ df = pd.DataFrame(data).to_markdown(index=False)
26
+
27
+ post = f"""# Model Memory Requirements\n
28
+
29
+ You will need about {data[1]} VRAM to load this model for inference, and {data[3]} VRAM to train it using Adam.
30
+
31
+ These calculations were measured from the [Model Memory Utility Space](https://huggingface.co/spaces/hf-accelerate/model-memory-usage) on the Hub.
32
+
33
+ The minimum recommended vRAM needed for this model assumes using [Accelerate or `device_map="auto"`](https://huggingface.co/docs/accelerate/usage_guides/big_modeling) and is denoted by the size of the "largest layer".
34
+ When performing inference, expect to add up to an additional 20% to this, as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). More tests will be performed in the future to get a more accurate benchmark for each model.
35
+
36
+ When training with `Adam`, you can expect roughly 4x the reported results to be used. (1x for the model, 1x for the gradients, and 2x for the optimizer).
37
+
38
+ ## Results:
39
+
40
+ {df}
41
+ """
42
+ api = HfApi(token=os.environ.get("HUGGINGFACE_API_LOGIN", None))
43
+ discussion = api.create_discussion(model_name, "[AUTOMATED] Model Memory Requirements", description=post)
44
+ webbrowser.open_new_tab(discussion.url)
src/model_utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utilities related to loading in and working with models/specific models
2
+ from urllib.parse import unquote, urlparse
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from accelerate.commands.estimate import check_has_model, create_empty_model, estimate_training_usage
7
+ from accelerate.utils import calculate_maximum_sizes, convert_bytes
8
+ from huggingface_hub import auth_check
9
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
10
+
11
+
12
+ DTYPE_MODIFIER = {"float32": 1, "float16/bfloat16": 2, "int8": 4, "int4": 8}
13
+
14
+
15
+ def extract_from_url(name: str):
16
+ "Checks if `name` is a URL, and if so converts it to a model name"
17
+ is_url = False
18
+ try:
19
+ result = urlparse(name)
20
+ is_url = all([result.scheme, result.netloc])
21
+ except Exception:
22
+ is_url = False
23
+
24
+ if not is_url:
25
+ return name
26
+
27
+ path = unquote(result.path).strip("/")
28
+ if path == "":
29
+ return name
30
+
31
+ parts = [part for part in path.split("/") if part]
32
+ if len(parts) >= 3 and parts[0] in {"models", "datasets", "spaces"}:
33
+ parts = parts[1:]
34
+
35
+ if len(parts) >= 2:
36
+ return "/".join(parts[:2])
37
+ return "/".join(parts)
38
+
39
+
40
+ def translate_llama(text: str):
41
+ "Translates Llama-2 and CodeLlama to its hf counterpart"
42
+ if not text.endswith("-hf"):
43
+ return text + "-hf"
44
+ return text
45
+
46
+
47
+ def normalize_model_name(model_name: str):
48
+ model_name = extract_from_url(model_name.strip())
49
+ if "meta-llama/Llama-2-" in model_name or "meta-llama/CodeLlama-" in model_name:
50
+ model_name = translate_llama(model_name)
51
+ return model_name.rstrip("/")
52
+
53
+
54
+ def classify_loader_error(model_name: str, error: Exception):
55
+ message = str(error)
56
+ lowered = message.lower()
57
+
58
+ if "timed out" in lowered or "timeout" in lowered:
59
+ return gr.Error(
60
+ f"Model `{model_name}` timed out during the Hub access or static initialization step. "
61
+ "Please try again, try a narrower model repo, or select the library manually."
62
+ )
63
+
64
+ if (
65
+ "401" in lowered
66
+ or "403" in lowered
67
+ or "unauthorized" in lowered
68
+ or "forbidden" in lowered
69
+ or "permission" in lowered
70
+ ):
71
+ return gr.Error(
72
+ f"Model `{model_name}` could not be accessed with the current credentials. "
73
+ "Please sign in with Hugging Face or paste a token that has access to this repo."
74
+ )
75
+
76
+ if "connection" in lowered or "temporarily unavailable" in lowered or "service unavailable" in lowered:
77
+ return gr.Error(
78
+ f"Model `{model_name}` could not be reached from this Space right now. "
79
+ "Please retry in a moment."
80
+ )
81
+
82
+ if "no module named" in lowered or "cannot import name" in lowered:
83
+ return gr.Error(
84
+ f"Model `{model_name}` requires custom code or extra dependencies that are not available in this Space. "
85
+ f"This often means the repository depends on a package that is not installed here. Error: `{error}`"
86
+ )
87
+
88
+ if "trust_remote_code" in lowered or "remote code" in lowered:
89
+ return gr.Error(
90
+ f"Model `{model_name}` uses custom code from the Hub and could not be initialized in this Space. "
91
+ f"Please inspect the repository code and make sure it is trusted and compatible with the current runtime. Error: `{error}`"
92
+ )
93
+
94
+ if "config" in lowered and "auto" in lowered:
95
+ return gr.Error(
96
+ f"Model `{model_name}` could not be resolved through the current library auto-detection path. "
97
+ f"Please try selecting `transformers` or `timm` manually. Error: `{error}`"
98
+ )
99
+
100
+ return gr.Error(
101
+ f"Model `{model_name}` had an error during static initialization in this Space. "
102
+ f"Please open a discussion on the model page and include this message: `{error}`"
103
+ )
104
+
105
+
106
+ def raise_model_error(model_name: str, error: Exception):
107
+ raise classify_loader_error(model_name, error)
108
+
109
+
110
+ def preflight_model_access_normalized(normalized_name: str, access_token: str | None):
111
+ try:
112
+ auth_check(normalized_name, token=access_token)
113
+ except GatedRepoError:
114
+ raise gr.Error(
115
+ f"Model `{normalized_name}` is a gated model. Please sign in with Hugging Face or pass an access token that already has access."
116
+ )
117
+ except RepositoryNotFoundError:
118
+ raise gr.Error(f"Model `{normalized_name}` was not found on the Hub. Please try another model name.")
119
+ except gr.Error:
120
+ raise
121
+ except Exception as error:
122
+ classified_error = classify_loader_error(normalized_name, error)
123
+ if "timed out" in str(classified_error).lower():
124
+ raise classified_error
125
+ if "could not be accessed" in str(classified_error).lower():
126
+ raise classified_error
127
+ if "could not be reached" in str(classified_error).lower():
128
+ raise classified_error
129
+ # Fallback to the loader path for transient Hub metadata issues.
130
+ pass
131
+
132
+ return normalized_name
133
+
134
+
135
+ def preflight_model_access(model_name: str, access_token: str | None):
136
+ return preflight_model_access_normalized(normalize_model_name(model_name), access_token)
137
+
138
+
139
+ def get_model_normalized(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False):
140
+ "Finds and grabs model from the Hub, and initializes on `meta`"
141
+ if library == "auto":
142
+ library = None
143
+
144
+ if not skip_auth_check:
145
+ preflight_model_access_normalized(model_name, access_token)
146
+
147
+ try:
148
+ model = create_empty_model(model_name, library_name=library, trust_remote_code=True, access_token=access_token)
149
+ except GatedRepoError:
150
+ raise gr.Error(
151
+ f"Model `{model_name}` is a gated model, please ensure to pass in your access token or sign in with Hugging Face and try again if you have access."
152
+ )
153
+ except RepositoryNotFoundError:
154
+ raise gr.Error(f"Model `{model_name}` was not found on the Hub, please try another model name.")
155
+ except ValueError:
156
+ raise gr.Error(
157
+ f"Model `{model_name}` does not have any library metadata on the Hub, please manually select a library_name to use (such as `transformers`)"
158
+ )
159
+ except (RuntimeError, OSError) as error:
160
+ library_name = check_has_model(error)
161
+ if library_name != "unknown":
162
+ raise gr.Error(
163
+ f"Tried to load `{model_name}` with `{library_name}` but a possible model to load was not found inside the repo."
164
+ )
165
+ raise_model_error(model_name, error)
166
+ except ImportError as error:
167
+ try:
168
+ model = create_empty_model(
169
+ model_name, library_name=library, trust_remote_code=False, access_token=access_token
170
+ )
171
+ except Exception:
172
+ raise_model_error(model_name, error)
173
+ except Exception as error:
174
+ raise_model_error(model_name, error)
175
+ return model
176
+
177
+
178
+ def get_model(model_name: str, library: str, access_token: str | None, skip_auth_check: bool = False):
179
+ return get_model_normalized(
180
+ normalize_model_name(model_name),
181
+ library,
182
+ access_token,
183
+ skip_auth_check=skip_auth_check,
184
+ )
185
+
186
+
187
+ def calculate_memory(model: torch.nn.Module, options: list):
188
+ "Calculates the memory usage for a model init on `meta` device"
189
+ total_size, largest_layer = calculate_maximum_sizes(model)
190
+
191
+ data = []
192
+ for dtype in options:
193
+ dtype_total_size = total_size
194
+ dtype_largest_layer = largest_layer[0]
195
+
196
+ modifier = DTYPE_MODIFIER[dtype]
197
+ dtype_training_size = estimate_training_usage(
198
+ dtype_total_size, dtype if dtype != "float16/bfloat16" else "float16"
199
+ )
200
+ dtype_total_size /= modifier
201
+ dtype_largest_layer /= modifier
202
+
203
+ dtype_total_size = convert_bytes(dtype_total_size)
204
+ dtype_largest_layer = convert_bytes(dtype_largest_layer)
205
+ data.append(
206
+ {
207
+ "dtype": dtype,
208
+ "Largest Layer or Residual Group": dtype_largest_layer,
209
+ "Total Size": dtype_total_size,
210
+ "Training using Adam (Peak vRAM)": dtype_training_size,
211
+ }
212
+ )
213
+ return data