chaosbee997 commited on
Commit
2807a2d
·
verified ·
1 Parent(s): 959472a

Upload train_colab.py

Browse files
Files changed (1) hide show
  1. train_colab.py +158 -0
train_colab.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Colab-ready training script for rice seed/grain classification.
3
+
4
+ Quick start in Google Colab:
5
+ 1. Go to https://colab.research.google.com
6
+ 2. Create a new notebook
7
+ 3. Copy each "cell" below into a Colab code cell and run sequentially.
8
+
9
+ Hardware: Runtime → Change runtime type → T4 GPU (recommended)
10
+ Expected time: ~15-30 minutes for 5 epochs on T4.
11
+ """
12
+
13
+ # ===================== CELL 1: 检查GPU =====================
14
+ # !nvidia-smi
15
+
16
+ # ===================== CELL 2: 安装依赖 =====================
17
+ # !pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
18
+ # import os
19
+ # os.kill(os.getpid(), 9) # 安装后重启runtime(可选,transformers建议)
20
+
21
+ # ===================== CELL 3: 导入库 =====================
22
+ import os
23
+ import numpy as np
24
+ from datasets import load_dataset
25
+ from transformers import (
26
+ AutoImageProcessor,
27
+ AutoModelForImageClassification,
28
+ TrainingArguments,
29
+ Trainer,
30
+ DefaultDataCollator,
31
+ )
32
+ from PIL import Image
33
+ import evaluate
34
+ from huggingface_hub import notebook_login
35
+
36
+ # ===================== CELL 4: 登录HF (需要Write权限Token) =====================
37
+ # notebook_login() # 交互式弹窗输入token
38
+
39
+ # ===================== CELL 5: 加载数据集 =====================
40
+ DATASET_NAME = "nateraw/rice-image-dataset"
41
+ print(f"Loading dataset: {DATASET_NAME}...")
42
+ ds = load_dataset(DATASET_NAME)
43
+ print(ds)
44
+
45
+ labels = ds["train"].features["label"].names
46
+ num_labels = len(labels)
47
+ print(f"Classes ({num_labels}): {labels}")
48
+
49
+ # ===================== CELL 6: 划分训练/验证集 =====================
50
+ split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
51
+ train_ds = split_ds["train"]
52
+ val_ds = split_ds["test"]
53
+ print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")
54
+
55
+ # ===================== CELL 7: 加载模型 =====================
56
+ MODEL_NAME = "microsoft/resnet-18"
57
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
58
+
59
+ label2id = {label: i for i, label in enumerate(labels)}
60
+ id2label = {i: label for i, label in enumerate(labels)}
61
+
62
+ model = AutoModelForImageClassification.from_pretrained(
63
+ MODEL_NAME,
64
+ num_labels=num_labels,
65
+ id2label=id2label,
66
+ label2id=label2id,
67
+ ignore_mismatched_sizes=True,
68
+ )
69
+ print(f"Model loaded: {MODEL_NAME} ({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)")
70
+
71
+ # ===================== CELL 8: 预处理函数 =====================
72
+ def transform(example_batch):
73
+ images = []
74
+ for img in example_batch["image"]:
75
+ if isinstance(img, Image.Image):
76
+ if img.mode != "RGB":
77
+ img = img.convert("RGB")
78
+ images.append(img)
79
+ else:
80
+ img = Image.fromarray(np.array(img)).convert("RGB")
81
+ images.append(img)
82
+ inputs = processor(images, return_tensors="pt")
83
+ inputs["labels"] = example_batch["label"]
84
+ return inputs
85
+
86
+ train_ds.set_transform(transform)
87
+ val_ds.set_transform(transform)
88
+
89
+ # ===================== CELL 9: 评估指标 =====================
90
+ accuracy = evaluate.load("accuracy")
91
+ f1 = evaluate.load("f1")
92
+
93
+ def compute_metrics(eval_pred):
94
+ predictions, labels = eval_pred
95
+ preds = np.argmax(predictions, axis=1)
96
+ acc = accuracy.compute(predictions=preds, references=labels)
97
+ f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
98
+ return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
99
+
100
+ # ===================== CELL 10: 训练参数 =====================
101
+ OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # 改成你的仓库名
102
+
103
+ args = TrainingArguments(
104
+ output_dir="/content/rice-seed-classifier",
105
+ remove_unused_columns=False,
106
+ evaluation_strategy="epoch",
107
+ save_strategy="epoch",
108
+ learning_rate=5e-5,
109
+ per_device_train_batch_size=64,
110
+ per_device_eval_batch_size=64,
111
+ num_train_epochs=5,
112
+ warmup_ratio=0.1,
113
+ logging_strategy="steps",
114
+ logging_steps=50,
115
+ logging_first_step=True,
116
+ load_best_model_at_end=True,
117
+ metric_for_best_model="accuracy",
118
+ seed=42,
119
+ push_to_hub=True,
120
+ hub_model_id=OUTPUT_REPO,
121
+ report_to="none",
122
+ )
123
+
124
+ # ===================== CELL 11: 训练 =====================
125
+ trainer = Trainer(
126
+ model=model,
127
+ args=args,
128
+ train_dataset=train_ds,
129
+ eval_dataset=val_ds,
130
+ compute_metrics=compute_metrics,
131
+ data_collator=DefaultDataCollator(),
132
+ tokenizer=processor,
133
+ )
134
+
135
+ print("Starting training...")
136
+ trainer.train()
137
+
138
+ # ===================== CELL 12: 评估 + 上传 =====================
139
+ metrics = trainer.evaluate()
140
+ print("Evaluation:", metrics)
141
+ trainer.push_to_hub()
142
+ print(f"Model uploaded to: https://huggingface.co/{OUTPUT_REPO}")
143
+
144
+ # ===================== CELL 13: 推理测试 =====================
145
+ from transformers import pipeline
146
+ import matplotlib.pyplot as plt
147
+
148
+ classifier = pipeline("image-classification", model=OUTPUT_REPO)
149
+ test_image = ds["train"][100]["image"]
150
+
151
+ plt.figure(figsize=(5,5))
152
+ plt.imshow(test_image)
153
+ plt.axis("off")
154
+ plt.show()
155
+
156
+ results = classifier(test_image)
157
+ for r in results:
158
+ print(f" {r['label']}: {r['score']*100:.2f}%")