File size: 7,799 Bytes
c8dcab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# 🚀 Google Colab 训练指南

本指南教你如何在 **免费的 Google Colab T4 GPU** 上训练籽粒分类模型。

---

## 第一步:打开 Colab

1. 访问 [colab.research.google.com](https://colab.research.google.com)
2. 点击 **文件 → 新建笔记本**

---

## 第二步:切换到 GPU 环境

点击菜单栏:
```
Runtime → Change runtime type → Hardware accelerator → T4 GPU
```
然后点击 **Save**,并 **重新连接**(点击左上角重新连接按钮)。

---

## 第三步:逐 Cell 运行代码

将下面的代码按顺序复制到 Colab 的每个 Cell 中运行:

### Cell 1:检查 GPU
```python
!nvidia-smi
```
> 应该能看到 `Tesla T4` 信息。如果看不到,回到第二步确认 GPU 已启用。

---

### Cell 2:安装依赖
```python
!pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
print("✅ 依赖安装完成")
```

---

### Cell 3:导入库
```python
import os
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
)
from PIL import Image
import evaluate
from huggingface_hub import notebook_login
```

---

### Cell 4:登录 Hugging Face(必须!)
```python
notebook_login()
```
> 运行后会弹出一个输入框,要求输入 **Access Token**
>
> 🔑 **如何获取 Token?**
> 1. 打开 [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
> 2. 点击 `New token`
> 3. 选择 **Write** 权限
> 4. 复制 token 粘贴到 Colab 弹窗中

---

### Cell 5:加载数据集
```python
DATASET_NAME = "nateraw/rice-image-dataset"
print(f"Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME)
print(ds)

labels = ds["train"].features["label"].names
num_labels = len(labels)
print(f"\n📋 类别数量: {num_labels}")
print(f"📋 类别列表: {labels}")
```

---

### Cell 6:划分训练/验证集
```python
split_ds = ds["train"].train_test_split(
    test_size=0.15,
    stratify_by_column="label",
    seed=42
)
train_ds = split_ds["train"]
val_ds = split_ds["test"]

print(f"训练集: {len(train_ds)} 张")
print(f"验证集: {len(val_ds)} 张")
```

---

### Cell 7:加载模型
```python
MODEL_NAME = "microsoft/resnet-18"

processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}

model = AutoModelForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

print(f"✅ 模型加载完成")
print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
```

---

### Cell 8:数据预处理
```python
def transform(example_batch):
    images = []
    for img in example_batch["image"]:
        if isinstance(img, Image.Image):
            if img.mode != "RGB":
                img = img.convert("RGB")
            images.append(img)
        else:
            img = Image.fromarray(np.array(img)).convert("RGB")
            images.append(img)
    inputs = processor(images, return_tensors="pt")
    inputs["labels"] = example_batch["label"]
    return inputs

train_ds.set_transform(transform)
val_ds.set_transform(transform)
print("✅ 数据预处理设置完成")
```

---

### Cell 9:评估指标
```python
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)
    acc = accuracy.compute(predictions=preds, references=labels)
    f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
    return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}

print("✅ 评估指标定义完成")
```

---

### Cell 10:训练参数配置
```python
OUTPUT_REPO = "chaosbee997/rice-seed-classifier"  # ← 改成你的仓库名

args = TrainingArguments(
    output_dir="/content/rice-seed-classifier",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_strategy="steps",
    logging_steps=50,
    logging_first_step=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    seed=42,
    push_to_hub=True,
    hub_model_id=OUTPUT_REPO,
    report_to="none",
)

print(f"训练配置:")
print(f"  输出仓库: {OUTPUT_REPO}")
print(f"  学习率: {args.learning_rate}")
print(f"  Batch size: {args.per_device_train_batch_size}")
print(f"  Epochs: {args.num_train_epochs}")
```

---

### Cell 11:开始训练 🚀
```python
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    data_collator=DefaultDataCollator(),
    tokenizer=processor,
)

print("开始训练...\n" + "="*50)
trainer.train()
```
> ⏱️ 训练时间:T4 GPU 上约 **15-30 分钟**(5 epochs,63K训练样本)
>
> 📊 预期结果:准确率 > **95%**

---

### Cell 12:评估并上传模型
```python
metrics = trainer.evaluate()
print("\n最终评估结果:")
print("="*50)
for k, v in metrics.items():
    if isinstance(v, float):
        print(f"  {k}: {v:.4f}")
    else:
        print(f"  {k}: {v}")

print("\n正在上传到 Hugging Face Hub...")
trainer.push_to_hub()
print(f"\n✅ 上传成功!")
print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}")
```

---

### Cell 13:推理测试
```python
from transformers import pipeline
import matplotlib.pyplot as plt

classifier = pipeline("image-classification", model=OUTPUT_REPO)

# 从数据集取一张图测试
test_image = ds["train"][100]["image"]

plt.figure(figsize=(5, 5))
plt.imshow(test_image)
plt.axis("off")
plt.title("测试图像")
plt.show()

results = classifier(test_image)
print("\n🔮 预测结果:")
for r in results:
    print(f"  {r['label']}: {r['score']*100:.2f}%")
```

---

## ⚡ 训练完成后

你的模型会自动上传到:
- **https://huggingface.co/chaosbee997/rice-seed-classifier**

任何人都可以用一行代码加载你的模型:
```python
from transformers import pipeline
classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier")
```

---

## 🌽 扩展到花生/玉米/小麦等其他作物

如果你想训练包含 **花生、玉米、小麦、水稻** 的通用籽粒分类模型:

1. **收集图像**:每种作物建立一个文件夹,例如:
   ```
   crop_seeds/
   ├── peanut/
   │   ├── img1.jpg
   │   └── img2.jpg
   ├── corn/
   │   ├── img1.jpg
   │   └── img2.jpg
   ├── wheat/
   │   ├── img1.jpg
   │   └── img2.jpg
   └── rice/
       ├── img1.jpg
       └── img2.jpg
   ```

2. **创建数据集**```python
   from datasets import load_dataset
   ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds")
   ds.push_to_hub("yourname/crop-seeds")
   ```

3. **修改 Cell 5**:把 `DATASET_NAME` 改成你的数据集名称

4. **重新运行 Cell 5-12**,其余代码完全不用改!

---

## 🆘 常见问题

**Q: 训练时报 `CUDA out of memory`?**  
A: 减小 batch size:把 `per_device_train_batch_size=64` 改成 `32``16`**Q: 没有 Hugging Face 账号?**  
A: 免费注册:[huggingface.co/join](https://huggingface.co/join)

**Q: 训练时间太长?**  
A: 可以减少 epoch 数(`num_train_epochs=3`),或换更小的模型(`microsoft/resnet-34` 更大,`google/mobilenet_v2_1.0_224` 更轻)。