| """ |
| Configuration Management Module |
| |
| This module provides secure, robust configuration management with: |
| - Environment variable handling with secure defaults |
| - Cache directory management with automatic fallbacks |
| - Comprehensive logging and error handling |
| - Security best practices for sensitive data |
| - Backward compatibility with existing code |
| |
| Environment Variables: |
| HF_TOKEN: HuggingFace API token (required for API access) |
| HF_HOME: Primary cache directory for HuggingFace models |
| TRANSFORMERS_CACHE: Alternative cache directory path |
| MAX_WORKERS: Maximum worker threads (default: 4) |
| CACHE_TTL: Cache time-to-live in seconds (default: 3600) |
| DB_PATH: Database file path (default: sessions.db) |
| LOG_LEVEL: Logging level (default: INFO) |
| LOG_FORMAT: Log format (default: json) |
| |
| Security Notes: |
| - Never commit .env files to version control |
| - Use environment variables for all sensitive data |
| - Cache directories are automatically secured with proper permissions |
| """ |
|
|
| import os |
| import logging |
| from pathlib import Path |
| from typing import Optional |
| from pydantic_settings import BaseSettings |
| from pydantic import Field, validator |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CacheDirectoryManager: |
| """ |
| Manages cache directory with secure fallback mechanism. |
| |
| Implements: |
| - Multi-level fallback strategy |
| - Permission validation |
| - Automatic directory creation |
| - Security best practices |
| """ |
| |
| @staticmethod |
| def get_cache_directory() -> str: |
| """ |
| Get cache directory with secure fallback chain. |
| |
| Priority order: |
| 1. HF_HOME environment variable |
| 2. TRANSFORMERS_CACHE environment variable |
| 3. User home directory (~/.cache/huggingface) |
| 4. User-specific fallback directory |
| 5. Temporary directory (last resort) |
| |
| Returns: |
| str: Path to writable cache directory |
| """ |
| |
| |
| |
| is_docker = os.path.exists("/.dockerenv") or os.path.exists("/tmp") |
| |
| cache_candidates = [ |
| os.getenv("HF_HOME"), |
| os.getenv("TRANSFORMERS_CACHE"), |
| |
| "/tmp/huggingface_cache" if is_docker else None, |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") and not is_docker else None, |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") and not is_docker else None, |
| "/tmp/huggingface_cache" if not is_docker else None, |
| "/tmp/huggingface" |
| ] |
| |
| for cache_dir in cache_candidates: |
| if not cache_dir: |
| continue |
| |
| try: |
| |
| cache_path = Path(cache_dir) |
| cache_path.mkdir(parents=True, exist_ok=True) |
| |
| |
| try: |
| os.chmod(cache_path, 0o755) |
| except (OSError, PermissionError): |
| |
| pass |
| |
| |
| test_file = cache_path / ".write_test" |
| try: |
| test_file.write_text("test") |
| test_file.unlink() |
| |
| logger.info(f"✓ Cache directory verified: {cache_dir}") |
| return str(cache_path) |
| |
| except (PermissionError, OSError) as e: |
| logger.debug(f"Write test failed for {cache_dir}: {e}") |
| continue |
| |
| except (PermissionError, OSError) as e: |
| logger.debug(f"Could not create/access {cache_dir}: {e}") |
| continue |
| |
| |
| fallback = "/tmp/huggingface_emergency" |
| try: |
| Path(fallback).mkdir(parents=True, exist_ok=True) |
| logger.warning(f"Using emergency fallback cache: {fallback}") |
| return fallback |
| except Exception as e: |
| logger.error(f"Emergency fallback also failed: {e}") |
| |
| return "/tmp/huggingface" |
|
|
|
|
| class Settings(BaseSettings): |
| """ |
| Application settings with secure defaults and validation. |
| |
| Backward Compatibility: |
| - All existing attributes are preserved |
| - hf_token is accessible as string (via property) |
| - hf_cache_dir is accessible as property (works like before) |
| - All defaults match original implementation |
| """ |
| |
| |
| |
| |
| hf_token: str = Field( |
| default="", |
| description="HuggingFace API token", |
| env="HF_TOKEN" |
| ) |
| |
| @validator("hf_token", pre=True) |
| def validate_hf_token(cls, v): |
| """Validate HF token (backward compatible)""" |
| if v is None: |
| return "" |
| token = str(v) if v else "" |
| if not token: |
| logger.debug("HF_TOKEN not set") |
| return token |
| |
| @property |
| def hf_cache_dir(self) -> str: |
| """ |
| Get cache directory with automatic fallback and validation. |
| |
| BACKWARD COMPAT: Works like the original hf_cache_dir field. |
| |
| Returns: |
| str: Path to writable cache directory |
| """ |
| if not hasattr(self, '_cached_cache_dir'): |
| try: |
| self._cached_cache_dir = CacheDirectoryManager.get_cache_directory() |
| except Exception as e: |
| logger.error(f"Cache directory setup failed: {e}") |
| |
| fallback = os.getenv("HF_HOME", "/tmp/huggingface") |
| Path(fallback).mkdir(parents=True, exist_ok=True) |
| self._cached_cache_dir = fallback |
| |
| return self._cached_cache_dir |
| |
| |
| |
| zerogpu_base_url: str = Field( |
| default="http://your-pod-ip:8000", |
| description="ZeroGPU Chat API base URL (RunPod endpoint)", |
| env="ZEROGPU_BASE_URL" |
| ) |
| |
| zerogpu_email: str = Field( |
| default="", |
| description="ZeroGPU Chat API email for authentication (required)", |
| env="ZEROGPU_EMAIL" |
| ) |
| |
| zerogpu_password: str = Field( |
| default="", |
| description="ZeroGPU Chat API password for authentication (required)", |
| env="ZEROGPU_PASSWORD" |
| ) |
| |
| |
| user_input_max_tokens: int = Field( |
| default=32000, |
| description="Maximum tokens dedicated for user input (prioritized over context)", |
| env="USER_INPUT_MAX_TOKENS" |
| ) |
| |
| context_preparation_budget: int = Field( |
| default=115000, |
| description="Maximum tokens for context preparation (includes user input + context)", |
| env="CONTEXT_PREPARATION_BUDGET" |
| ) |
| |
| context_pruning_threshold: int = Field( |
| default=115000, |
| description="Context pruning threshold (should match context_preparation_budget)", |
| env="CONTEXT_PRUNING_THRESHOLD" |
| ) |
| |
| prioritize_user_input: bool = Field( |
| default=True, |
| description="Always prioritize user input over historical context", |
| env="PRIORITIZE_USER_INPUT" |
| ) |
| |
| |
| zerogpu_model_context_window: int = Field( |
| default=8192, |
| description="Maximum context window for ZeroGPU Chat API model (input + output tokens). Adjust based on your deployed model.", |
| env="ZEROGPU_MODEL_CONTEXT_WINDOW" |
| ) |
| |
| @validator("zerogpu_base_url", pre=True) |
| def validate_zerogpu_base_url(cls, v): |
| """Validate ZeroGPU base URL""" |
| if v is None: |
| return "http://your-pod-ip:8000" |
| url = str(v).strip() |
| |
| if url.endswith('/'): |
| url = url[:-1] |
| return url |
| |
| @validator("zerogpu_email", pre=True) |
| def validate_zerogpu_email(cls, v): |
| """Validate ZeroGPU email""" |
| if v is None: |
| return "" |
| email = str(v).strip() |
| if email and '@' not in email: |
| logger.warning("ZEROGPU_EMAIL may not be a valid email address") |
| return email |
| |
| @validator("zerogpu_password", pre=True) |
| def validate_zerogpu_password(cls, v): |
| """Validate ZeroGPU password""" |
| if v is None: |
| return "" |
| return str(v).strip() |
| |
| @validator("user_input_max_tokens", pre=True) |
| def validate_user_input_tokens(cls, v): |
| """Validate user input token limit""" |
| val = int(v) if v else 32000 |
| return max(1000, min(50000, val)) |
| |
| @validator("context_preparation_budget", pre=True) |
| def validate_context_budget(cls, v): |
| """Validate context preparation budget""" |
| val = int(v) if v else 115000 |
| return max(4000, min(125000, val)) |
| |
| @validator("context_pruning_threshold", pre=True) |
| def validate_pruning_threshold(cls, v): |
| """Validate context pruning threshold""" |
| val = int(v) if v else 115000 |
| return max(4000, min(125000, val)) |
| |
| @validator("zerogpu_model_context_window", pre=True) |
| def validate_context_window(cls, v): |
| """Validate context window size""" |
| val = int(v) if v else 8192 |
| return max(1000, min(200000, val)) |
| |
| |
| |
| default_model: str = Field( |
| default="meta-llama/Llama-3.1-8B-Instruct:cerebras", |
| description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)" |
| ) |
| |
| embedding_model: str = Field( |
| default="intfloat/e5-large-v2", |
| description="Model for embeddings (upgraded: 1024-dim embeddings)" |
| ) |
| |
| classification_model: str = Field( |
| default="meta-llama/Llama-3.1-8B-Instruct:cerebras", |
| description="Model for classification tasks (Cerebras deployment)" |
| ) |
| |
| |
| |
| max_workers: int = Field( |
| default=4, |
| description="Maximum worker threads for parallel processing", |
| env="MAX_WORKERS" |
| ) |
| |
| @validator("max_workers", pre=True) |
| def validate_max_workers(cls, v): |
| """Validate and convert max_workers (backward compatible)""" |
| if v is None: |
| return 4 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| logger.warning(f"Invalid MAX_WORKERS value: {v}, using default 4") |
| return 4 |
| try: |
| val = int(v) |
| return max(1, min(16, val)) |
| except (ValueError, TypeError): |
| return 4 |
| |
| cache_ttl: int = Field( |
| default=3600, |
| description="Cache time-to-live in seconds", |
| env="CACHE_TTL" |
| ) |
| |
| @validator("cache_ttl", pre=True) |
| def validate_cache_ttl(cls, v): |
| """Validate cache TTL (backward compatible)""" |
| if v is None: |
| return 3600 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| return 3600 |
| try: |
| return max(0, int(v)) |
| except (ValueError, TypeError): |
| return 3600 |
| |
| |
| |
| db_path: str = Field( |
| default="sessions.db", |
| description="Path to SQLite database file", |
| env="DB_PATH" |
| ) |
| |
| @validator("db_path", pre=True) |
| def validate_db_path(cls, v): |
| """Validate db_path with Docker fallback (backward compatible)""" |
| if v is None: |
| |
| if os.path.exists("/.dockerenv") or os.path.exists("/tmp"): |
| return "/tmp/sessions.db" |
| return "sessions.db" |
| return str(v) |
| |
| faiss_index_path: str = Field( |
| default="embeddings.faiss", |
| description="Path to FAISS index file", |
| env="FAISS_INDEX_PATH" |
| ) |
| |
| @validator("faiss_index_path", pre=True) |
| def validate_faiss_path(cls, v): |
| """Validate faiss path with Docker fallback (backward compatible)""" |
| if v is None: |
| |
| if os.path.exists("/.dockerenv") or os.path.exists("/tmp"): |
| return "/tmp/embeddings.faiss" |
| return "embeddings.faiss" |
| return str(v) |
| |
| |
| |
| session_timeout: int = Field( |
| default=3600, |
| description="Session timeout in seconds", |
| env="SESSION_TIMEOUT" |
| ) |
| |
| @validator("session_timeout", pre=True) |
| def validate_session_timeout(cls, v): |
| """Validate session timeout (backward compatible)""" |
| if v is None: |
| return 3600 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| return 3600 |
| try: |
| return max(60, int(v)) |
| except (ValueError, TypeError): |
| return 3600 |
| |
| max_session_size_mb: int = Field( |
| default=10, |
| description="Maximum session size in megabytes", |
| env="MAX_SESSION_SIZE_MB" |
| ) |
| |
| @validator("max_session_size_mb", pre=True) |
| def validate_max_session_size(cls, v): |
| """Validate max session size (backward compatible)""" |
| if v is None: |
| return 10 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| return 10 |
| try: |
| return max(1, min(100, int(v))) |
| except (ValueError, TypeError): |
| return 10 |
| |
| |
| |
| mobile_max_tokens: int = Field( |
| default=800, |
| description="Maximum tokens for mobile responses", |
| env="MOBILE_MAX_TOKENS" |
| ) |
| |
| @validator("mobile_max_tokens", pre=True) |
| def validate_mobile_max_tokens(cls, v): |
| """Validate mobile max tokens (backward compatible)""" |
| if v is None: |
| return 800 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| return 800 |
| try: |
| return max(100, min(2000, int(v))) |
| except (ValueError, TypeError): |
| return 800 |
| |
| mobile_timeout: int = Field( |
| default=15000, |
| description="Mobile request timeout in milliseconds", |
| env="MOBILE_TIMEOUT" |
| ) |
| |
| @validator("mobile_timeout", pre=True) |
| def validate_mobile_timeout(cls, v): |
| """Validate mobile timeout (backward compatible)""" |
| if v is None: |
| return 15000 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| return 15000 |
| try: |
| return max(5000, min(60000, int(v))) |
| except (ValueError, TypeError): |
| return 15000 |
| |
| |
| |
| gradio_port: int = Field( |
| default=7860, |
| description="Gradio server port", |
| env="GRADIO_PORT" |
| ) |
| |
| @validator("gradio_port", pre=True) |
| def validate_gradio_port(cls, v): |
| """Validate gradio port (backward compatible)""" |
| if v is None: |
| return 7860 |
| if isinstance(v, str): |
| try: |
| v = int(v) |
| except ValueError: |
| return 7860 |
| try: |
| return max(1024, min(65535, int(v))) |
| except (ValueError, TypeError): |
| return 7860 |
| |
| gradio_host: str = Field( |
| default="0.0.0.0", |
| description="Gradio server host", |
| env="GRADIO_HOST" |
| ) |
| |
| |
| |
| log_level: str = Field( |
| default="INFO", |
| description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", |
| env="LOG_LEVEL" |
| ) |
| |
| @validator("log_level") |
| def validate_log_level(cls, v): |
| """Validate log level (backward compatible)""" |
| if not v: |
| return "INFO" |
| valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] |
| if v.upper() not in valid_levels: |
| logger.warning(f"Invalid log level: {v}, using INFO") |
| return "INFO" |
| return v.upper() |
| |
| log_format: str = Field( |
| default="json", |
| description="Log format (json or text)", |
| env="LOG_FORMAT" |
| ) |
| |
| @validator("log_format") |
| def validate_log_format(cls, v): |
| """Validate log format (backward compatible)""" |
| if not v: |
| return "json" |
| if v.lower() not in ["json", "text"]: |
| logger.warning(f"Invalid log format: {v}, using json") |
| return "json" |
| return v.lower() |
| |
| |
| |
| class Config: |
| """Pydantic configuration""" |
| env_file = ".env" |
| env_file_encoding = "utf-8" |
| case_sensitive = False |
| validate_assignment = True |
| |
| extra = "ignore" |
| |
| |
| |
| def validate_configuration(self) -> bool: |
| """ |
| Validate configuration and log status. |
| |
| Returns: |
| bool: True if configuration is valid, False otherwise |
| """ |
| try: |
| |
| cache_dir = self.hf_cache_dir |
| if logger.isEnabledFor(logging.INFO): |
| logger.info("Configuration validated:") |
| logger.info(f" - Cache directory: {cache_dir}") |
| logger.info(f" - Max workers: {self.max_workers}") |
| logger.info(f" - Log level: {self.log_level}") |
| logger.info(f" - HF token: {'Set' if self.hf_token else 'Not set'}") |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"Configuration validation failed: {e}") |
| return False |
|
|
|
|
| |
|
|
| def get_settings() -> Settings: |
| """ |
| Get or create global settings instance. |
| |
| Returns: |
| Settings: Global settings instance |
| |
| Note: |
| This function ensures settings are loaded once and cached. |
| """ |
| if not hasattr(get_settings, '_instance'): |
| get_settings._instance = Settings() |
| |
| try: |
| get_settings._instance.validate_configuration() |
| except Exception as e: |
| logger.warning(f"Configuration validation warning: {e}") |
| return get_settings._instance |
|
|
|
|
| |
| settings = get_settings() |
|
|
| |
| if logger.isEnabledFor(logging.INFO): |
| try: |
| logger.info("=" * 60) |
| logger.info("Configuration Loaded") |
| logger.info("=" * 60) |
| logger.info(f"Cache directory: {settings.hf_cache_dir}") |
| logger.info(f"Max workers: {settings.max_workers}") |
| logger.info(f"Log level: {settings.log_level}") |
| logger.info("=" * 60) |
| except Exception as e: |
| logger.debug(f"Configuration logging skipped: {e}") |
|
|