""" Colab-ready training script for rice seed/grain classification. Quick start in Google Colab: 1. Go to https://colab.research.google.com 2. Create a new notebook 3. Copy each "cell" below into a Colab code cell and run sequentially. Hardware: Runtime → Change runtime type → T4 GPU (recommended) Expected time: ~15-30 minutes for 5 epochs on T4. """ # ===================== CELL 1: 检查GPU ===================== # !nvidia-smi # ===================== CELL 2: 安装依赖 ===================== # !pip install -q transformers datasets accelerate evaluate pillow huggingface_hub # import os # os.kill(os.getpid(), 9) # 安装后重启runtime(可选,transformers建议) # ===================== CELL 3: 导入库 ===================== import os import numpy as np from datasets import load_dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, DefaultDataCollator, ) from PIL import Image import evaluate from huggingface_hub import notebook_login # ===================== CELL 4: 登录HF (需要Write权限Token) ===================== # notebook_login() # 交互式弹窗输入token # ===================== CELL 5: 加载数据集 ===================== DATASET_NAME = "nateraw/rice-image-dataset" print(f"Loading dataset: {DATASET_NAME}...") ds = load_dataset(DATASET_NAME) print(ds) labels = ds["train"].features["label"].names num_labels = len(labels) print(f"Classes ({num_labels}): {labels}") # ===================== CELL 6: 划分训练/验证集 ===================== split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42) train_ds = split_ds["train"] val_ds = split_ds["test"] print(f"Train: {len(train_ds)} | Val: {len(val_ds)}") # ===================== CELL 7: 加载模型 ===================== MODEL_NAME = "microsoft/resnet-18" processor = AutoImageProcessor.from_pretrained(MODEL_NAME) label2id = {label: i for i, label in enumerate(labels)} id2label = {i: label for i, label in enumerate(labels)} model = AutoModelForImageClassification.from_pretrained( MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True, ) print(f"Model loaded: {MODEL_NAME} ({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)") # ===================== CELL 8: 预处理函数 ===================== def transform(example_batch): images = [] for img in example_batch["image"]: if isinstance(img, Image.Image): if img.mode != "RGB": img = img.convert("RGB") images.append(img) else: img = Image.fromarray(np.array(img)).convert("RGB") images.append(img) inputs = processor(images, return_tensors="pt") inputs["labels"] = example_batch["label"] return inputs train_ds.set_transform(transform) val_ds.set_transform(transform) # ===================== CELL 9: 评估指标 ===================== accuracy = evaluate.load("accuracy") f1 = evaluate.load("f1") def compute_metrics(eval_pred): predictions, labels = eval_pred preds = np.argmax(predictions, axis=1) acc = accuracy.compute(predictions=preds, references=labels) f1_score = f1.compute(predictions=preds, references=labels, average="weighted") return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]} # ===================== CELL 10: 训练参数 ===================== OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # 改成你的仓库名 args = TrainingArguments( output_dir="/content/rice-seed-classifier", remove_unused_columns=False, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_train_epochs=5, warmup_ratio=0.1, logging_strategy="steps", logging_steps=50, logging_first_step=True, load_best_model_at_end=True, metric_for_best_model="accuracy", seed=42, push_to_hub=True, hub_model_id=OUTPUT_REPO, report_to="none", ) # ===================== CELL 11: 训练 ===================== trainer = Trainer( model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=compute_metrics, data_collator=DefaultDataCollator(), tokenizer=processor, ) print("Starting training...") trainer.train() # ===================== CELL 12: 评估 + 上传 ===================== metrics = trainer.evaluate() print("Evaluation:", metrics) trainer.push_to_hub() print(f"Model uploaded to: https://huggingface.co/{OUTPUT_REPO}") # ===================== CELL 13: 推理测试 ===================== from transformers import pipeline import matplotlib.pyplot as plt classifier = pipeline("image-classification", model=OUTPUT_REPO) test_image = ds["train"][100]["image"] plt.figure(figsize=(5,5)) plt.imshow(test_image) plt.axis("off") plt.show() results = classifier(test_image) for r in results: print(f" {r['label']}: {r['score']*100:.2f}%")