rice-seed-classifier / train_colab.py
chaosbee997's picture
Upload train_colab.py
2807a2d verified
"""
Colab-ready training script for rice seed/grain classification.
Quick start in Google Colab:
1. Go to https://colab.research.google.com
2. Create a new notebook
3. Copy each "cell" below into a Colab code cell and run sequentially.
Hardware: Runtime → Change runtime type → T4 GPU (recommended)
Expected time: ~15-30 minutes for 5 epochs on T4.
"""
# ===================== CELL 1: 检查GPU =====================
# !nvidia-smi
# ===================== CELL 2: 安装依赖 =====================
# !pip install -q transformers datasets accelerate evaluate pillow huggingface_hub
# import os
# os.kill(os.getpid(), 9) # 安装后重启runtime(可选,transformers建议)
# ===================== CELL 3: 导入库 =====================
import os
import numpy as np
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
from PIL import Image
import evaluate
from huggingface_hub import notebook_login
# ===================== CELL 4: 登录HF (需要Write权限Token) =====================
# notebook_login() # 交互式弹窗输入token
# ===================== CELL 5: 加载数据集 =====================
DATASET_NAME = "nateraw/rice-image-dataset"
print(f"Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME)
print(ds)
labels = ds["train"].features["label"].names
num_labels = len(labels)
print(f"Classes ({num_labels}): {labels}")
# ===================== CELL 6: 划分训练/验证集 =====================
split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
train_ds = split_ds["train"]
val_ds = split_ds["test"]
print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")
# ===================== CELL 7: 加载模型 =====================
MODEL_NAME = "microsoft/resnet-18"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
model = AutoModelForImageClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
print(f"Model loaded: {MODEL_NAME} ({sum(p.numel() for p in model.parameters())/1e6:.1f}M params)")
# ===================== CELL 8: 预处理函数 =====================
def transform(example_batch):
images = []
for img in example_batch["image"]:
if isinstance(img, Image.Image):
if img.mode != "RGB":
img = img.convert("RGB")
images.append(img)
else:
img = Image.fromarray(np.array(img)).convert("RGB")
images.append(img)
inputs = processor(images, return_tensors="pt")
inputs["labels"] = example_batch["label"]
return inputs
train_ds.set_transform(transform)
val_ds.set_transform(transform)
# ===================== CELL 9: 评估指标 =====================
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions, axis=1)
acc = accuracy.compute(predictions=preds, references=labels)
f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
# ===================== CELL 10: 训练参数 =====================
OUTPUT_REPO = "chaosbee997/rice-seed-classifier" # 改成你的仓库名
args = TrainingArguments(
output_dir="/content/rice-seed-classifier",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=5,
warmup_ratio=0.1,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
seed=42,
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
report_to="none",
)
# ===================== CELL 11: 训练 =====================
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
compute_metrics=compute_metrics,
data_collator=DefaultDataCollator(),
tokenizer=processor,
)
print("Starting training...")
trainer.train()
# ===================== CELL 12: 评估 + 上传 =====================
metrics = trainer.evaluate()
print("Evaluation:", metrics)
trainer.push_to_hub()
print(f"Model uploaded to: https://huggingface.co/{OUTPUT_REPO}")
# ===================== CELL 13: 推理测试 =====================
from transformers import pipeline
import matplotlib.pyplot as plt
classifier = pipeline("image-classification", model=OUTPUT_REPO)
test_image = ds["train"][100]["image"]
plt.figure(figsize=(5,5))
plt.imshow(test_image)
plt.axis("off")
plt.show()
results = classifier(test_image)
for r in results:
print(f" {r['label']}: {r['score']*100:.2f}%")