| """ |
| Comprehensive error handling and recovery utilities for BitTransformerLM. |
| |
| Provides robust error recovery mechanisms, graceful degradation, and detailed |
| error logging for production deployments. |
| """ |
|
|
| import logging |
| import traceback |
| import functools |
| from typing import Dict, Any, Optional, Callable, Union, Type |
| from contextlib import contextmanager |
| import torch |
| import numpy as np |
|
|
| from .types import ErrorHandler, RecoveryStrategy, LogLevel, TensorLike |
|
|
|
|
| class BitTransformerError(Exception): |
| """Base exception class for BitTransformerLM errors.""" |
| |
| def __init__(self, message: str, error_code: str = "BTLM_ERROR", |
| context: Optional[Dict[str, Any]] = None): |
| self.message = message |
| self.error_code = error_code |
| self.context = context or {} |
| super().__init__(f"[{error_code}] {message}") |
|
|
|
|
| class ModelError(BitTransformerError): |
| """Errors related to model operations.""" |
| pass |
|
|
|
|
| class CompressionError(BitTransformerError): |
| """Errors related to compression/decompression.""" |
| pass |
|
|
|
|
| class SafetyError(BitTransformerError): |
| """Errors related to safety gates and telemetry.""" |
| pass |
|
|
|
|
| class DataError(BitTransformerError): |
| """Errors related to data processing.""" |
| pass |
|
|
|
|
| class DistributedError(BitTransformerError): |
| """Errors related to distributed training.""" |
| pass |
|
|
|
|
| class ErrorRecoveryManager: |
| """Manages error recovery strategies and fallback mechanisms.""" |
| |
| def __init__(self, logger: Optional[logging.Logger] = None): |
| self.logger = logger or logging.getLogger(__name__) |
| self.recovery_strategies: Dict[Type[Exception], RecoveryStrategy] = {} |
| self.error_counts: Dict[str, int] = {} |
| self.max_retries = 3 |
| |
| def register_recovery_strategy(self, |
| error_type: Type[Exception], |
| strategy: RecoveryStrategy) -> None: |
| """Register a recovery strategy for a specific error type.""" |
| self.recovery_strategies[error_type] = strategy |
| |
| def handle_error(self, |
| error: Exception, |
| context: Optional[Dict[str, Any]] = None, |
| allow_recovery: bool = True) -> Any: |
| """Handle an error with potential recovery.""" |
| error_key = f"{type(error).__name__}:{str(error)}" |
| self.error_counts[error_key] = self.error_counts.get(error_key, 0) + 1 |
| |
| self.logger.error( |
| f"Error occurred: {error}\n" |
| f"Context: {context}\n" |
| f"Traceback: {traceback.format_exc()}" |
| ) |
| |
| if allow_recovery and self.error_counts[error_key] <= self.max_retries: |
| |
| for error_type, strategy in self.recovery_strategies.items(): |
| if isinstance(error, error_type): |
| try: |
| self.logger.info(f"Attempting recovery for {type(error).__name__}") |
| return strategy() |
| except Exception as recovery_error: |
| self.logger.error(f"Recovery failed: {recovery_error}") |
| break |
| |
| |
| raise error |
|
|
|
|
| |
| error_manager = ErrorRecoveryManager() |
|
|
|
|
| def with_error_recovery(recovery_value: Any = None, |
| max_retries: int = 3, |
| error_types: Optional[tuple] = None): |
| """Decorator for adding error recovery to functions.""" |
| def decorator(func: Callable) -> Callable: |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| last_error = None |
| |
| for attempt in range(max_retries + 1): |
| try: |
| return func(*args, **kwargs) |
| except Exception as e: |
| last_error = e |
| |
| |
| if error_types and not isinstance(e, error_types): |
| raise |
| |
| if attempt < max_retries: |
| error_manager.logger.warning( |
| f"Function {func.__name__} failed (attempt {attempt + 1}), retrying..." |
| ) |
| continue |
| |
| |
| error_manager.logger.error( |
| f"Function {func.__name__} failed after {max_retries + 1} attempts" |
| ) |
| break |
| |
| |
| if recovery_value is not None: |
| return recovery_value |
| raise last_error |
| |
| return wrapper |
| return decorator |
|
|
|
|
| @contextmanager |
| def safe_operation(operation_name: str, |
| context: Optional[Dict[str, Any]] = None, |
| recovery_value: Any = None): |
| """Context manager for safe operations with error handling.""" |
| try: |
| error_manager.logger.debug(f"Starting operation: {operation_name}") |
| yield |
| error_manager.logger.debug(f"Completed operation: {operation_name}") |
| except Exception as e: |
| error_context = {"operation": operation_name} |
| if context: |
| error_context.update(context) |
| |
| try: |
| return error_manager.handle_error(e, error_context) |
| except: |
| if recovery_value is not None: |
| error_manager.logger.warning( |
| f"Operation {operation_name} failed, using recovery value" |
| ) |
| return recovery_value |
| raise |
|
|
|
|
| def safe_tensor_operation(tensor_op: Callable[[torch.Tensor], torch.Tensor], |
| fallback_value: Optional[torch.Tensor] = None) -> Callable: |
| """Wrapper for tensor operations with safety checks.""" |
| def wrapper(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
| |
| if not isinstance(tensor, torch.Tensor): |
| raise DataError("Input must be a torch.Tensor") |
| |
| if tensor.numel() == 0: |
| if fallback_value is not None: |
| return fallback_value |
| raise DataError("Cannot operate on empty tensor") |
| |
| |
| if torch.isnan(tensor).any(): |
| error_manager.logger.warning("NaN values detected in tensor, attempting to clean") |
| tensor = torch.nan_to_num(tensor, nan=0.0) |
| |
| if torch.isinf(tensor).any(): |
| error_manager.logger.warning("Inf values detected in tensor, attempting to clean") |
| tensor = torch.nan_to_num(tensor, posinf=1e6, neginf=-1e6) |
| |
| try: |
| return tensor_op(tensor, *args, **kwargs) |
| except (RuntimeError, ValueError) as e: |
| if "out of memory" in str(e).lower(): |
| |
| error_manager.logger.warning("OOM detected, attempting chunked operation") |
| return _chunked_tensor_operation(tensor_op, tensor, *args, **kwargs) |
| elif "device" in str(e).lower(): |
| |
| error_manager.logger.warning("Device mismatch, attempting CPU fallback") |
| return tensor_op(tensor.cpu(), *args, **kwargs) |
| else: |
| raise |
| |
| return wrapper |
|
|
|
|
| def _chunked_tensor_operation(tensor_op: Callable, |
| tensor: torch.Tensor, |
| chunk_size: int = 1024, |
| *args, **kwargs) -> torch.Tensor: |
| """Execute tensor operation in chunks to avoid OOM.""" |
| if tensor.size(0) <= chunk_size: |
| return tensor_op(tensor, *args, **kwargs) |
| |
| results = [] |
| for i in range(0, tensor.size(0), chunk_size): |
| chunk = tensor[i:i + chunk_size] |
| chunk_result = tensor_op(chunk, *args, **kwargs) |
| results.append(chunk_result) |
| |
| return torch.cat(results, dim=0) |
|
|
|
|
| def validate_model_inputs(inputs: torch.Tensor, |
| max_seq_len: int = 8192, |
| expected_dtype: torch.dtype = torch.long) -> torch.Tensor: |
| """Validate and sanitize model inputs.""" |
| if not isinstance(inputs, torch.Tensor): |
| raise DataError("Model inputs must be torch.Tensor") |
| |
| |
| if inputs.dim() == 1: |
| inputs = inputs.unsqueeze(0) |
| elif inputs.dim() > 2: |
| raise DataError(f"Input tensor has too many dimensions: {inputs.dim()}") |
| |
| |
| if inputs.size(-1) > max_seq_len: |
| error_manager.logger.warning(f"Sequence length {inputs.size(-1)} exceeds max {max_seq_len}, truncating") |
| inputs = inputs[:, :max_seq_len] |
| |
| |
| if inputs.dtype != expected_dtype: |
| error_manager.logger.warning(f"Converting input dtype from {inputs.dtype} to {expected_dtype}") |
| inputs = inputs.to(expected_dtype) |
| |
| |
| if expected_dtype == torch.long: |
| invalid_values = (inputs < 0) | (inputs > 1) |
| if invalid_values.any(): |
| error_manager.logger.warning("Invalid bit values detected, clamping to [0, 1]") |
| inputs = torch.clamp(inputs, 0, 1) |
| |
| return inputs |
|
|
|
|
| def safe_model_forward(model: torch.nn.Module, |
| inputs: torch.Tensor, |
| **kwargs) -> torch.Tensor: |
| """Safely execute model forward pass with error recovery.""" |
| inputs = validate_model_inputs(inputs) |
| |
| try: |
| with safe_operation("model_forward"): |
| return model(inputs, **kwargs) |
| except RuntimeError as e: |
| if "out of memory" in str(e).lower(): |
| |
| error_manager.logger.warning("OOM in forward pass, enabling gradient checkpointing") |
| from torch.utils.checkpoint import checkpoint |
| return checkpoint(model, inputs, **kwargs) |
| elif "device" in str(e).lower(): |
| |
| device = next(model.parameters()).device |
| inputs = inputs.to(device) |
| return model(inputs, **kwargs) |
| else: |
| raise |
|
|
|
|
| def recovery_checkpoint_save(model: torch.nn.Module, |
| path: str, |
| additional_data: Optional[Dict[str, Any]] = None) -> bool: |
| """Save model checkpoint with error recovery.""" |
| try: |
| checkpoint_data = { |
| 'model_state_dict': model.state_dict(), |
| 'timestamp': torch.tensor(0), |
| } |
| if additional_data: |
| checkpoint_data.update(additional_data) |
| |
| torch.save(checkpoint_data, path, _use_new_zipfile_serialization=True) |
| error_manager.logger.info(f"Checkpoint saved successfully to {path}") |
| return True |
| |
| except Exception as e: |
| error_manager.logger.error(f"Failed to save checkpoint to {path}: {e}") |
| |
| |
| backup_path = path + ".backup" |
| try: |
| torch.save(checkpoint_data, backup_path) |
| error_manager.logger.info(f"Checkpoint saved to backup location: {backup_path}") |
| return True |
| except Exception as backup_e: |
| error_manager.logger.error(f"Backup save also failed: {backup_e}") |
| return False |
|
|
|
|
| def setup_error_logging(log_level: LogLevel = "INFO", |
| log_file: Optional[str] = None) -> logging.Logger: |
| """Set up comprehensive error logging.""" |
| logger = logging.getLogger("BitTransformerLM") |
| logger.setLevel(getattr(logging, log_level)) |
| |
| |
| console_handler = logging.StreamHandler() |
| console_formatter = logging.Formatter( |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| console_handler.setFormatter(console_formatter) |
| logger.addHandler(console_handler) |
| |
| |
| if log_file: |
| file_handler = logging.FileHandler(log_file) |
| file_formatter = logging.Formatter( |
| '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s' |
| ) |
| file_handler.setFormatter(file_formatter) |
| logger.addHandler(file_handler) |
| |
| return logger |
|
|
|
|
| |
| def default_tensor_recovery() -> torch.Tensor: |
| """Default recovery strategy for tensor operations.""" |
| return torch.zeros(1, dtype=torch.long) |
|
|
|
|
| def default_model_recovery() -> Dict[str, torch.Tensor]: |
| """Default recovery strategy for model operations.""" |
| return {"output": torch.zeros(1, dtype=torch.float32)} |
|
|
|
|
| |
| error_manager.register_recovery_strategy(RuntimeError, default_tensor_recovery) |
| error_manager.register_recovery_strategy(ModelError, default_model_recovery) |