""" Fine-tune a lightweight image classifier on rice grain (seed) images. Dataset: nateraw/rice-image-dataset (5 rice varieties, 75K images) Model: microsoft/resnet-18 (lightweight, ~11M params) Usage (with Hugging Face Jobs): HF_MODEL_REPO=chaosbee997/rice-seed-classifier python train.py Or submit via hf_jobs with a10g-large / t4-small GPU. """ import os import numpy as np from datasets import load_dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, ) from transformers import DefaultDataCollator from PIL import Image import evaluate # ============= CONFIG ============= DATASET_NAME = "nateraw/rice-image-dataset" MODEL_NAME = "microsoft/resnet-18" OUTPUT_REPO = os.environ.get("HF_MODEL_REPO", "chaosbee997/rice-seed-classifier") # ============= LOAD DATASET ============= print(f"Loading dataset: {DATASET_NAME}") ds = load_dataset(DATASET_NAME) print(ds) split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42) train_ds = split_ds["train"] val_ds = split_ds["test"] labels = ds["train"].features["label"].names num_labels = len(labels) label2id = {label: i for i, label in enumerate(labels)} id2label = {i: label for i, label in enumerate(labels)} print(f"Classes ({num_labels}): {labels}") print(f"Train: {len(train_ds)} | Val: {len(val_ds)}") # ============= LOAD PROCESSOR & MODEL ============= print(f"Loading model: {MODEL_NAME}") processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained( MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True, ) # ============= PREPROCESS ============= def transform(example_batch): images = [] for img in example_batch["image"]: if isinstance(img, Image.Image): if img.mode != "RGB": img = img.convert("RGB") images.append(img) else: img = Image.fromarray(np.array(img)).convert("RGB") images.append(img) inputs = processor(images, return_tensors="pt") inputs["labels"] = example_batch["label"] return inputs train_ds.set_transform(transform) val_ds.set_transform(transform) # ============= METRICS ============= accuracy = evaluate.load("accuracy") f1 = evaluate.load("f1") def compute_metrics(eval_pred): predictions, labels = eval_pred preds = np.argmax(predictions, axis=1) acc = accuracy.compute(predictions=preds, references=labels) f1_score = f1.compute(predictions=preds, references=labels, average="weighted") return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]} # ============= TRAINING ARGS ============= args = TrainingArguments( output_dir="/tmp/rice-seed-classifier", remove_unused_columns=False, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_train_epochs=5, warmup_ratio=0.1, logging_strategy="steps", logging_steps=50, logging_first_step=True, disable_tqdm=True, load_best_model_at_end=True, metric_for_best_model="accuracy", seed=42, push_to_hub=True, hub_model_id=OUTPUT_REPO, report_to="trackio", run_name="rice_resnet18_lr5e-5_bs64", project="grain-classification", trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", "chaosbee997/mlintern-grain"), ) # ============= TRAINER ============= trainer = Trainer( model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=compute_metrics, data_collator=DefaultDataCollator(), tokenizer=processor, ) print("Starting training...") trainer.train() print("Evaluating...") metrics = trainer.evaluate() print(metrics) print("Pushing to hub...") trainer.push_to_hub() print("Done!")