rice-seed-classifier / COLAB_GUIDE.md
chaosbee997's picture
Upload COLAB_GUIDE.md
c8dcab6 verified

🚀 Google Colab 训练指南

本指南教你如何在 免费的 Google Colab T4 GPU 上训练籽粒分类模型。


第一步:打开 Colab

  1. 访问 colab.research.google.com
  2. 点击 文件 → 新建笔记本

第二步:切换到 GPU 环境

点击菜单栏:

Runtime → Change runtime type → Hardware accelerator → T4 GPU

然后点击 Save,并 重新连接(点击左上角重新连接按钮)。


第三步:逐 Cell 运行代码

将下面的代码按顺序复制到 Colab 的每个 Cell 中运行:

Cell 1:检查 GPU

!nvidia-smi

应该能看到 Tesla T4 信息。如果看不到,回到第二步确认 GPU 已启用。


Cell 2:安装依赖

!pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
print("✅ 依赖安装完成")

Cell 3:导入库

import os
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
)
from PIL import Image
import evaluate
from huggingface_hub import notebook_login

Cell 4:登录 Hugging Face(必须!)

notebook_login()

运行后会弹出一个输入框,要求输入 Access Token

🔑 如何获取 Token?

  1. 打开 huggingface.co/settings/tokens
  2. 点击 New token
  3. 选择 Write 权限
  4. 复制 token 粘贴到 Colab 弹窗中

Cell 5:加载数据集

DATASET_NAME = "nateraw/rice-image-dataset"
print(f"Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME)
print(ds)

labels = ds["train"].features["label"].names
num_labels = len(labels)
print(f"\n📋 类别数量: {num_labels}")
print(f"📋 类别列表: {labels}")

Cell 6:划分训练/验证集

split_ds = ds["train"].train_test_split(
    test_size=0.15,
    stratify_by_column="label",
    seed=42
)
train_ds = split_ds["train"]
val_ds = split_ds["test"]

print(f"训练集: {len(train_ds)} 张")
print(f"验证集: {len(val_ds)} 张")

Cell 7:加载模型

MODEL_NAME = "microsoft/resnet-18"

processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}

model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

print(f"✅ 模型加载完成")
print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

Cell 8:数据预处理

def transform(example_batch):
    images = []
    for img in example_batch["image"]:
        if isinstance(img, Image.Image):
            if img.mode != "RGB":
                img = img.convert("RGB")
            images.append(img)
        else:
            img = Image.fromarray(np.array(img)).convert("RGB")
            images.append(img)
    inputs = processor(images, return_tensors="pt")
    inputs["labels"] = example_batch["label"]
    return inputs

train_ds.set_transform(transform)
val_ds.set_transform(transform)
print("✅ 数据预处理设置完成")

Cell 9:评估指标

accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)
    acc = accuracy.compute(predictions=preds, references=labels)
    f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
    return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}

print("✅ 评估指标定义完成")

Cell 10:训练参数配置

OUTPUT_REPO = "chaosbee997/rice-seed-classifier"  # ← 改成你的仓库名

args = TrainingArguments(
    output_dir="/content/rice-seed-classifier",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_strategy="steps",
    logging_steps=50,
    logging_first_step=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    seed=42,
    push_to_hub=True,
    hub_model_id=OUTPUT_REPO,
    report_to="none",
)

print(f"训练配置:")
print(f"  输出仓库: {OUTPUT_REPO}")
print(f"  学习率: {args.learning_rate}")
print(f"  Batch size: {args.per_device_train_batch_size}")
print(f"  Epochs: {args.num_train_epochs}")

Cell 11:开始训练 🚀

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    data_collator=DefaultDataCollator(),
    tokenizer=processor,
)

print("开始训练...\n" + "="*50)
trainer.train()

⏱️ 训练时间:T4 GPU 上约 15-30 分钟(5 epochs,63K训练样本)

📊 预期结果:准确率 > 95%


Cell 12:评估并上传模型

metrics = trainer.evaluate()
print("\n最终评估结果:")
print("="*50)
for k, v in metrics.items():
    if isinstance(v, float):
        print(f"  {k}: {v:.4f}")
    else:
        print(f"  {k}: {v}")

print("\n正在上传到 Hugging Face Hub...")
trainer.push_to_hub()
print(f"\n✅ 上传成功!")
print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}")

Cell 13:推理测试

from transformers import pipeline
import matplotlib.pyplot as plt

classifier = pipeline("image-classification", model=OUTPUT_REPO)

# 从数据集取一张图测试
test_image = ds["train"][100]["image"]

plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.axis("off")
plt.title("测试图像")
plt.show()

results = classifier(test_image)
print("\n🔮 预测结果:")
for r in results:
    print(f"  {r['label']}: {r['score']*100:.2f}%")

⚡ 训练完成后

你的模型会自动上传到:

任何人都可以用一行代码加载你的模型:

from transformers import pipeline
classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier")

🌽 扩展到花生/玉米/小麦等其他作物

如果你想训练包含 花生、玉米、小麦、水稻 的通用籽粒分类模型:

  1. 收集图像:每种作物建立一个文件夹,例如:

    crop_seeds/
    ├── peanut/
    │   ├── img1.jpg
    │   └── img2.jpg
    ├── corn/
    │   ├── img1.jpg
    │   └── img2.jpg
    ├── wheat/
    │   ├── img1.jpg
    │   └── img2.jpg
    └── rice/
        ├── img1.jpg
        └── img2.jpg
    
  2. 创建数据集

    from datasets import load_dataset
    ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds")
    ds.push_to_hub("yourname/crop-seeds")
    
  3. 修改 Cell 5:把 DATASET_NAME 改成你的数据集名称

  4. 重新运行 Cell 5-12,其余代码完全不用改!


🆘 常见问题

Q: 训练时报 CUDA out of memory
A: 减小 batch size:把 per_device_train_batch_size=64 改成 3216

Q: 没有 Hugging Face 账号?
A: 免费注册:huggingface.co/join

Q: 训练时间太长?
A: 可以减少 epoch 数(num_train_epochs=3),或换更小的模型(microsoft/resnet-34 更大,google/mobilenet_v2_1.0_224 更轻)。