File size: 7,799 Bytes
c8dcab6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | # 🚀 Google Colab 训练指南
本指南教你如何在 **免费的 Google Colab T4 GPU** 上训练籽粒分类模型。
---
## 第一步:打开 Colab
1. 访问 [colab.research.google.com](https://colab.research.google.com)
2. 点击 **文件 → 新建笔记本**
---
## 第二步:切换到 GPU 环境
点击菜单栏:
```
Runtime → Change runtime type → Hardware accelerator → T4 GPU
```
然后点击 **Save**,并 **重新连接**(点击左上角重新连接按钮)。
---
## 第三步:逐 Cell 运行代码
将下面的代码按顺序复制到 Colab 的每个 Cell 中运行:
### Cell 1:检查 GPU
```python
!nvidia-smi
```
> 应该能看到 `Tesla T4` 信息。如果看不到,回到第二步确认 GPU 已启用。
---
### Cell 2:安装依赖
```python
!pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
print("✅ 依赖安装完成")
```
---
### Cell 3:导入库
```python
import os
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
from PIL import Image
import evaluate
from huggingface_hub import notebook_login
```
---
### Cell 4:登录 Hugging Face(必须!)
```python
notebook_login()
```
> 运行后会弹出一个输入框,要求输入 **Access Token**。
>
> 🔑 **如何获取 Token?**
> 1. 打开 [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
> 2. 点击 `New token`
> 3. 选择 **Write** 权限
> 4. 复制 token 粘贴到 Colab 弹窗中
---
### Cell 5:加载数据集
```python
DATASET_NAME = "nateraw/rice-image-dataset"
print(f"Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME)
print(ds)
labels = ds["train"].features["label"].names
num_labels = len(labels)
print(f"\n📋 类别数量: {num_labels}")
print(f"📋 类别列表: {labels}")
```
---
### Cell 6:划分训练/验证集
```python
split_ds = ds["train"].train_test_split(
test_size=0.15,
stratify_by_column="label",
seed=42
)
train_ds = split_ds["train"]
val_ds = split_ds["test"]
print(f"训练集: {len(train_ds)} 张")
print(f"验证集: {len(val_ds)} 张")
```
---
### Cell 7:加载模型
```python
MODEL_NAME = "microsoft/resnet-18"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
print(f"✅ 模型加载完成")
print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
```
---
### Cell 8:数据预处理
```python
def transform(example_batch):
images = []
for img in example_batch["image"]:
if isinstance(img, Image.Image):
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
else:
img = Image.fromarray(np.array(img)).convert("RGB")
images.append(img)
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
train_ds.set_transform(transform)
val_ds.set_transform(transform)
print("✅ 数据预处理设置完成")
```
---
### Cell 9:评估指标
```python
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
acc = accuracy.compute(predictions=preds, references=labels)
f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
print("✅ 评估指标定义完成")
```
---
### Cell 10:训练参数配置
```python
OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # ← 改成你的仓库名
args = TrainingArguments(
output_dir="/content/rice-seed-classifier",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=42,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
report_to="none",
)
print(f"训练配置:")
print(f" 输出仓库: {OUTPUT_REPO}")
print(f" 学习率: {args.learning_rate}")
print(f" Batch size: {args.per_device_train_batch_size}")
print(f" Epochs: {args.num_train_epochs}")
```
---
### Cell 11:开始训练 🚀
```python
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
data_collator=DefaultDataCollator(),
tokenizer=processor,
)
print("开始训练...\n" + "="*50)
trainer.train()
```
> ⏱️ 训练时间:T4 GPU 上约 **15-30 分钟**(5 epochs,63K训练样本)
>
> 📊 预期结果:准确率 > **95%**
---
### Cell 12:评估并上传模型
```python
metrics = trainer.evaluate()
print("\n最终评估结果:")
print("="*50)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
print("\n正在上传到 Hugging Face Hub...")
trainer.push_to_hub()
print(f"\n✅ 上传成功!")
print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}")
```
---
### Cell 13:推理测试
```python
from transformers import pipeline
import matplotlib.pyplot as plt
classifier = pipeline("image-classification", model=OUTPUT_REPO)
# 从数据集取一张图测试
test_image = ds["train"][100]["image"]
plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.axis("off")
plt.title("测试图像")
plt.show()
results = classifier(test_image)
print("\n🔮 预测结果:")
for r in results:
print(f" {r['label']}: {r['score']*100:.2f}%")
```
---
## ⚡ 训练完成后
你的模型会自动上传到:
- **https://huggingface.co/chaosbee997/rice-seed-classifier**
任何人都可以用一行代码加载你的模型:
```python
from transformers import pipeline
classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier")
```
---
## 🌽 扩展到花生/玉米/小麦等其他作物
如果你想训练包含 **花生、玉米、小麦、水稻** 的通用籽粒分类模型:
1. **收集图像**:每种作物建立一个文件夹,例如:
```
crop_seeds/
├── peanut/
│ ├── img1.jpg
│ └── img2.jpg
├── corn/
│ ├── img1.jpg
│ └── img2.jpg
├── wheat/
│ ├── img1.jpg
│ └── img2.jpg
└── rice/
├── img1.jpg
└── img2.jpg
```
2. **创建数据集**:
```python
from datasets import load_dataset
ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds")
ds.push_to_hub("yourname/crop-seeds")
```
3. **修改 Cell 5**:把 `DATASET_NAME` 改成你的数据集名称
4. **重新运行 Cell 5-12**,其余代码完全不用改!
---
## 🆘 常见问题
**Q: 训练时报 `CUDA out of memory`?**
A: 减小 batch size:把 `per_device_train_batch_size=64` 改成 `32` 或 `16`。
**Q: 没有 Hugging Face 账号?**
A: 免费注册:[huggingface.co/join](https://huggingface.co/join)
**Q: 训练时间太长?**
A: 可以减少 epoch 数(`num_train_epochs=3`),或换更小的模型(`microsoft/resnet-34` 更大,`google/mobilenet_v2_1.0_224` 更轻)。
|