| from transformers import ( |
| ViTForImageClassification, |
| ViTImageProcessor, |
| TrainingArguments, |
| Trainer, |
| ) |
| from datasets import load_dataset |
| from .utils import ROOT_DIR |
|
|
|
|
| def train(): |
| |
| dataset = load_dataset("mnist") |
| dataset = dataset.rename_column("label", "labels") |
|
|
| |
| small_train_size = 2000 |
| small_test_size = 500 |
|
|
| dataset["train"] = dataset["train"].select(range(small_train_size)) |
| dataset["test"] = dataset["test"].select(range(small_test_size)) |
|
|
| |
| processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") |
|
|
| def transform(examples): |
| |
| images = [img.convert("RGB") for img in examples["image"]] |
| inputs = processor(images=images, return_tensors="pt") |
| inputs["labels"] = examples["labels"] |
| return inputs |
|
|
| |
| dataset.set_transform(transform) |
|
|
| |
| model = ViTForImageClassification.from_pretrained( |
| "google/vit-base-patch16-224", |
| num_labels=10, |
| id2label={str(i): str(i) for i in range(10)}, |
| label2id={str(i): i for i in range(10)}, |
| ignore_mismatched_sizes=True, |
| ) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir="./results", |
| remove_unused_columns=False, |
| per_device_train_batch_size=16, |
| eval_strategy="steps", |
| num_train_epochs=3, |
| fp16=False, |
| save_steps=500, |
| eval_steps=500, |
| logging_steps=100, |
| learning_rate=2e-4, |
| push_to_hub=False, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset["train"], |
| eval_dataset=dataset["test"], |
| ) |
|
|
| trainer.train() |
|
|
| |
| model.save_pretrained(ROOT_DIR) |
| processor.save_pretrained(ROOT_DIR) |
|
|