# 🚀 Google Colab 训练指南 本指南教你如何在 **免费的 Google Colab T4 GPU** 上训练籽粒分类模型。 --- ## 第一步:打开 Colab 1. 访问 [colab.research.google.com](https://colab.research.google.com) 2. 点击 **文件 → 新建笔记本** --- ## 第二步:切换到 GPU 环境 点击菜单栏: ``` Runtime → Change runtime type → Hardware accelerator → T4 GPU ``` 然后点击 **Save**,并 **重新连接**(点击左上角重新连接按钮)。 --- ## 第三步:逐 Cell 运行代码 将下面的代码按顺序复制到 Colab 的每个 Cell 中运行: ### Cell 1:检查 GPU ```python !nvidia-smi ``` > 应该能看到 `Tesla T4` 信息。如果看不到,回到第二步确认 GPU 已启用。 --- ### Cell 2:安装依赖 ```python !pip install -q transformers datasets accelerate evaluate pillow huggingface_hub print("✅ 依赖安装完成") ``` --- ### Cell 3:导入库 ```python import os import numpy as np from datasets import load_dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, DefaultDataCollator, ) from PIL import Image import evaluate from huggingface_hub import notebook_login ``` --- ### Cell 4:登录 Hugging Face(必须!) ```python notebook_login() ``` > 运行后会弹出一个输入框,要求输入 **Access Token**。 > > 🔑 **如何获取 Token?** > 1. 打开 [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) > 2. 点击 `New token` > 3. 选择 **Write** 权限 > 4. 复制 token 粘贴到 Colab 弹窗中 --- ### Cell 5:加载数据集 ```python DATASET_NAME = "nateraw/rice-image-dataset" print(f"Loading dataset: {DATASET_NAME}...") ds = load_dataset(DATASET_NAME) print(ds) labels = ds["train"].features["label"].names num_labels = len(labels) print(f"\n📋 类别数量: {num_labels}") print(f"📋 类别列表: {labels}") ``` --- ### Cell 6:划分训练/验证集 ```python split_ds = ds["train"].train_test_split( test_size=0.15, stratify_by_column="label", seed=42 ) train_ds = split_ds["train"] val_ds = split_ds["test"] print(f"训练集: {len(train_ds)} 张") print(f"验证集: {len(val_ds)} 张") ``` --- ### Cell 7:加载模型 ```python MODEL_NAME = "microsoft/resnet-18" processor = AutoImageProcessor.from_pretrained(MODEL_NAME) label2id = {label: i for i, label in enumerate(labels)} id2label = {i: label for i, label in enumerate(labels)} model = AutoModelForImageClassification.from_pretrained( MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True, ) print(f"✅ 模型加载完成") print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") ``` --- ### Cell 8:数据预处理 ```python def transform(example_batch): images = [] for img in example_batch["image"]: if isinstance(img, Image.Image): if img.mode != "RGB": img = img.convert("RGB") images.append(img) else: img = Image.fromarray(np.array(img)).convert("RGB") images.append(img) inputs = processor(images, return_tensors="pt") inputs["labels"] = example_batch["label"] return inputs train_ds.set_transform(transform) val_ds.set_transform(transform) print("✅ 数据预处理设置完成") ``` --- ### Cell 9:评估指标 ```python accuracy = evaluate.load("accuracy") f1 = evaluate.load("f1") def compute_metrics(eval_pred): predictions, labels = eval_pred preds = np.argmax(predictions, axis=1) acc = accuracy.compute(predictions=preds, references=labels) f1_score = f1.compute(predictions=preds, references=labels, average="weighted") return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]} print("✅ 评估指标定义完成") ``` --- ### Cell 10:训练参数配置 ```python OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # ← 改成你的仓库名 args = TrainingArguments( output_dir="/content/rice-seed-classifier", remove_unused_columns=False, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-5, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_train_epochs=5, warmup_ratio=0.1, logging_strategy="steps", logging_steps=50, logging_first_step=True, load_best_model_at_end=True, metric_for_best_model="accuracy", seed=42, push_to_hub=True, hub_model_id=OUTPUT_REPO, report_to="none", ) print(f"训练配置:") print(f" 输出仓库: {OUTPUT_REPO}") print(f" 学习率: {args.learning_rate}") print(f" Batch size: {args.per_device_train_batch_size}") print(f" Epochs: {args.num_train_epochs}") ``` --- ### Cell 11:开始训练 🚀 ```python trainer = Trainer( model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds, compute_metrics=compute_metrics, data_collator=DefaultDataCollator(), tokenizer=processor, ) print("开始训练...\n" + "="*50) trainer.train() ``` > ⏱️ 训练时间:T4 GPU 上约 **15-30 分钟**(5 epochs,63K训练样本) > > 📊 预期结果:准确率 > **95%** --- ### Cell 12:评估并上传模型 ```python metrics = trainer.evaluate() print("\n最终评估结果:") print("="*50) for k, v in metrics.items(): if isinstance(v, float): print(f" {k}: {v:.4f}") else: print(f" {k}: {v}") print("\n正在上传到 Hugging Face Hub...") trainer.push_to_hub() print(f"\n✅ 上传成功!") print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}") ``` --- ### Cell 13:推理测试 ```python from transformers import pipeline import matplotlib.pyplot as plt classifier = pipeline("image-classification", model=OUTPUT_REPO) # 从数据集取一张图测试 test_image = ds["train"][100]["image"] plt.figure(figsize=(5, 5)) plt.imshow(test_image) plt.axis("off") plt.title("测试图像") plt.show() results = classifier(test_image) print("\n🔮 预测结果:") for r in results: print(f" {r['label']}: {r['score']*100:.2f}%") ``` --- ## ⚡ 训练完成后 你的模型会自动上传到: - **https://huggingface.co/chaosbee997/rice-seed-classifier** 任何人都可以用一行代码加载你的模型: ```python from transformers import pipeline classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier") ``` --- ## 🌽 扩展到花生/玉米/小麦等其他作物 如果你想训练包含 **花生、玉米、小麦、水稻** 的通用籽粒分类模型: 1. **收集图像**:每种作物建立一个文件夹,例如: ``` crop_seeds/ ├── peanut/ │ ├── img1.jpg │ └── img2.jpg ├── corn/ │ ├── img1.jpg │ └── img2.jpg ├── wheat/ │ ├── img1.jpg │ └── img2.jpg └── rice/ ├── img1.jpg └── img2.jpg ``` 2. **创建数据集**: ```python from datasets import load_dataset ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds") ds.push_to_hub("yourname/crop-seeds") ``` 3. **修改 Cell 5**:把 `DATASET_NAME` 改成你的数据集名称 4. **重新运行 Cell 5-12**,其余代码完全不用改! --- ## 🆘 常见问题 **Q: 训练时报 `CUDA out of memory`?** A: 减小 batch size:把 `per_device_train_batch_size=64` 改成 `32` 或 `16`。 **Q: 没有 Hugging Face 账号?** A: 免费注册:[huggingface.co/join](https://huggingface.co/join) **Q: 训练时间太长?** A: 可以减少 epoch 数(`num_train_epochs=3`),或换更小的模型(`microsoft/resnet-34` 更大,`google/mobilenet_v2_1.0_224` 更轻)。