| import torch |
| import psutil |
| import os |
| import logging |
| from typing import Optional |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class GPUMemoryMonitor: |
| def __init__(self, |
| memory_threshold: float = 0.9, |
| check_interval: int = 100, |
| gpu_id: Optional[int] = None): |
| self.memory_threshold = memory_threshold |
| self.check_interval = check_interval |
| self.gpu_id = gpu_id if gpu_id is not None else 0 |
| self.step_count = 0 |
| |
| if not torch.cuda.is_available(): |
| logger.warning("CUDA is not available. GPU monitoring will be disabled.") |
| self.enabled = False |
| else: |
| self.enabled = True |
| self.device = torch.device(f"cuda:{self.gpu_id}") |
| |
| def check_memory(self) -> bool: |
| """Check if GPU memory usage is below threshold""" |
| if not self.enabled: |
| return True |
| |
| self.step_count += 1 |
| if self.step_count % self.check_interval != 0: |
| return True |
| |
| try: |
| |
| memory_allocated = torch.cuda.memory_allocated(self.device) |
| memory_reserved = torch.cuda.memory_reserved(self.device) |
| memory_total = torch.cuda.get_device_properties(self.device).total_memory |
| |
| |
| memory_ratio = memory_allocated / memory_total |
| |
| if memory_ratio > self.memory_threshold: |
| logger.warning(f"GPU memory usage ({memory_ratio:.2%}) exceeds threshold ({self.memory_threshold:.2%})") |
| return False |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error checking GPU memory: {str(e)}") |
| return True |
| |
| def clear_memory(self): |
| """Clear GPU memory cache""" |
| if self.enabled: |
| torch.cuda.empty_cache() |
| |
| def get_memory_stats(self) -> dict: |
| """Get current GPU memory statistics""" |
| if not self.enabled: |
| return {"enabled": False} |
| |
| try: |
| memory_allocated = torch.cuda.memory_allocated(self.device) |
| memory_reserved = torch.cuda.memory_reserved(self.device) |
| memory_total = torch.cuda.get_device_properties(self.device).total_memory |
| |
| return { |
| "enabled": True, |
| "allocated_gb": memory_allocated / 1024**3, |
| "reserved_gb": memory_reserved / 1024**3, |
| "total_gb": memory_total / 1024**3, |
| "usage_ratio": memory_allocated / memory_total |
| } |
| except Exception as e: |
| logger.error(f"Error getting GPU memory stats: {str(e)}") |
| return {"enabled": False, "error": str(e)} |