vectorplasticity commited on
Commit
57a645e
·
verified ·
1 Parent(s): f4f6fd8

Add database module with models

Browse files
Files changed (1) hide show
  1. app/database.py +241 -0
app/database.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database models and initialization for Universal Model Trainer
3
+ """
4
+
5
+ from sqlalchemy import (
6
+ Column, Integer, String, Text, Float, Boolean, DateTime,
7
+ ForeignKey, JSON, Enum, create_engine
8
+ )
9
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
10
+ from sqlalchemy.orm import sessionmaker, relationship, declarative_base
11
+ from datetime import datetime
12
+ import enum
13
+ import os
14
+
15
+ from app.config import settings
16
+
17
+ # Create async engine
18
+ DATABASE_PATH = settings.DATABASE_URL.replace("sqlite:///./", "").replace("sqlite://", "")
19
+ os.makedirs(os.path.dirname(DATABASE_PATH) if os.path.dirname(DATABASE_PATH) else ".", exist_ok=True)
20
+
21
+ # Use async engine for SQLite
22
+ ASYNC_DB_URL = settings.DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
23
+ engine = create_async_engine(ASYNC_DB_URL, echo=settings.DEBUG)
24
+
25
+ AsyncSessionLocal = sessionmaker(
26
+ engine, class_=AsyncSession, expire_on_commit=False
27
+ )
28
+
29
+ Base = declarative_base()
30
+
31
+
32
+ class JobStatus(str, enum.Enum):
33
+ """Training job status enum."""
34
+ PENDING = "pending"
35
+ QUEUED = "queued"
36
+ RUNNING = "running"
37
+ COMPLETED = "completed"
38
+ FAILED = "failed"
39
+ CANCELLED = "cancelled"
40
+ PAUSED = "paused"
41
+
42
+
43
+ class TaskType(str, enum.Enum):
44
+ """Supported task types."""
45
+ CAUSAL_LM = "causal-lm"
46
+ SEQ2SEQ = "seq2seq"
47
+ TOKEN_CLASSIFICATION = "token-classification"
48
+ SEQUENCE_CLASSIFICATION = "sequence-classification"
49
+ QUESTION_ANSWERING = "question-answering"
50
+ SUMMARIZATION = "summarization"
51
+ TRANSLATION = "translation"
52
+ TEXT_CLASSIFICATION = "text-classification"
53
+ MASKED_LM = "masked-lm"
54
+ VISION_CLASSIFICATION = "vision-classification"
55
+ VISION_SEGMENTATION = "vision-segmentation"
56
+ AUDIO_CLASSIFICATION = "audio-classification"
57
+ AUDIO_TRANSCRIPTION = "audio-transcription"
58
+
59
+
60
+ class TrainingJob(Base):
61
+ """Model for training jobs."""
62
+ __tablename__ = "training_jobs"
63
+
64
+ id = Column(Integer, primary_key=True, index=True)
65
+ job_id = Column(String(36), unique=True, index=True, nullable=False)
66
+ name = Column(String(255), nullable=False)
67
+ description = Column(Text, nullable=True)
68
+
69
+ # Task configuration
70
+ task_type = Column(String(50), nullable=False)
71
+ base_model = Column(String(255), nullable=False)
72
+ output_model_name = Column(String(255), nullable=True)
73
+
74
+ # Dataset configuration
75
+ dataset_source = Column(String(50), default="huggingface")
76
+ dataset_name = Column(String(255), nullable=True)
77
+ dataset_config = Column(String(100), nullable=True)
78
+ dataset_split = Column(String(50), default="train")
79
+ custom_dataset_path = Column(String(512), nullable=True)
80
+
81
+ # Training arguments
82
+ training_args = Column(JSON, default=dict)
83
+ peft_config = Column(JSON, nullable=True)
84
+ deepspeed_config = Column(JSON, nullable=True)
85
+
86
+ # Status and progress
87
+ status = Column(String(20), default=JobStatus.PENDING.value)
88
+ progress = Column(Float, default=0.0)
89
+ current_epoch = Column(Integer, default=0)
90
+ total_epochs = Column(Integer, default=3)
91
+ current_step = Column(Integer, default=0)
92
+ total_steps = Column(Integer, default=0)
93
+
94
+ # Metrics
95
+ train_loss = Column(Float, nullable=True)
96
+ eval_loss = Column(Float, nullable=True)
97
+ learning_rate = Column(Float, nullable=True)
98
+ metrics = Column(JSON, default=dict)
99
+
100
+ # Output
101
+ output_path = Column(String(512), nullable=True)
102
+ hub_model_id = Column(String(255), nullable=True)
103
+ model_card = Column(Text, nullable=True)
104
+
105
+ # Error handling
106
+ error_message = Column(Text, nullable=True)
107
+ traceback = Column(Text, nullable=True)
108
+ retry_count = Column(Integer, default=0)
109
+ max_retries = Column(Integer, default=3)
110
+
111
+ # Timestamps
112
+ created_at = Column(DateTime, default=datetime.utcnow)
113
+ started_at = Column(DateTime, nullable=True)
114
+ completed_at = Column(DateTime, nullable=True)
115
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
116
+
117
+ # User info
118
+ created_by = Column(String(100), nullable=True)
119
+ tags = Column(JSON, default=list)
120
+
121
+ # Relationships
122
+ checkpoints = relationship("Checkpoint", back_populates="job", cascade="all, delete-orphan")
123
+ logs = relationship("TrainingLog", back_populates="job", cascade="all, delete-orphan")
124
+
125
+ def to_dict(self):
126
+ return {
127
+ "id": self.id,
128
+ "job_id": self.job_id,
129
+ "name": self.name,
130
+ "description": self.description,
131
+ "task_type": self.task_type,
132
+ "base_model": self.base_model,
133
+ "output_model_name": self.output_model_name,
134
+ "dataset_name": self.dataset_name,
135
+ "status": self.status,
136
+ "progress": self.progress,
137
+ "current_epoch": self.current_epoch,
138
+ "total_epochs": self.total_epochs,
139
+ "current_step": self.current_step,
140
+ "total_steps": self.total_steps,
141
+ "train_loss": self.train_loss,
142
+ "eval_loss": self.eval_loss,
143
+ "metrics": self.metrics,
144
+ "output_path": self.output_path,
145
+ "hub_model_id": self.hub_model_id,
146
+ "error_message": self.error_message,
147
+ "created_at": self.created_at.isoformat() if self.created_at else None,
148
+ "started_at": self.started_at.isoformat() if self.started_at else None,
149
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
150
+ "tags": self.tags
151
+ }
152
+
153
+
154
+ class Checkpoint(Base):
155
+ """Model for training checkpoints."""
156
+ __tablename__ = "checkpoints"
157
+
158
+ id = Column(Integer, primary_key=True, index=True)
159
+ job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False)
160
+ checkpoint_name = Column(String(255), nullable=False)
161
+ checkpoint_path = Column(String(512), nullable=False)
162
+ step = Column(Integer, nullable=False)
163
+ epoch = Column(Float, nullable=False)
164
+ loss = Column(Float, nullable=True)
165
+ metrics = Column(JSON, default=dict)
166
+ is_best = Column(Boolean, default=False)
167
+ created_at = Column(DateTime, default=datetime.utcnow)
168
+ size_mb = Column(Float, nullable=True)
169
+
170
+ # Relationship
171
+ job = relationship("TrainingJob", back_populates="checkpoints")
172
+
173
+
174
+ class TrainingLog(Base):
175
+ """Model for training logs."""
176
+ __tablename__ = "training_logs"
177
+
178
+ id = Column(Integer, primary_key=True, index=True)
179
+ job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=False)
180
+ level = Column(String(10), default="INFO")
181
+ message = Column(Text, nullable=False)
182
+ step = Column(Integer, nullable=True)
183
+ epoch = Column(Float, nullable=True)
184
+ loss = Column(Float, nullable=True)
185
+ learning_rate = Column(Float, nullable=True)
186
+ metrics = Column(JSON, nullable=True)
187
+ created_at = Column(DateTime, default=datetime.utcnow)
188
+
189
+ # Relationship
190
+ job = relationship("TrainingJob", back_populates="logs")
191
+
192
+
193
+ class ModelRegistry(Base):
194
+ """Registry of trained and available models."""
195
+ __tablename__ = "model_registry"
196
+
197
+ id = Column(Integer, primary_key=True, index=True)
198
+ name = Column(String(255), unique=True, nullable=False)
199
+ model_id = Column(String(255), nullable=False)
200
+ task_type = Column(String(50), nullable=False)
201
+ description = Column(Text, nullable=True)
202
+ tags = Column(JSON, default=list)
203
+ parameters = Column(String(20), nullable=True)
204
+ is_local = Column(Boolean, default=False)
205
+ local_path = Column(String(512), nullable=True)
206
+ hub_url = Column(String(512), nullable=True)
207
+ is_trained = Column(Boolean, default=False)
208
+ training_job_id = Column(Integer, ForeignKey("training_jobs.id"), nullable=True)
209
+ created_at = Column(DateTime, default=datetime.utcnow)
210
+ last_used = Column(DateTime, nullable=True)
211
+
212
+
213
+ class DatasetCache(Base):
214
+ """Cache for downloaded datasets."""
215
+ __tablename__ = "dataset_cache"
216
+
217
+ id = Column(Integer, primary_key=True, index=True)
218
+ name = Column(String(255), unique=True, nullable=False)
219
+ config = Column(String(100), nullable=True)
220
+ split = Column(String(50), nullable=True)
221
+ local_path = Column(String(512), nullable=False)
222
+ size_mb = Column(Float, nullable=True)
223
+ num_samples = Column(Integer, nullable=True)
224
+ features = Column(JSON, nullable=True)
225
+ created_at = Column(DateTime, default=datetime.utcnow)
226
+ last_accessed = Column(DateTime, default=datetime.utcnow)
227
+
228
+
229
+ async def init_db():
230
+ """Initialize database tables."""
231
+ async with engine.begin() as conn:
232
+ await conn.run_sync(Base.metadata.create_all)
233
+
234
+
235
+ async def get_db():
236
+ """Get database session."""
237
+ async with AsyncSessionLocal() as session:
238
+ try:
239
+ yield session
240
+ finally:
241
+ await session.close()