| """ |
| Fine-tune a lightweight image classifier on rice grain (seed) images. |
| Dataset: nateraw/rice-image-dataset (5 rice varieties, 75K images) |
| Model: microsoft/resnet-18 (lightweight, ~11M params) |
| |
| Usage (with Hugging Face Jobs): |
| HF_MODEL_REPO=chaosbee997/rice-seed-classifier python train.py |
| |
| Or submit via hf_jobs with a10g-large / t4-small GPU. |
| """ |
| import os |
| import numpy as np |
| from datasets import load_dataset |
| from transformers import ( |
| AutoImageProcessor, |
| AutoModelForImageClassification, |
| TrainingArguments, |
| Trainer, |
| ) |
| from transformers import DefaultDataCollator |
| from PIL import Image |
| import evaluate |
|
|
| |
| DATASET_NAME = "nateraw/rice-image-dataset" |
| MODEL_NAME = "microsoft/resnet-18" |
| OUTPUT_REPO = os.environ.get("HF_MODEL_REPO", "chaosbee997/rice-seed-classifier") |
|
|
| |
| print(f"Loading dataset: {DATASET_NAME}") |
| ds = load_dataset(DATASET_NAME) |
| print(ds) |
|
|
| split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42) |
| train_ds = split_ds["train"] |
| val_ds = split_ds["test"] |
|
|
| labels = ds["train"].features["label"].names |
| num_labels = len(labels) |
| label2id = {label: i for i, label in enumerate(labels)} |
| id2label = {i: label for i, label in enumerate(labels)} |
|
|
| print(f"Classes ({num_labels}): {labels}") |
| print(f"Train: {len(train_ds)} | Val: {len(val_ds)}") |
|
|
| |
| print(f"Loading model: {MODEL_NAME}") |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForImageClassification.from_pretrained( |
| MODEL_NAME, |
| num_labels=num_labels, |
| id2label=id2label, |
| label2id=label2id, |
| ignore_mismatched_sizes=True, |
| ) |
|
|
| |
| def transform(example_batch): |
| images = [] |
| for img in example_batch["image"]: |
| if isinstance(img, Image.Image): |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| images.append(img) |
| else: |
| img = Image.fromarray(np.array(img)).convert("RGB") |
| images.append(img) |
| inputs = processor(images, return_tensors="pt") |
| inputs["labels"] = example_batch["label"] |
| return inputs |
|
|
| train_ds.set_transform(transform) |
| val_ds.set_transform(transform) |
|
|
| |
| accuracy = evaluate.load("accuracy") |
| f1 = evaluate.load("f1") |
|
|
| def compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| preds = np.argmax(predictions, axis=1) |
| acc = accuracy.compute(predictions=preds, references=labels) |
| f1_score = f1.compute(predictions=preds, references=labels, average="weighted") |
| return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]} |
|
|
| |
| args = TrainingArguments( |
| output_dir="/tmp/rice-seed-classifier", |
| remove_unused_columns=False, |
| evaluation_strategy="epoch", |
| save_strategy="epoch", |
| learning_rate=5e-5, |
| per_device_train_batch_size=64, |
| per_device_eval_batch_size=64, |
| num_train_epochs=5, |
| warmup_ratio=0.1, |
| logging_strategy="steps", |
| logging_steps=50, |
| logging_first_step=True, |
| disable_tqdm=True, |
| load_best_model_at_end=True, |
| metric_for_best_model="accuracy", |
| seed=42, |
| push_to_hub=True, |
| hub_model_id=OUTPUT_REPO, |
| report_to="trackio", |
| run_name="rice_resnet18_lr5e-5_bs64", |
| project="grain-classification", |
| trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", "chaosbee997/mlintern-grain"), |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=args, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| compute_metrics=compute_metrics, |
| data_collator=DefaultDataCollator(), |
| tokenizer=processor, |
| ) |
|
|
| print("Starting training...") |
| trainer.train() |
|
|
| print("Evaluating...") |
| metrics = trainer.evaluate() |
| print(metrics) |
|
|
| print("Pushing to hub...") |
| trainer.push_to_hub() |
| print("Done!") |
|
|