chaosbee997 commited on
Commit
c8dcab6
·
verified ·
1 Parent(s): 2807a2d

Upload COLAB_GUIDE.md

Browse files
Files changed (1) hide show
  1. COLAB_GUIDE.md +325 -0
COLAB_GUIDE.md ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Google Colab 训练指南
2
+
3
+ 本指南教你如何在 **免费的 Google Colab T4 GPU** 上训练籽粒分类模型。
4
+
5
+ ---
6
+
7
+ ## 第一步:打开 Colab
8
+
9
+ 1. 访问 [colab.research.google.com](https://colab.research.google.com)
10
+ 2. 点击 **文件 → 新建笔记本**
11
+
12
+ ---
13
+
14
+ ## 第二步:切换到 GPU 环境
15
+
16
+ 点击菜单栏:
17
+ ```
18
+ Runtime → Change runtime type → Hardware accelerator → T4 GPU
19
+ ```
20
+ 然后点击 **Save**,并 **重新连接**(点击左上角重新连接按钮)。
21
+
22
+ ---
23
+
24
+ ## 第三步:逐 Cell 运行代码
25
+
26
+ 将下面的代码按顺序复制到 Colab 的每个 Cell 中运行:
27
+
28
+ ### Cell 1:检查 GPU
29
+ ```python
30
+ !nvidia-smi
31
+ ```
32
+ > 应该能看到 `Tesla T4` 信息。如果看不到,回到第二步确认 GPU 已启用。
33
+
34
+ ---
35
+
36
+ ### Cell 2:安装依赖
37
+ ```python
38
+ !pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
39
+ print("✅ 依赖安装完成")
40
+ ```
41
+
42
+ ---
43
+
44
+ ### Cell 3:导入库
45
+ ```python
46
+ import os
47
+ import numpy as np
48
+ from datasets import load_dataset
49
+ from transformers import (
50
+ AutoImageProcessor,
51
+ AutoModelForImageClassification,
52
+ TrainingArguments,
53
+ Trainer,
54
+ DefaultDataCollator,
55
+ )
56
+ from PIL import Image
57
+ import evaluate
58
+ from huggingface_hub import notebook_login
59
+ ```
60
+
61
+ ---
62
+
63
+ ### Cell 4:登录 Hugging Face(必须!)
64
+ ```python
65
+ notebook_login()
66
+ ```
67
+ > 运行后会弹出一个输入框,要求输入 **Access Token**。
68
+ >
69
+ > 🔑 **如何获取 Token?**
70
+ > 1. 打开 [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
71
+ > 2. 点击 `New token`
72
+ > 3. 选择 **Write** 权限
73
+ > 4. 复制 token 粘贴到 Colab 弹窗中
74
+
75
+ ---
76
+
77
+ ### Cell 5:加载数据集
78
+ ```python
79
+ DATASET_NAME = "nateraw/rice-image-dataset"
80
+ print(f"Loading dataset: {DATASET_NAME}...")
81
+ ds = load_dataset(DATASET_NAME)
82
+ print(ds)
83
+
84
+ labels = ds["train"].features["label"].names
85
+ num_labels = len(labels)
86
+ print(f"\n📋 类别数量: {num_labels}")
87
+ print(f"📋 类别列表: {labels}")
88
+ ```
89
+
90
+ ---
91
+
92
+ ### Cell 6:划分训练/验证集
93
+ ```python
94
+ split_ds = ds["train"].train_test_split(
95
+ test_size=0.15,
96
+ stratify_by_column="label",
97
+ seed=42
98
+ )
99
+ train_ds = split_ds["train"]
100
+ val_ds = split_ds["test"]
101
+
102
+ print(f"训练集: {len(train_ds)} 张")
103
+ print(f"验证集: {len(val_ds)} 张")
104
+ ```
105
+
106
+ ---
107
+
108
+ ### Cell 7:加载模型
109
+ ```python
110
+ MODEL_NAME = "microsoft/resnet-18"
111
+
112
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
113
+
114
+ label2id = {label: i for i, label in enumerate(labels)}
115
+ id2label = {i: label for i, label in enumerate(labels)}
116
+
117
+ model = AutoModelForImageClassification.from_pretrained(
118
+ MODEL_NAME,
119
+ num_labels=num_labels,
120
+ id2label=id2label,
121
+ label2id=label2id,
122
+ ignore_mismatched_sizes=True,
123
+ )
124
+
125
+ print(f"✅ 模型加载完成")
126
+ print(f"参数总量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
127
+ ```
128
+
129
+ ---
130
+
131
+ ### Cell 8:数据预处理
132
+ ```python
133
+ def transform(example_batch):
134
+ images = []
135
+ for img in example_batch["image"]:
136
+ if isinstance(img, Image.Image):
137
+ if img.mode != "RGB":
138
+ img = img.convert("RGB")
139
+ images.append(img)
140
+ else:
141
+ img = Image.fromarray(np.array(img)).convert("RGB")
142
+ images.append(img)
143
+ inputs = processor(images, return_tensors="pt")
144
+ inputs["labels"] = example_batch["label"]
145
+ return inputs
146
+
147
+ train_ds.set_transform(transform)
148
+ val_ds.set_transform(transform)
149
+ print("✅ 数据预处理设置完成")
150
+ ```
151
+
152
+ ---
153
+
154
+ ### Cell 9:评估指标
155
+ ```python
156
+ accuracy = evaluate.load("accuracy")
157
+ f1 = evaluate.load("f1")
158
+
159
+ def compute_metrics(eval_pred):
160
+ predictions, labels = eval_pred
161
+ preds = np.argmax(predictions, axis=1)
162
+ acc = accuracy.compute(predictions=preds, references=labels)
163
+ f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
164
+ return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
165
+
166
+ print("✅ 评估指标定义完成")
167
+ ```
168
+
169
+ ---
170
+
171
+ ### Cell 10:训练参数配置
172
+ ```python
173
+ OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # ← 改成你的仓库名
174
+
175
+ args = TrainingArguments(
176
+ output_dir="/content/rice-seed-classifier",
177
+ remove_unused_columns=False,
178
+ evaluation_strategy="epoch",
179
+ save_strategy="epoch",
180
+ learning_rate=5e-5,
181
+ per_device_train_batch_size=64,
182
+ per_device_eval_batch_size=64,
183
+ num_train_epochs=5,
184
+ warmup_ratio=0.1,
185
+ logging_strategy="steps",
186
+ logging_steps=50,
187
+ logging_first_step=True,
188
+ load_best_model_at_end=True,
189
+ metric_for_best_model="accuracy",
190
+ seed=42,
191
+ push_to_hub=True,
192
+ hub_model_id=OUTPUT_REPO,
193
+ report_to="none",
194
+ )
195
+
196
+ print(f"训练配置:")
197
+ print(f" 输出仓库: {OUTPUT_REPO}")
198
+ print(f" 学习率: {args.learning_rate}")
199
+ print(f" Batch size: {args.per_device_train_batch_size}")
200
+ print(f" Epochs: {args.num_train_epochs}")
201
+ ```
202
+
203
+ ---
204
+
205
+ ### Cell 11:开始训练 🚀
206
+ ```python
207
+ trainer = Trainer(
208
+ model=model,
209
+ args=args,
210
+ train_dataset=train_ds,
211
+ eval_dataset=val_ds,
212
+ compute_metrics=compute_metrics,
213
+ data_collator=DefaultDataCollator(),
214
+ tokenizer=processor,
215
+ )
216
+
217
+ print("开始训练...\n" + "="*50)
218
+ trainer.train()
219
+ ```
220
+ > ⏱️ 训练时间:T4 GPU 上约 **15-30 分钟**(5 epochs,63K训练样本)
221
+ >
222
+ > 📊 预期结果:准确率 > **95%**
223
+
224
+ ---
225
+
226
+ ### Cell 12:评估并上传模型
227
+ ```python
228
+ metrics = trainer.evaluate()
229
+ print("\n最终评估结果:")
230
+ print("="*50)
231
+ for k, v in metrics.items():
232
+ if isinstance(v, float):
233
+ print(f" {k}: {v:.4f}")
234
+ else:
235
+ print(f" {k}: {v}")
236
+
237
+ print("\n正在上传到 Hugging Face Hub...")
238
+ trainer.push_to_hub()
239
+ print(f"\n✅ 上传成功!")
240
+ print(f"🔗 模型地址: https://huggingface.co/{OUTPUT_REPO}")
241
+ ```
242
+
243
+ ---
244
+
245
+ ### Cell 13:推理测试
246
+ ```python
247
+ from transformers import pipeline
248
+ import matplotlib.pyplot as plt
249
+
250
+ classifier = pipeline("image-classification", model=OUTPUT_REPO)
251
+
252
+ # 从数据集取一张图测试
253
+ test_image = ds["train"][100]["image"]
254
+
255
+ plt.figure(figsize=(5, 5))
256
+ plt.imshow(test_image)
257
+ plt.axis("off")
258
+ plt.title("测试图像")
259
+ plt.show()
260
+
261
+ results = classifier(test_image)
262
+ print("\n🔮 预测结果:")
263
+ for r in results:
264
+ print(f" {r['label']}: {r['score']*100:.2f}%")
265
+ ```
266
+
267
+ ---
268
+
269
+ ## ⚡ 训练完成后
270
+
271
+ 你的模型会自动上传到:
272
+ - **https://huggingface.co/chaosbee997/rice-seed-classifier**
273
+
274
+ 任何人都可以用一行代码加载你的模型:
275
+ ```python
276
+ from transformers import pipeline
277
+ classifier = pipeline("image-classification", model="chaosbee997/rice-seed-classifier")
278
+ ```
279
+
280
+ ---
281
+
282
+ ## 🌽 扩展到花生/玉米/小麦等其他作物
283
+
284
+ 如果你想训练包含 **花生、玉米、小麦、水稻** 的通用籽粒分类模型:
285
+
286
+ 1. **收集图像**:每种作物建立一个文件夹,例如:
287
+ ```
288
+ crop_seeds/
289
+ ├── peanut/
290
+ │ ├── img1.jpg
291
+ │ └── img2.jpg
292
+ ├── corn/
293
+ │ ├── img1.jpg
294
+ │ └── img2.jpg
295
+ ├── wheat/
296
+ │ ├── img1.jpg
297
+ │ └── img2.jpg
298
+ └── rice/
299
+ ├── img1.jpg
300
+ └── img2.jpg
301
+ ```
302
+
303
+ 2. **创建数据集**:
304
+ ```python
305
+ from datasets import load_dataset
306
+ ds = load_dataset("imagefolder", data_dir="/path/to/crop_seeds")
307
+ ds.push_to_hub("yourname/crop-seeds")
308
+ ```
309
+
310
+ 3. **修改 Cell 5**:把 `DATASET_NAME` 改成你的数据集名称
311
+
312
+ 4. **重新运行 Cell 5-12**,其余代码完全不用改!
313
+
314
+ ---
315
+
316
+ ## 🆘 常见问题
317
+
318
+ **Q: 训练时报 `CUDA out of memory`?**
319
+ A: 减小 batch size:把 `per_device_train_batch_size=64` 改成 `32` 或 `16`。
320
+
321
+ **Q: 没有 Hugging Face 账号?**
322
+ A: 免费注册:[huggingface.co/join](https://huggingface.co/join)
323
+
324
+ **Q: 训练时间太长?**
325
+ A: 可以减少 epoch 数(`num_train_epochs=3`),或换更小的模型(`microsoft/resnet-34` 更大,`google/mobilenet_v2_1.0_224` 更轻)。