Update app/services/training_service.py
Browse files
app/services/training_service.py
CHANGED
|
@@ -26,7 +26,7 @@ from transformers import (
|
|
| 26 |
)
|
| 27 |
from datasets import load_dataset, Dataset
|
| 28 |
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
|
| 29 |
-
|
| 30 |
from app.config import settings
|
| 31 |
from app.database import AsyncSessionLocal, TrainingJob, TrainingLog, JobStatus
|
| 32 |
import importlib
|
|
@@ -476,7 +476,7 @@ class TrainingService:
|
|
| 476 |
progress_callback = ProgressCallback(total_steps)
|
| 477 |
|
| 478 |
# Trainer
|
| 479 |
-
trainer =
|
| 480 |
model=model,
|
| 481 |
args=training_arguments,
|
| 482 |
train_dataset=tokenized_train,
|
|
|
|
| 26 |
)
|
| 27 |
from datasets import load_dataset, Dataset
|
| 28 |
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
|
| 29 |
+
from trl import SFTTrainer # instead of from transformers import Trainer
|
| 30 |
from app.config import settings
|
| 31 |
from app.database import AsyncSessionLocal, TrainingJob, TrainingLog, JobStatus
|
| 32 |
import importlib
|
|
|
|
| 476 |
progress_callback = ProgressCallback(total_steps)
|
| 477 |
|
| 478 |
# Trainer
|
| 479 |
+
trainer = SFTTrainer(
|
| 480 |
model=model,
|
| 481 |
args=training_arguments,
|
| 482 |
train_dataset=tokenized_train,
|