| import os
|
| import torch
|
| from datasets import load_dataset, ClassLabel, Image
|
| from transformers import (
|
| ViTImageProcessor,
|
| ViTForImageClassification,
|
| TrainingArguments,
|
| Trainer,
|
| DefaultDataCollator,
|
| )
|
| import evaluate
|
| from torchvision.transforms import (
|
| CenterCrop,
|
| Compose,
|
| Normalize,
|
| RandomRotation,
|
| RandomResizedCrop,
|
| RandomHorizontalFlip,
|
| RandomAdjustSharpness,
|
| Resize,
|
| ToTensor,
|
| )
|
| import numpy as np
|
|
|
|
|
| MODEL_NAME = "google/vit-base-patch16-224"
|
| DATASET_DIR = "./dataset"
|
| OUTPUT_DIR = "./model"
|
| BATCH_SIZE = 16
|
| NUM_EPOCHS = 3
|
| LEARNING_RATE = 2e-5
|
|
|
| def main():
|
|
|
| print("Loading dataset...")
|
|
|
| data_files = {}
|
| if os.path.exists(os.path.join(DATASET_DIR, "train")):
|
| data_files["train"] = os.path.join(DATASET_DIR, "train")
|
| if os.path.exists(os.path.join(DATASET_DIR, "test")):
|
| data_files["test"] = os.path.join(DATASET_DIR, "test")
|
|
|
| if not data_files:
|
| print(f"Error: No data found in {DATASET_DIR}. Please organize data in 'train' and 'test' folders.")
|
| print("Expected structure: ./dataset/train/REAL, ./dataset/train/FAKE, etc.")
|
| return
|
|
|
|
|
|
|
| dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)
|
|
|
|
|
| labels = dataset["train"].features["label"].names
|
| id2label = {str(i): c for i, c in enumerate(labels)}
|
| label2id = {c: str(i) for i, c in enumerate(labels)}
|
| print(f"Labels found: {labels}")
|
|
|
|
|
| processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
|
| image_mean = processor.image_mean
|
| image_std = processor.image_std
|
| size = processor.size["height"]
|
|
|
| normalize = Normalize(mean=image_mean, std=image_std)
|
|
|
| _train_transforms = Compose([
|
| RandomResizedCrop(size),
|
| RandomHorizontalFlip(),
|
| RandomAdjustSharpness(2),
|
| ToTensor(),
|
| normalize,
|
| ])
|
|
|
| _val_transforms = Compose([
|
| Resize(size),
|
| CenterCrop(size),
|
| ToTensor(),
|
| normalize,
|
| ])
|
|
|
| def train_transforms(examples):
|
| examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
|
| return examples
|
|
|
| def val_transforms(examples):
|
| examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
|
| return examples
|
|
|
|
|
| print("Applying transforms...")
|
| dataset["train"].set_transform(train_transforms)
|
| if "test" in dataset:
|
| dataset["test"].set_transform(val_transforms)
|
|
|
|
|
| print(f"Loading model {MODEL_NAME}...")
|
| model = ViTForImageClassification.from_pretrained(
|
| MODEL_NAME,
|
| num_labels=len(labels),
|
| id2label=id2label,
|
| label2id=label2id,
|
| ignore_mismatched_sizes=True
|
| )
|
|
|
|
|
| metric = evaluate.load("accuracy")
|
| def compute_metrics(eval_pred):
|
| predictions = np.argmax(eval_pred.predictions, axis=1)
|
| return metric.compute(predictions=predictions, references=eval_pred.label_ids)
|
|
|
|
|
| args = TrainingArguments(
|
| output_dir=OUTPUT_DIR,
|
| remove_unused_columns=False,
|
| evaluation_strategy="epoch",
|
| save_strategy="epoch",
|
| learning_rate=LEARNING_RATE,
|
| per_device_train_batch_size=BATCH_SIZE,
|
| per_device_eval_batch_size=BATCH_SIZE,
|
| num_train_epochs=NUM_EPOCHS,
|
| warmup_ratio=0.1,
|
| logging_steps=10,
|
| load_best_model_at_end=True,
|
| metric_for_best_model="accuracy",
|
| push_to_hub=False,
|
| )
|
|
|
| collator = DefaultDataCollator()
|
|
|
| trainer = Trainer(
|
| model=model,
|
| args=args,
|
| train_dataset=dataset["train"],
|
| eval_dataset=dataset["test"] if "test" in dataset else None,
|
| tokenizer=processor,
|
| data_collator=collator,
|
| compute_metrics=compute_metrics,
|
| )
|
|
|
|
|
| print("Starting training...")
|
| trainer.train()
|
|
|
|
|
| print(f"Saving model to {OUTPUT_DIR}...")
|
| trainer.save_model(OUTPUT_DIR)
|
| processor.save_pretrained(OUTPUT_DIR)
|
| print("Done!")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|