| import torch |
| import torch.nn as nn |
| from typing import Any, Dict, List, Tuple |
|
|
| def inspect_model_devices(model, prefix="", max_depth=3, current_depth=0): |
| """ |
| Recursively inspect all properties of a PyTorch Lightning model to find |
| which tensors are on CPU vs CUDA devices and check density/sparsity. |
| |
| Args: |
| model: The model/object to inspect |
| prefix: String prefix for nested attributes |
| max_depth: Maximum recursion depth to prevent infinite loops |
| current_depth: Current recursion depth |
| |
| Returns: |
| Dict with categorized results |
| """ |
| results = { |
| 'cuda_tensors': [], |
| 'cpu_tensors': [], |
| 'mixed_tensors': [], |
| 'sparse_tensors': [], |
| 'non_contiguous_tensors': [], |
| 'problematic_tensors': [], |
| 'non_tensor_attrs': [], |
| 'errors': [] |
| } |
| |
| if current_depth >= max_depth: |
| return results |
| |
| |
| for attr_name in dir(model): |
| |
| if attr_name.startswith('_'): |
| continue |
| |
| try: |
| attr_value = getattr(model, attr_name) |
| full_name = f"{prefix}.{attr_name}" if prefix else attr_name |
| |
| |
| if isinstance(attr_value, torch.Tensor): |
| |
| is_sparse = attr_value.is_sparse or attr_value.is_sparse_csr |
| is_contiguous = attr_value.is_contiguous() |
| is_cuda = attr_value.device.type == 'cuda' |
| |
| device_info = { |
| 'name': full_name, |
| 'shape': tuple(attr_value.shape), |
| 'dtype': str(attr_value.dtype), |
| 'device': str(attr_value.device), |
| 'requires_grad': attr_value.requires_grad, |
| 'is_sparse': is_sparse, |
| 'is_contiguous': is_contiguous, |
| 'stride': tuple(attr_value.stride()) if not is_sparse else 'N/A (sparse)', |
| 'storage_offset': attr_value.storage_offset() if not is_sparse else 'N/A (sparse)', |
| 'numel': attr_value.numel() |
| } |
| |
| |
| has_issues = [] |
| if not is_cuda: |
| has_issues.append('CPU') |
| if is_sparse: |
| has_issues.append('SPARSE') |
| results['sparse_tensors'].append(device_info) |
| if not is_contiguous: |
| has_issues.append('NON_CONTIGUOUS') |
| results['non_contiguous_tensors'].append(device_info) |
| |
| if has_issues: |
| device_info['issues'] = has_issues |
| results['problematic_tensors'].append(device_info) |
| |
| |
| if is_cuda: |
| results['cuda_tensors'].append(device_info) |
| else: |
| results['cpu_tensors'].append(device_info) |
| |
| |
| elif isinstance(attr_value, nn.Parameter): |
| is_sparse = attr_value.is_sparse or attr_value.is_sparse_csr |
| is_contiguous = attr_value.is_contiguous() |
| is_cuda = attr_value.device.type == 'cuda' |
| |
| device_info = { |
| 'name': full_name, |
| 'shape': tuple(attr_value.shape), |
| 'dtype': str(attr_value.dtype), |
| 'device': str(attr_value.device), |
| 'requires_grad': attr_value.requires_grad, |
| 'type': 'Parameter', |
| 'is_sparse': is_sparse, |
| 'is_contiguous': is_contiguous, |
| 'stride': tuple(attr_value.stride()) if not is_sparse else 'N/A (sparse)', |
| 'storage_offset': attr_value.storage_offset() if not is_sparse else 'N/A (sparse)', |
| 'numel': attr_value.numel() |
| } |
| |
| |
| has_issues = [] |
| if not is_cuda: |
| has_issues.append('CPU') |
| if is_sparse: |
| has_issues.append('SPARSE') |
| results['sparse_tensors'].append(device_info) |
| if not is_contiguous: |
| has_issues.append('NON_CONTIGUOUS') |
| results['non_contiguous_tensors'].append(device_info) |
| |
| if has_issues: |
| device_info['issues'] = has_issues |
| results['problematic_tensors'].append(device_info) |
| |
| if is_cuda: |
| results['cuda_tensors'].append(device_info) |
| else: |
| results['cpu_tensors'].append(device_info) |
| |
| |
| elif isinstance(attr_value, nn.Module): |
| |
| module_devices = set() |
| for param in attr_value.parameters(): |
| module_devices.add(param.device.type) |
| for buffer in attr_value.buffers(): |
| module_devices.add(buffer.device.type) |
| |
| if len(module_devices) > 1: |
| results['mixed_tensors'].append({ |
| 'name': full_name, |
| 'type': type(attr_value).__name__, |
| 'devices': list(module_devices) |
| }) |
| elif len(module_devices) == 1: |
| device_type = list(module_devices)[0] |
| module_info = { |
| 'name': full_name, |
| 'type': type(attr_value).__name__, |
| 'device': device_type |
| } |
| |
| if device_type == 'cuda': |
| results['cuda_tensors'].append(module_info) |
| else: |
| results['cpu_tensors'].append(module_info) |
| |
| |
| if current_depth < max_depth - 1: |
| sub_results = inspect_model_devices(attr_value, full_name, max_depth, current_depth + 1) |
| for key in results: |
| results[key].extend(sub_results[key]) |
| |
| |
| elif hasattr(attr_value, '__dict__') and not callable(attr_value): |
| if current_depth < max_depth - 1: |
| sub_results = inspect_model_devices(attr_value, full_name, max_depth, current_depth + 1) |
| for key in results: |
| results[key].extend(sub_results[key]) |
| else: |
| |
| if not callable(attr_value): |
| results['non_tensor_attrs'].append({ |
| 'name': full_name, |
| 'type': type(attr_value).__name__, |
| 'value': str(attr_value)[:100] |
| }) |
| |
| except Exception as e: |
| results['errors'].append({ |
| 'name': full_name, |
| 'error': str(e) |
| }) |
| continue |
| |
| return results |
|
|
| def print_device_report(model, detailed=False): |
| """ |
| Print a formatted report of device allocation for all model components. |
| |
| Args: |
| model: PyTorch Lightning model to inspect |
| detailed: If True, show detailed information for each tensor |
| """ |
| print("="*80) |
| print("COMPREHENSIVE DEVICE & TENSOR DENSITY REPORT") |
| print("="*80) |
| |
| results = inspect_model_devices(model) |
| |
| |
| if results['problematic_tensors']: |
| print(f"\n🚨 PROBLEMATIC TENSORS ({len(results['problematic_tensors'])}) - LIKELY CAUSING ISSUES!") |
| print("-" * 70) |
| for item in results['problematic_tensors']: |
| issues_str = " | ".join(item['issues']) |
| print(f" ❌ {item['name']}: {issues_str}") |
| if detailed: |
| print(f" Shape: {item['shape']} | Device: {item['device']}") |
| print(f" Contiguous: {item['is_contiguous']} | Sparse: {item['is_sparse']}") |
| if item['stride'] != 'N/A (sparse)': |
| print(f" Stride: {item['stride']} | Storage offset: {item['storage_offset']}") |
| print() |
| |
| print(f"\n📍 CUDA TENSORS/MODULES ({len(results['cuda_tensors'])})") |
| print("-" * 40) |
| for item in results['cuda_tensors']: |
| status_indicators = [] |
| if not item.get('is_contiguous', True): |
| status_indicators.append('NON-CONTIGUOUS') |
| if item.get('is_sparse', False): |
| status_indicators.append('SPARSE') |
| |
| status_str = f" [{', '.join(status_indicators)}]" if status_indicators else "" |
| |
| if detailed: |
| if 'shape' in item: |
| print(f" {item['name']}: {item['shape']} | {item['dtype']} | {item['device']}{status_str}") |
| if 'stride' in item and item['stride'] != 'N/A (sparse)': |
| print(f" ↳ Contiguous: {item.get('is_contiguous', 'N/A')} | Stride: {item['stride']}") |
| else: |
| print(f" {item['name']}: {item['type']} | {item['device']}{status_str}") |
| else: |
| print(f" ✅ {item['name']}{status_str}") |
| |
| print(f"\n💻 CPU TENSORS/MODULES ({len(results['cpu_tensors'])})") |
| print("-" * 40) |
| for item in results['cpu_tensors']: |
| status_indicators = [] |
| if not item.get('is_contiguous', True): |
| status_indicators.append('NON-CONTIGUOUS') |
| if item.get('is_sparse', False): |
| status_indicators.append('SPARSE') |
| |
| status_str = f" [{', '.join(status_indicators)}]" if status_indicators else "" |
| |
| if detailed: |
| if 'shape' in item: |
| print(f" {item['name']}: {item['shape']} | {item['dtype']} | {item['device']}{status_str}") |
| if 'stride' in item and item['stride'] != 'N/A (sparse)': |
| print(f" ↳ Contiguous: {item.get('is_contiguous', 'N/A')} | Stride: {item['stride']}") |
| else: |
| print(f" {item['name']}: {item['type']} | {item['device']}{status_str}") |
| else: |
| print(f" ❌ {item['name']}{status_str}") |
| |
| if results['sparse_tensors']: |
| print(f"\n🕳️ SPARSE TENSORS ({len(results['sparse_tensors'])})") |
| print("-" * 40) |
| for item in results['sparse_tensors']: |
| print(f" {item['name']}: {item['shape']} | {item['device']} | SPARSE") |
| |
| if results['non_contiguous_tensors']: |
| print(f"\n📐 NON-CONTIGUOUS TENSORS ({len(results['non_contiguous_tensors'])})") |
| print("-" * 40) |
| for item in results['non_contiguous_tensors']: |
| print(f" {item['name']}: {item['shape']} | Stride: {item['stride']}") |
| if detailed: |
| print(f" ↳ Storage offset: {item['storage_offset']}") |
| |
| if results['mixed_tensors']: |
| print(f"\n⚠️ MIXED DEVICE MODULES ({len(results['mixed_tensors'])})") |
| print("-" * 40) |
| for item in results['mixed_tensors']: |
| print(f" {item['name']}: {item['type']} | Devices: {item['devices']}") |
| |
| if results['errors']: |
| print(f"\n❗ ERRORS ({len(results['errors'])})") |
| print("-" * 40) |
| for item in results['errors']: |
| print(f" {item['name']}: {item['error']}") |
| |
| print(f"\n📊 DETAILED SUMMARY") |
| print("-" * 40) |
| print(f" CUDA components: {len(results['cuda_tensors'])}") |
| print(f" CPU components: {len(results['cpu_tensors'])}") |
| print(f" Sparse tensors: {len(results['sparse_tensors'])}") |
| print(f" Non-contiguous: {len(results['non_contiguous_tensors'])}") |
| print(f" Mixed components: {len(results['mixed_tensors'])}") |
| print(f" Errors: {len(results['errors'])}") |
| print(f" 🚨 TOTAL PROBLEMATIC: {len(results['problematic_tensors'])}") |
| |
| if results['problematic_tensors']: |
| print(f"\n⚠️ CRITICAL: Found {len(results['problematic_tensors'])} problematic tensors!") |
| print(" These are likely causing the 'Tensors must be CUDA and dense' error.") |
| print(" Focus on fixing the tensors marked with 🚨 above.") |