| # 🚀 Google Colab 训练指南 |
|
|
| 本指南教你如何在 **免费的 Google Colab T4 GPU** 上训练籽粒分类模型。 |
|
|
| --- |
|
|
| ## 第一步:打开 Colab |
|
|
| 1. 访问 [colab.research.google.com](https://colab.research.google.com) |
| 2. 点击 **文件 → 新建笔记本** |
|
|
| --- |
|
|
| ## 第二步:切换到 GPU 环境 |
|
|
| 点击菜单栏: |
| ``` |
| Runtime → Change runtime type → Hardware accelerator → T4 GPU |
| ``` |
| 然后点击 **Save**,并 **重新连接**(点击左上角重新连接按钮)。 |
|
|
| --- |
|
|
| ## 第三步:逐 Cell 运行代码 |
|
|
| 将下面的代码按顺序复制到 Colab 的每个 Cell 中运行: |
|
|
| ### Cell 1:检查 GPU |
| ```python |
| !nvidia-smi |
| ``` |
| > 应该能看到 `Tesla T4` 信息。如果看不到,回到第二步确认 GPU 已启用。 |
|
|
| --- |
|
|
| ### Cell 2:安装依赖 |
| ```python |
| !pip install -q transformers datasets accelerate evaluate pillow huggingface_hub |
| print("✅ 依赖安装完成") |
| ``` |
|
|
| --- |
|
|
| ### Cell 3:导入库 |
| ```python |
| import os |
| import numpy as np |
| from datasets import load_dataset |
| from transformers import ( |
| AutoImageProcessor, |
| AutoModelForImageClassification, |
| TrainingArguments, |
| Trainer, |
| DefaultDataCollator, |
| ) |
| from PIL import Image |
| import evaluate |
| from huggingface_hub import notebook_login |
| ``` |
|
|
| --- |
|
|
| ### Cell 4:登录 Hugging Face(必须!) |
| ```python |
| notebook_login() |
| ``` |
| > 运行后会弹出一个输入框,要求输入 **Access Token**。 |
| > |
| > 🔑 **如何获取 Token?** |
| > 1. 打开 [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) |
| > 2. 点击 `New token` |
| > 3. 选择 **Write** 权限 |
| > 4. 复制 token 粘贴到 Colab 弹窗中 |
|
|
| --- |
|
|
| ### Cell 5:加载数据集 |
| ```python |
| DATASET_NAME = "nateraw/rice-image-dataset" |
| print(f"Loading dataset: {DATASET_NAME}...") |
| ds = load_dataset(DATASET_NAME) |
| print(ds) |
| |
| labels = ds["train"].features["label"].names |
| num_labels = len(labels) |
| print(f"\n📋 类别数量: {num_labels}") |
| print(f"📋 类别列表: {labels}") |
| ``` |
|
|
| --- |
|
|
| ### Cell 6:划分训练/验证集 |
| ```python |
| split_ds = ds["train"].train_test_split( |
| test_size=0.15, |
| stratify_by_column="label", |
| seed=42 |
| ) |
| train_ds = split_ds["train"] |
| val_ds = split_ds["test"] |
| |
| print(f"训练集: {len(train_ds)} 张") |
| print(f"验证集: {len(val_ds)} 张") |
| ``` |
|
|
| --- |
|
|
| ### Cell 7:加载模型 |
| ```python |
| MODEL_NAME = "microsoft/resnet-18" |
| |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| |
| label2id = {label: i for i, label in enumerate(labels)} |
| id2label = {i: label for i, label in enumerate(labels)} |
| |
| model = AutoModelForImageClassification.from_pretrained( |
| MODEL_NAME, |
| num_labels=num_labels, |
| id2label=id2label, |
| label2id=label2id, |
| ignore_mismatched_sizes=True, |
| ) |
| |
| print(f"✅ 模型加载完成") |
| print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") |
| ``` |
|
|
| --- |
|
|
| ### Cell 8:数据预处理 |
| ```python |
| def transform(example_batch): |
| images = [] |
| for img in example_batch["image"]: |
| if isinstance(img, Image.Image): |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| images.append(img) |
| else: |
| img = Image.fromarray(np.array(img)).convert("RGB") |
| images.append(img) |
| inputs = processor(images, return_tensors="pt") |
| inputs["labels"] = example_batch["label"] |
| return inputs |
| |
| train_ds.set_transform(transform) |
| val_ds.set_transform(transform) |
| print("✅ 数据预处理设置完成") |
| ``` |
|
|
| --- |
|
|
| ### Cell 9:评估指标 |
| ```python |
| accuracy = evaluate.load("accuracy") |
| f1 = evaluate.load("f1") |
| |
| def compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| preds = np.argmax(predictions, axis=1) |
| acc = accuracy.compute(predictions=preds, references=labels) |
| f1_score = f1.compute(predictions=preds, references=labels, average="weighted") |
| return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]} |
| |
| print("✅ 评估指标定义完成") |
| ``` |
|
|
| --- |
|
|
| ### Cell 10:训练参数配置 |
| ```python |
| OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # ← 改成你的仓库名 |
| |
| args = TrainingArguments( |
| output_dir="/content/rice-seed-classifier", |
| remove_unused_columns=False, |
| evaluation_strategy="epoch", |
| save_strategy="epoch", |
| learning_rate=5e-5, |
| per_device_train_batch_size=64, |
| per_device_eval_batch_size=64, |
| num_train_epochs=5, |
| warmup_ratio=0.1, |
| logging_strategy="steps", |
| logging_steps=50, |
| logging_first_step=True, |
| load_best_model_at_end=True, |
| metric_for_best_model="accuracy", |
| seed=42, |
| push_to_hub=True, |
| hub_model_id=OUTPUT_REPO, |
| report_to="none", |
| ) |
| |
| print(f"训练配置:") |
| print(f" 输出仓库: {OUTPUT_REPO}") |
| print(f" 学习率: {args.learning_rate}") |
| print(f" Batch size: {args.per_device_train_batch_size}") |
| print(f" Epochs: {args.num_train_epochs}") |
| ``` |
|
|
| --- |
|
|
| ### Cell 11:开始训练 🚀 |
| ```python |
| trainer = Trainer( |
| model=model, |
| args=args, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| compute_metrics=compute_metrics, |
| data_collator=DefaultDataCollator(), |
| tokenizer=processor, |
| ) |
| |
| print("开始训练...\n" + "="*50) |
| trainer.train() |
| ``` |
| > ⏱️ 训练时间:T4 GPU 上约 **15-30 分钟**(5 epochs,63K训练样本) |
| > |
| > 📊 预期结果:准确率 > **95%** |
|
|
| --- |
|
|
| ### Cell 12:评估并上传模型 |
| ```python |
| metrics = trainer.evaluate() |
| print("\n最终评估结果:") |
| print("="*50) |
| for k, v in metrics.items(): |
| if isinstance(v, float): |
| print(f" {k}: {v:.4f}") |
| else: |
| print(f" {k}: {v}") |
| |
| print("\n正在上传到 Hugging Face Hub...") |
| trainer.push_to_hub() |
| print(f"\n✅ 上传成功!") |
| print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}") |
| ``` |
|
|
| --- |
|
|
| ### Cell 13:推理测试 |
| ```python |
| from transformers import pipeline |
| import matplotlib.pyplot as plt |
| |
| classifier = pipeline("image-classification", model=OUTPUT_REPO) |
| |
| # 从数据集取一张图测试 |
| test_image = ds["train"][100]["image"] |
| |
| plt.figure(figsize=(5, 5)) |
| plt.imshow(test_image) |
| plt.axis("off") |
| plt.title("测试图像") |
| plt.show() |
| |
| results = classifier(test_image) |
| print("\n🔮 预测结果:") |
| for r in results: |
| print(f" {r['label']}: {r['score']*100:.2f}%") |
| ``` |
|
|
| --- |
|
|
| ## ⚡ 训练完成后 |
|
|
| 你的模型会自动上传到: |
| - **https://huggingface.co/chaosbee997/rice-seed-classifier** |
|
|
| 任何人都可以用一行代码加载你的模型: |
| ```python |
| from transformers import pipeline |
| classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier") |
| ``` |
|
|
| --- |
|
|
| ## 🌽 扩展到花生/玉米/小麦等其他作物 |
|
|
| 如果你想训练包含 **花生、玉米、小麦、水稻** 的通用籽粒分类模型: |
|
|
| 1. **收集图像**:每种作物建立一个文件夹,例如: |
| ``` |
| crop_seeds/ |
| ├── peanut/ |
| │ ├── img1.jpg |
| │ └── img2.jpg |
| ├── corn/ |
| │ ├── img1.jpg |
| │ └── img2.jpg |
| ├── wheat/ |
| │ ├── img1.jpg |
| │ └── img2.jpg |
| └── rice/ |
| ├── img1.jpg |
| └── img2.jpg |
| ``` |
|
|
| 2. **创建数据集**: |
| ```python |
| from datasets import load_dataset |
| ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds") |
| ds.push_to_hub("yourname/crop-seeds") |
| ``` |
|
|
| 3. **修改 Cell 5**:把 `DATASET_NAME` 改成你的数据集名称 |
|
|
| 4. **重新运行 Cell 5-12**,其余代码完全不用改! |
|
|
| --- |
|
|
| ## 🆘 常见问题 |
|
|
| **Q: 训练时报 `CUDA out of memory`?** |
| A: 减小 batch size:把 `per_device_train_batch_size=64` 改成 `32` 或 `16`。 |
|
|
| **Q: 没有 Hugging Face 账号?** |
| A: 免费注册:[huggingface.co/join](https://huggingface.co/join) |
|
|
| **Q: 训练时间太长?** |
| A: 可以减少 epoch 数(`num_train_epochs=3`),或换更小的模型(`microsoft/resnet-34` 更大,`google/mobilenet_v2_1.0_224` 更轻)。 |
|
|