rice-seed-classifier / COLAB_GUIDE.md
chaosbee997's picture
Upload COLAB_GUIDE.md
c8dcab6 verified
# 🚀 Google Colab 训练指南
本指南教你如何在 **免费的 Google Colab T4 GPU** 上训练籽粒分类模型。
---
## 第一步:打开 Colab
1. 访问 [colab.research.google.com](https://colab.research.google.com)
2. 点击 **文件 → 新建笔记本**
---
## 第二步:切换到 GPU 环境
点击菜单栏:
```
Runtime → Change runtime type → Hardware accelerator → T4 GPU
```
然后点击 **Save**,并 **重新连接**(点击左上角重新连接按钮)。
---
## 第三步:逐 Cell 运行代码
将下面的代码按顺序复制到 Colab 的每个 Cell 中运行:
### Cell 1:检查 GPU
```python
!nvidia-smi
```
> 应该能看到 `Tesla T4` 信息。如果看不到,回到第二步确认 GPU 已启用。
---
### Cell 2:安装依赖
```python
!pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
print("✅ 依赖安装完成")
```
---
### Cell 3:导入库
```python
import os
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
from PIL import Image
import evaluate
from huggingface_hub import notebook_login
```
---
### Cell 4:登录 Hugging Face(必须!)
```python
notebook_login()
```
> 运行后会弹出一个输入框,要求输入 **Access Token**
>
> 🔑 **如何获取 Token?**
> 1. 打开 [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
> 2. 点击 `New token`
> 3. 选择 **Write** 权限
> 4. 复制 token 粘贴到 Colab 弹窗中
---
### Cell 5:加载数据集
```python
DATASET_NAME = "nateraw/rice-image-dataset"
print(f"Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME)
print(ds)
labels = ds["train"].features["label"].names
num_labels = len(labels)
print(f"\n📋 类别数量: {num_labels}")
print(f"📋 类别列表: {labels}")
```
---
### Cell 6:划分训练/验证集
```python
split_ds = ds["train"].train_test_split(
test_size=0.15,
stratify_by_column="label",
seed=42
)
train_ds = split_ds["train"]
val_ds = split_ds["test"]
print(f"训练集: {len(train_ds)} 张")
print(f"验证集: {len(val_ds)} 张")
```
---
### Cell 7:加载模型
```python
MODEL_NAME = "microsoft/resnet-18"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
print(f"✅ 模型加载完成")
print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
```
---
### Cell 8:数据预处理
```python
def transform(example_batch):
images = []
for img in example_batch["image"]:
if isinstance(img, Image.Image):
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
else:
img = Image.fromarray(np.array(img)).convert("RGB")
images.append(img)
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
train_ds.set_transform(transform)
val_ds.set_transform(transform)
print("✅ 数据预处理设置完成")
```
---
### Cell 9:评估指标
```python
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
acc = accuracy.compute(predictions=preds, references=labels)
f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
print("✅ 评估指标定义完成")
```
---
### Cell 10:训练参数配置
```python
OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # ← 改成你的仓库名
args = TrainingArguments(
output_dir="/content/rice-seed-classifier",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=42,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
report_to="none",
)
print(f"训练配置:")
print(f" 输出仓库: {OUTPUT_REPO}")
print(f" 学习率: {args.learning_rate}")
print(f" Batch size: {args.per_device_train_batch_size}")
print(f" Epochs: {args.num_train_epochs}")
```
---
### Cell 11:开始训练 🚀
```python
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
data_collator=DefaultDataCollator(),
tokenizer=processor,
)
print("开始训练...\n" + "="*50)
trainer.train()
```
> ⏱️ 训练时间:T4 GPU 上约 **15-30 分钟**(5 epochs,63K训练样本)
>
> 📊 预期结果:准确率 > **95%**
---
### Cell 12:评估并上传模型
```python
metrics = trainer.evaluate()
print("\n最终评估结果:")
print("="*50)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
print("\n正在上传到 Hugging Face Hub...")
trainer.push_to_hub()
print(f"\n✅ 上传成功!")
print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}")
```
---
### Cell 13:推理测试
```python
from transformers import pipeline
import matplotlib.pyplot as plt
classifier = pipeline("image-classification", model=OUTPUT_REPO)
# 从数据集取一张图测试
test_image = ds["train"][100]["image"]
plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.axis("off")
plt.title("测试图像")
plt.show()
results = classifier(test_image)
print("\n🔮 预测结果:")
for r in results:
print(f" {r['label']}: {r['score']*100:.2f}%")
```
---
## ⚡ 训练完成后
你的模型会自动上传到:
- **https://huggingface.co/chaosbee997/rice-seed-classifier**
任何人都可以用一行代码加载你的模型:
```python
from transformers import pipeline
classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier")
```
---
## 🌽 扩展到花生/玉米/小麦等其他作物
如果你想训练包含 **花生、玉米、小麦、水稻** 的通用籽粒分类模型:
1. **收集图像**:每种作物建立一个文件夹,例如:
```
crop_seeds/
├── peanut/
│ ├── img1.jpg
│ └── img2.jpg
├── corn/
│ ├── img1.jpg
│ └── img2.jpg
├── wheat/
│ ├── img1.jpg
│ └── img2.jpg
└── rice/
├── img1.jpg
└── img2.jpg
```
2. **创建数据集**
```python
from datasets import load_dataset
ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds")
ds.push_to_hub("yourname/crop-seeds")
```
3. **修改 Cell 5**:把 `DATASET_NAME` 改成你的数据集名称
4. **重新运行 Cell 5-12**,其余代码完全不用改!
---
## 🆘 常见问题
**Q: 训练时报 `CUDA out of memory`?**
A: 减小 batch size:把 `per_device_train_batch_size=64` 改成 `32``16`
**Q: 没有 Hugging Face 账号?**
A: 免费注册:[huggingface.co/join](https://huggingface.co/join)
**Q: 训练时间太长?**
A: 可以减少 epoch 数(`num_train_epochs=3`),或换更小的模型(`microsoft/resnet-34` 更大,`google/mobilenet_v2_1.0_224` 更轻)。