chaosbee997's picture
Upload train.py
c8d1295 verified
"""
Fine-tune a lightweight image classifier on rice grain (seed) images.
Dataset: nateraw/rice-image-dataset (5 rice varieties, 75K images)
Model: microsoft/resnet-18 (lightweight, ~11M params)
Usage (with Hugging Face Jobs):
HF_MODEL_REPO=chaosbee997/rice-seed-classifier python train.py
Or submit via hf_jobs with a10g-large / t4-small GPU.
"""
import os
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
)
from transformers import DefaultDataCollator
from PIL import Image
import evaluate
# ============= CONFIG =============
DATASET_NAME = "nateraw/rice-image-dataset"
MODEL_NAME = "microsoft/resnet-18"
OUTPUT_REPO = os.environ.get("HF_MODEL_REPO", "chaosbee997/rice-seed-classifier")
# ============= LOAD DATASET =============
print(f"Loading dataset: {DATASET_NAME}")
ds = load_dataset(DATASET_NAME)
print(ds)
split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
train_ds = split_ds["train"]
val_ds = split_ds["test"]
labels = ds["train"].features["label"].names
num_labels = len(labels)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
print(f"Classes ({num_labels}): {labels}")
print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")
# ============= LOAD PROCESSOR & MODEL =============
print(f"Loading model: {MODEL_NAME}")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
# ============= PREPROCESS =============
def transform(example_batch):
images = []
for img in example_batch["image"]:
if isinstance(img, Image.Image):
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
else:
img = Image.fromarray(np.array(img)).convert("RGB")
images.append(img)
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
train_ds.set_transform(transform)
val_ds.set_transform(transform)
# ============= METRICS =============
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
acc = accuracy.compute(predictions=preds, references=labels)
f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
# ============= TRAINING ARGS =============
args = TrainingArguments(
output_dir="/tmp/rice-seed-classifier",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=42,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
report_to="trackio",
run_name="rice_resnet18_lr5e-5_bs64",
project="grain-classification",
trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", "chaosbee997/mlintern-grain"),
)
# ============= TRAINER =============
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
data_collator=DefaultDataCollator(),
tokenizer=processor,
)
print("Starting training...")
trainer.train()
print("Evaluating...")
metrics = trainer.evaluate()
print(metrics)
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")