🚀 Google Colab 训练指南
本指南教你如何在 免费的 Google Colab T4 GPU 上训练籽粒分类模型。
第一步:打开 Colab
- 访问 colab.research.google.com
- 点击 文件 → 新建笔记本
第二步:切换到 GPU 环境
点击菜单栏:
Runtime → Change runtime type → Hardware accelerator → T4 GPU
然后点击 Save,并 重新连接(点击左上角重新连接按钮)。
第三步:逐 Cell 运行代码
将下面的代码按顺序复制到 Colab 的每个 Cell 中运行:
Cell 1:检查 GPU
!nvidia-smi
应该能看到
Tesla T4信息。如果看不到,回到第二步确认 GPU 已启用。
Cell 2:安装依赖
!pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
print("✅ 依赖安装完成")
Cell 3:导入库
import os
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
from PIL import Image
import evaluate
from huggingface_hub import notebook_login
Cell 4:登录 Hugging Face(必须!)
notebook_login()
运行后会弹出一个输入框,要求输入 Access Token。
🔑 如何获取 Token?
- 打开 huggingface.co/settings/tokens
- 点击
New token- 选择 Write 权限
- 复制 token 粘贴到 Colab 弹窗中
Cell 5:加载数据集
DATASET_NAME = "nateraw/rice-image-dataset"
print(f"Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME)
print(ds)
labels = ds["train"].features["label"].names
num_labels = len(labels)
print(f"\n📋 类别数量: {num_labels}")
print(f"📋 类别列表: {labels}")
Cell 6:划分训练/验证集
split_ds = ds["train"].train_test_split(
test_size=0.15,
stratify_by_column="label",
seed=42
)
train_ds = split_ds["train"]
val_ds = split_ds["test"]
print(f"训练集: {len(train_ds)} 张")
print(f"验证集: {len(val_ds)} 张")
Cell 7:加载模型
MODEL_NAME = "microsoft/resnet-18"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
print(f"✅ 模型加载完成")
print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
Cell 8:数据预处理
def transform(example_batch):
images = []
for img in example_batch["image"]:
if isinstance(img, Image.Image):
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
else:
img = Image.fromarray(np.array(img)).convert("RGB")
images.append(img)
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
train_ds.set_transform(transform)
val_ds.set_transform(transform)
print("✅ 数据预处理设置完成")
Cell 9:评估指标
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
acc = accuracy.compute(predictions=preds, references=labels)
f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
print("✅ 评估指标定义完成")
Cell 10:训练参数配置
OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # ← 改成你的仓库名
args = TrainingArguments(
output_dir="/content/rice-seed-classifier",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=42,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
report_to="none",
)
print(f"训练配置:")
print(f" 输出仓库: {OUTPUT_REPO}")
print(f" 学习率: {args.learning_rate}")
print(f" Batch size: {args.per_device_train_batch_size}")
print(f" Epochs: {args.num_train_epochs}")
Cell 11:开始训练 🚀
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
data_collator=DefaultDataCollator(),
tokenizer=processor,
)
print("开始训练...\n" + "="*50)
trainer.train()
⏱️ 训练时间:T4 GPU 上约 15-30 分钟(5 epochs,63K训练样本)
📊 预期结果:准确率 > 95%
Cell 12:评估并上传模型
metrics = trainer.evaluate()
print("\n最终评估结果:")
print("="*50)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
print("\n正在上传到 Hugging Face Hub...")
trainer.push_to_hub()
print(f"\n✅ 上传成功!")
print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}")
Cell 13:推理测试
from transformers import pipeline
import matplotlib.pyplot as plt
classifier = pipeline("image-classification", model=OUTPUT_REPO)
# 从数据集取一张图测试
test_image = ds["train"][100]["image"]
plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.axis("off")
plt.title("测试图像")
plt.show()
results = classifier(test_image)
print("\n🔮 预测结果:")
for r in results:
print(f" {r['label']}: {r['score']*100:.2f}%")
⚡ 训练完成后
你的模型会自动上传到:
任何人都可以用一行代码加载你的模型:
from transformers import pipeline
classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier")
🌽 扩展到花生/玉米/小麦等其他作物
如果你想训练包含 花生、玉米、小麦、水稻 的通用籽粒分类模型:
收集图像:每种作物建立一个文件夹,例如:
crop_seeds/ ├── peanut/ │ ├── img1.jpg │ └── img2.jpg ├── corn/ │ ├── img1.jpg │ └── img2.jpg ├── wheat/ │ ├── img1.jpg │ └── img2.jpg └── rice/ ├── img1.jpg └── img2.jpg创建数据集:
from datasets import load_dataset ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds") ds.push_to_hub("yourname/crop-seeds")修改 Cell 5:把
DATASET_NAME改成你的数据集名称重新运行 Cell 5-12,其余代码完全不用改!
🆘 常见问题
Q: 训练时报 CUDA out of memory?
A: 减小 batch size:把 per_device_train_batch_size=64 改成 32 或 16。
Q: 没有 Hugging Face 账号?
A: 免费注册:huggingface.co/join
Q: 训练时间太长?
A: 可以减少 epoch 数(num_train_epochs=3),或换更小的模型(microsoft/resnet-34 更大,google/mobilenet_v2_1.0_224 更轻)。