chaosbee997 commited on
Commit
c8d1295
·
verified ·
1 Parent(s): 463ffec

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +132 -0
train.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tune a lightweight image classifier on rice grain (seed) images.
3
+ Dataset: nateraw/rice-image-dataset (5 rice varieties, 75K images)
4
+ Model: microsoft/resnet-18 (lightweight, ~11M params)
5
+
6
+ Usage (with Hugging Face Jobs):
7
+ HF_MODEL_REPO=chaosbee997/rice-seed-classifier python train.py
8
+
9
+ Or submit via hf_jobs with a10g-large / t4-small GPU.
10
+ """
11
+ import os
12
+ import numpy as np
13
+ from datasets import load_dataset
14
+ from transformers import (
15
+ AutoImageProcessor,
16
+ AutoModelForImageClassification,
17
+ TrainingArguments,
18
+ Trainer,
19
+ )
20
+ from transformers import DefaultDataCollator
21
+ from PIL import Image
22
+ import evaluate
23
+
24
+ # ============= CONFIG =============
25
+ DATASET_NAME = "nateraw/rice-image-dataset"
26
+ MODEL_NAME = "microsoft/resnet-18"
27
+ OUTPUT_REPO = os.environ.get("HF_MODEL_REPO", "chaosbee997/rice-seed-classifier")
28
+
29
+ # ============= LOAD DATASET =============
30
+ print(f"Loading dataset: {DATASET_NAME}")
31
+ ds = load_dataset(DATASET_NAME)
32
+ print(ds)
33
+
34
+ split_ds = ds["train"].train_test_split(test_size=0.15, stratify_by_column="label", seed=42)
35
+ train_ds = split_ds["train"]
36
+ val_ds = split_ds["test"]
37
+
38
+ labels = ds["train"].features["label"].names
39
+ num_labels = len(labels)
40
+ label2id = {label: i for i, label in enumerate(labels)}
41
+ id2label = {i: label for i, label in enumerate(labels)}
42
+
43
+ print(f"Classes ({num_labels}): {labels}")
44
+ print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")
45
+
46
+ # ============= LOAD PROCESSOR & MODEL =============
47
+ print(f"Loading model: {MODEL_NAME}")
48
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
49
+ model = AutoModelForImageClassification.from_pretrained(
50
+ MODEL_NAME,
51
+ num_labels=num_labels,
52
+ id2label=id2label,
53
+ label2id=label2id,
54
+ ignore_mismatched_sizes=True,
55
+ )
56
+
57
+ # ============= PREPROCESS =============
58
+ def transform(example_batch):
59
+ images = []
60
+ for img in example_batch["image"]:
61
+ if isinstance(img, Image.Image):
62
+ if img.mode != "RGB":
63
+ img = img.convert("RGB")
64
+ images.append(img)
65
+ else:
66
+ img = Image.fromarray(np.array(img)).convert("RGB")
67
+ images.append(img)
68
+ inputs = processor(images, return_tensors="pt")
69
+ inputs["labels"] = example_batch["label"]
70
+ return inputs
71
+
72
+ train_ds.set_transform(transform)
73
+ val_ds.set_transform(transform)
74
+
75
+ # ============= METRICS =============
76
+ accuracy = evaluate.load("accuracy")
77
+ f1 = evaluate.load("f1")
78
+
79
+ def compute_metrics(eval_pred):
80
+ predictions, labels = eval_pred
81
+ preds = np.argmax(predictions, axis=1)
82
+ acc = accuracy.compute(predictions=preds, references=labels)
83
+ f1_score = f1.compute(predictions=preds, references=labels, average="weighted")
84
+ return {"accuracy": acc["accuracy"], "f1": f1_score["f1"]}
85
+
86
+ # ============= TRAINING ARGS =============
87
+ args = TrainingArguments(
88
+ output_dir="/tmp/rice-seed-classifier",
89
+ remove_unused_columns=False,
90
+ evaluation_strategy="epoch",
91
+ save_strategy="epoch",
92
+ learning_rate=5e-5,
93
+ per_device_train_batch_size=64,
94
+ per_device_eval_batch_size=64,
95
+ num_train_epochs=5,
96
+ warmup_ratio=0.1,
97
+ logging_strategy="steps",
98
+ logging_steps=50,
99
+ logging_first_step=True,
100
+ disable_tqdm=True,
101
+ load_best_model_at_end=True,
102
+ metric_for_best_model="accuracy",
103
+ seed=42,
104
+ push_to_hub=True,
105
+ hub_model_id=OUTPUT_REPO,
106
+ report_to="trackio",
107
+ run_name="rice_resnet18_lr5e-5_bs64",
108
+ project="grain-classification",
109
+ trackio_space_id=os.environ.get("TRACKIO_SPACE_ID", "chaosbee997/mlintern-grain"),
110
+ )
111
+
112
+ # ============= TRAINER =============
113
+ trainer = Trainer(
114
+ model=model,
115
+ args=args,
116
+ train_dataset=train_ds,
117
+ eval_dataset=val_ds,
118
+ compute_metrics=compute_metrics,
119
+ data_collator=DefaultDataCollator(),
120
+ tokenizer=processor,
121
+ )
122
+
123
+ print("Starting training...")
124
+ trainer.train()
125
+
126
+ print("Evaluating...")
127
+ metrics = trainer.evaluate()
128
+ print(metrics)
129
+
130
+ print("Pushing to hub...")
131
+ trainer.push_to_hub()
132
+ print("Done!")