vectorplasticity's picture
Add database module with models
57a645e verified
"""
Database models and initialization for Universal Model Trainer
"""
from sqlalchemy import (
Column, Integer, String, Text, Float, Boolean, DateTime,
ForeignKey, JSON, Enum, create_engine
)
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, relationship, declarative_base
from datetime import datetime
import enum
import os
from app.config import settings
# Create async engine
DATABASE_PATH = settings.DATABASE_URL.replace("sqlite:///./", "").replace("sqlite://", "")
os.makedirs(os.path.dirname(DATABASE_PATH) if os.path.dirname(DATABASE_PATH) else ".", exist_ok=True)
# Use async engine for SQLite
ASYNC_DB_URL = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
engine = create_async_engine(ASYNC_DB_URL, echo=settings.DEBUG)
AsyncSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
Base = declarative_base()
class JobStatus(str, enum.Enum):
"""Training job status enum."""
PENDING = "pending"
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
PAUSED = "paused"
class TaskType(str, enum.Enum):
"""Supported task types."""
CAUSAL_LM = "causal-lm"
SEQ2SEQ = "seq2seq"
TOKEN_CLASSIFICATION = "token-classification"
SEQUENCE_CLASSIFICATION = "sequence-classification"
QUESTION_ANSWERING = "question-answering"
SUMMARIZATION = "summarization"
TRANSLATION = "translation"
TEXT_CLASSIFICATION = "text-classification"
MASKED_LM = "masked-lm"
VISION_CLASSIFICATION = "vision-classification"
VISION_SEGMENTATION = "vision-segmentation"
AUDIO_CLASSIFICATION = "audio-classification"
AUDIO_TRANSCRIPTION = "audio-transcription"
class TrainingJob(Base):
"""Model for training jobs."""
__tablename__ = "training_jobs"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(String(36), unique=True, index=True, nullable=False)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
# Task configuration
task_type = Column(String(50), nullable=False)
base_model = Column(String(255), nullable=False)
output_model_name = Column(String(255), nullable=True)
# Dataset configuration
dataset_source = Column(String(50), default="huggingface")
dataset_name = Column(String(255), nullable=True)
dataset_config = Column(String(100), nullable=True)
dataset_split = Column(String(50), default="train")
custom_dataset_path = Column(String(512), nullable=True)
# Training arguments
training_args = Column(JSON, default=dict)
peft_config = Column(JSON, nullable=True)
deepspeed_config = Column(JSON, nullable=True)
# Status and progress
status = Column(String(20), default=JobStatus.PENDING.value)
progress = Column(Float, default=0.0)
current_epoch = Column(Integer, default=0)
total_epochs = Column(Integer, default=3)
current_step = Column(Integer, default=0)
total_steps = Column(Integer, default=0)
# Metrics
train_loss = Column(Float, nullable=True)
eval_loss = Column(Float, nullable=True)
learning_rate = Column(Float, nullable=True)
metrics = Column(JSON, default=dict)
# Output
output_path = Column(String(512), nullable=True)
hub_model_id = Column(String(255), nullable=True)
model_card = Column(Text, nullable=True)
# Error handling
error_message = Column(Text, nullable=True)
traceback = Column(Text, nullable=True)
retry_count = Column(Integer, default=0)
max_retries = Column(Integer, default=3)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# User info
created_by = Column(String(100), nullable=True)
tags = Column(JSON, default=list)
# Relationships
checkpoints = relationship("Checkpoint", back_populates="job", cascade="all, delete-orphan")
logs = relationship("TrainingLog", back_populates="job", cascade="all, delete-orphan")
def to_dict(self):
return {
"id": self.id,
"job_id": self.job_id,
"name": self.name,
"description": self.description,
"task_type": self.task_type,
"base_model": self.base_model,
"output_model_name": self.output_model_name,
"dataset_name": self.dataset_name,
"status": self.status,
"progress": self.progress,
"current_epoch": self.current_epoch,
"total_epochs": self.total_epochs,
"current_step": self.current_step,
"total_steps": self.total_steps,
"train_loss": self.train_loss,
"eval_loss": self.eval_loss,
"metrics": self.metrics,
"output_path": self.output_path,
"hub_model_id": self.hub_model_id,
"error_message": self.error_message,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"tags": self.tags
}
class Checkpoint(Base):
"""Model for training checkpoints."""
__tablename__ = "checkpoints"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False)
checkpoint_name = Column(String(255), nullable=False)
checkpoint_path = Column(String(512), nullable=False)
step = Column(Integer, nullable=False)
epoch = Column(Float, nullable=False)
loss = Column(Float, nullable=True)
metrics = Column(JSON, default=dict)
is_best = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
size_mb = Column(Float, nullable=True)
# Relationship
job = relationship("TrainingJob", back_populates="checkpoints")
class TrainingLog(Base):
"""Model for training logs."""
__tablename__ = "training_logs"
id = Column(Integer, primary_key=True, index=True)
job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False)
level = Column(String(10), default="INFO")
message = Column(Text, nullable=False)
step = Column(Integer, nullable=True)
epoch = Column(Float, nullable=True)
loss = Column(Float, nullable=True)
learning_rate = Column(Float, nullable=True)
metrics = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
# Relationship
job = relationship("TrainingJob", back_populates="logs")
class ModelRegistry(Base):
"""Registry of trained and available models."""
__tablename__ = "model_registry"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), unique=True, nullable=False)
model_id = Column(String(255), nullable=False)
task_type = Column(String(50), nullable=False)
description = Column(Text, nullable=True)
tags = Column(JSON, default=list)
parameters = Column(String(20), nullable=True)
is_local = Column(Boolean, default=False)
local_path = Column(String(512), nullable=True)
hub_url = Column(String(512), nullable=True)
is_trained = Column(Boolean, default=False)
training_job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
last_used = Column(DateTime, nullable=True)
class DatasetCache(Base):
"""Cache for downloaded datasets."""
__tablename__ = "dataset_cache"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), unique=True, nullable=False)
config = Column(String(100), nullable=True)
split = Column(String(50), nullable=True)
local_path = Column(String(512), nullable=False)
size_mb = Column(Float, nullable=True)
num_samples = Column(Integer, nullable=True)
features = Column(JSON, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
last_accessed = Column(DateTime, default=datetime.utcnow)
async def init_db():
"""Initialize database tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_db():
"""Get database session."""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()