File size: 999 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
import torch
from ltx_pipelines.utils.helpers import cleanup_memory
_M = TypeVar("_M", bound=torch.nn.Module)
@contextmanager
def gpu_model(model: _M) -> Iterator[_M]:
"""Context manager that yields a model and releases its memory on exit.
Moves all parameters and buffers to ``meta`` device on exit, which
immediately releases the underlying storage on **both** GPU and CPU,
then runs ``cleanup_memory()`` to reclaim fragmented CUDA memory.
Usage::
with gpu_model(build_encoder()) as encoder:
... # use encoder — typed as the concrete class
# GPU + CPU memory freed automatically
"""
try:
yield model
finally:
torch.cuda.synchronize()
# .to("meta") releases storage for all parameters/buffers regardless
# of their original device (CUDA or CPU).
model.to("meta")
cleanup_memory()
|