""" Database models and initialization for Universal Model Trainer """ from sqlalchemy import ( Column, Integer, String, Text, Float, Boolean, DateTime, ForeignKey, JSON, Enum, create_engine ) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, relationship, declarative_base from datetime import datetime import enum import os from app.config import settings # Create async engine DATABASE_PATH = settings.DATABASE_URL.replace("sqlite:///./", "").replace("sqlite://", "") os.makedirs(os.path.dirname(DATABASE_PATH) if os.path.dirname(DATABASE_PATH) else ".", exist_ok=True) # Use async engine for SQLite ASYNC_DB_URL = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://") engine = create_async_engine(ASYNC_DB_URL, echo=settings.DEBUG) AsyncSessionLocal = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False ) Base = declarative_base() class JobStatus(str, enum.Enum): """Training job status enum.""" PENDING = "pending" QUEUED = "queued" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" PAUSED = "paused" class TaskType(str, enum.Enum): """Supported task types.""" CAUSAL_LM = "causal-lm" SEQ2SEQ = "seq2seq" TOKEN_CLASSIFICATION = "token-classification" SEQUENCE_CLASSIFICATION = "sequence-classification" QUESTION_ANSWERING = "question-answering" SUMMARIZATION = "summarization" TRANSLATION = "translation" TEXT_CLASSIFICATION = "text-classification" MASKED_LM = "masked-lm" VISION_CLASSIFICATION = "vision-classification" VISION_SEGMENTATION = "vision-segmentation" AUDIO_CLASSIFICATION = "audio-classification" AUDIO_TRANSCRIPTION = "audio-transcription" class TrainingJob(Base): """Model for training jobs.""" __tablename__ = "training_jobs" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(36), unique=True, index=True, nullable=False) name = Column(String(255), nullable=False) description = Column(Text, nullable=True) # Task configuration task_type = Column(String(50), nullable=False) base_model = Column(String(255), nullable=False) output_model_name = Column(String(255), nullable=True) # Dataset configuration dataset_source = Column(String(50), default="huggingface") dataset_name = Column(String(255), nullable=True) dataset_config = Column(String(100), nullable=True) dataset_split = Column(String(50), default="train") custom_dataset_path = Column(String(512), nullable=True) # Training arguments training_args = Column(JSON, default=dict) peft_config = Column(JSON, nullable=True) deepspeed_config = Column(JSON, nullable=True) # Status and progress status = Column(String(20), default=JobStatus.PENDING.value) progress = Column(Float, default=0.0) current_epoch = Column(Integer, default=0) total_epochs = Column(Integer, default=3) current_step = Column(Integer, default=0) total_steps = Column(Integer, default=0) # Metrics train_loss = Column(Float, nullable=True) eval_loss = Column(Float, nullable=True) learning_rate = Column(Float, nullable=True) metrics = Column(JSON, default=dict) # Output output_path = Column(String(512), nullable=True) hub_model_id = Column(String(255), nullable=True) model_card = Column(Text, nullable=True) # Error handling error_message = Column(Text, nullable=True) traceback = Column(Text, nullable=True) retry_count = Column(Integer, default=0) max_retries = Column(Integer, default=3) # Timestamps created_at = Column(DateTime, default=datetime.utcnow) started_at = Column(DateTime, nullable=True) completed_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) # User info created_by = Column(String(100), nullable=True) tags = Column(JSON, default=list) # Relationships checkpoints = relationship("Checkpoint", back_populates="job", cascade="all, delete-orphan") logs = relationship("TrainingLog", back_populates="job", cascade="all, delete-orphan") def to_dict(self): return { "id": self.id, "job_id": self.job_id, "name": self.name, "description": self.description, "task_type": self.task_type, "base_model": self.base_model, "output_model_name": self.output_model_name, "dataset_name": self.dataset_name, "status": self.status, "progress": self.progress, "current_epoch": self.current_epoch, "total_epochs": self.total_epochs, "current_step": self.current_step, "total_steps": self.total_steps, "train_loss": self.train_loss, "eval_loss": self.eval_loss, "metrics": self.metrics, "output_path": self.output_path, "hub_model_id": self.hub_model_id, "error_message": self.error_message, "created_at": self.created_at.isoformat() if self.created_at else None, "started_at": self.started_at.isoformat() if self.started_at else None, "completed_at": self.completed_at.isoformat() if self.completed_at else None, "tags": self.tags } class Checkpoint(Base): """Model for training checkpoints.""" __tablename__ = "checkpoints" id = Column(Integer, primary_key=True, index=True) job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False) checkpoint_name = Column(String(255), nullable=False) checkpoint_path = Column(String(512), nullable=False) step = Column(Integer, nullable=False) epoch = Column(Float, nullable=False) loss = Column(Float, nullable=True) metrics = Column(JSON, default=dict) is_best = Column(Boolean, default=False) created_at = Column(DateTime, default=datetime.utcnow) size_mb = Column(Float, nullable=True) # Relationship job = relationship("TrainingJob", back_populates="checkpoints") class TrainingLog(Base): """Model for training logs.""" __tablename__ = "training_logs" id = Column(Integer, primary_key=True, index=True) job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False) level = Column(String(10), default="INFO") message = Column(Text, nullable=False) step = Column(Integer, nullable=True) epoch = Column(Float, nullable=True) loss = Column(Float, nullable=True) learning_rate = Column(Float, nullable=True) metrics = Column(JSON, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) # Relationship job = relationship("TrainingJob", back_populates="logs") class ModelRegistry(Base): """Registry of trained and available models.""" __tablename__ = "model_registry" id = Column(Integer, primary_key=True, index=True) name = Column(String(255), unique=True, nullable=False) model_id = Column(String(255), nullable=False) task_type = Column(String(50), nullable=False) description = Column(Text, nullable=True) tags = Column(JSON, default=list) parameters = Column(String(20), nullable=True) is_local = Column(Boolean, default=False) local_path = Column(String(512), nullable=True) hub_url = Column(String(512), nullable=True) is_trained = Column(Boolean, default=False) training_job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=True) created_at = Column(DateTime, default=datetime.utcnow) last_used = Column(DateTime, nullable=True) class DatasetCache(Base): """Cache for downloaded datasets.""" __tablename__ = "dataset_cache" id = Column(Integer, primary_key=True, index=True) name = Column(String(255), unique=True, nullable=False) config = Column(String(100), nullable=True) split = Column(String(50), nullable=True) local_path = Column(String(512), nullable=False) size_mb = Column(Float, nullable=True) num_samples = Column(Integer, nullable=True) features = Column(JSON, nullable=True) created_at = Column(DateTime, default=datetime.utcnow) last_accessed = Column(DateTime, default=datetime.utcnow) async def init_db(): """Initialize database tables.""" async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async def get_db(): """Get database session.""" async with AsyncSessionLocal() as session: try: yield session finally: await session.close()