File size: 3,914 Bytes
c8d1295 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
Fine-tune a lightweight image classifier on rice grain (seed) images.
Dataset: nateraw/rice-image-dataset (5 rice varieties, 75K images)
Model: microsoft/resnet-18 (lightweight, ~11M params)
Usage (with Hugging Face Jobs):
HF_MODEL_REPO=chaosbee997/rice-seed-classifier python train.py
Or submit via hf_jobs with a10g-large / t4-small GPU.
"""
import os
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
)
from transformers import DefaultDataCollator
from PIL import Image
import evaluate
# ============= CONFIG =============
DATASET_NAME = "nateraw/rice-image-dataset"
MODEL_NAME = "microsoft/resnet-18"
OUTPUT_REPO = os.environ.get("HF_MODEL_REPO", "chaosbee997/rice-seed-classifier")
# ============= LOAD DATASET =============
print(f"Loading dataset: {DATASET_NAME}")
ds = load_dataset(DATASET_NAME)
print(ds)
split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
train_ds = split_ds["train"]
val_ds = split_ds["test"]
labels = ds["train"].features["label"].names
num_labels = len(labels)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
print(f"Classes ({num_labels}): {labels}")
print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")
# ============= LOAD PROCESSOR & MODEL =============
print(f"Loading model: {MODEL_NAME}")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
# ============= PREPROCESS =============
def transform(example_batch):
images = []
for img in example_batch["image"]:
if isinstance(img, Image.Image):
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
else:
img = Image.fromarray(np.array(img)).convert("RGB")
images.append(img)
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
train_ds.set_transform(transform)
val_ds.set_transform(transform)
# ============= METRICS =============
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
acc = accuracy.compute(predictions=preds, references=labels)
f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
# ============= TRAINING ARGS =============
args = TrainingArguments(
output_dir="/tmp/rice-seed-classifier",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=42,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
report_to="trackio",
run_name="rice_resnet18_lr5e-5_bs64",
project="grain-classification",
trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", "chaosbee997/mlintern-grain"),
)
# ============= TRAINER =============
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
data_collator=DefaultDataCollator(),
tokenizer=processor,
)
print("Starting training...")
trainer.train()
print("Evaluating...")
metrics = trainer.evaluate()
print(metrics)
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")
|