vectorplasticity commited on
Commit
032cfc1
·
verified ·
1 Parent(s): 6566fa8

Update app/services/training_service.py

Browse files
Files changed (1) hide show
  1. app/services/training_service.py +2 -2
app/services/training_service.py CHANGED
@@ -26,7 +26,7 @@ from transformers import (
26
  )
27
  from datasets import load_dataset, Dataset
28
  from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
29
-
30
  from app.config import settings
31
  from app.database import AsyncSessionLocal, TrainingJob, TrainingLog, JobStatus
32
  import importlib
@@ -476,7 +476,7 @@ class TrainingService:
476
  progress_callback = ProgressCallback(total_steps)
477
 
478
  # Trainer
479
- trainer = Trainer(
480
  model=model,
481
  args=training_arguments,
482
  train_dataset=tokenized_train,
 
26
  )
27
  from datasets import load_dataset, Dataset
28
  from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
29
+ from trl import SFTTrainer # instead of from transformers import Trainer
30
  from app.config import settings
31
  from app.database import AsyncSessionLocal, TrainingJob, TrainingLog, JobStatus
32
  import importlib
 
476
  progress_callback = ProgressCallback(total_steps)
477
 
478
  # Trainer
479
+ trainer = SFTTrainer(
480
  model=model,
481
  args=training_arguments,
482
  train_dataset=tokenized_train,