NeuroScan / train.py
Shoaib-33's picture
Deployment Added
f16e7e0
# train.py — training loop with MLflow + DagsHub tracking
import os
import numpy as np
import tensorflow as tf
import mlflow
import mlflow.keras
from mlflow.models.signature import infer_signature
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras import optimizers
from src.utils import get_logger, load_config, plot_history, plot_comparison
from src.data_loader import get_data_generators
from src.models import (build_baseline_cnn, build_mobilenet_tl, build_mobilenet_finetuned,
build_efficientnet, build_optuna_cnn, compile_model)
import optuna
logger = get_logger("train")
# ---------------------------------------------------------------------------
# MLflow setup
# ---------------------------------------------------------------------------
def setup_mlflow(cfg: dict):
ml = cfg["mlflow"]
os.environ["MLFLOW_TRACKING_USERNAME"] = ml["dagshub_username"]
os.environ["MLFLOW_TRACKING_PASSWORD"] = ml["dagshub_token"]
uri = f"https://dagshub.com/{ml['dagshub_username']}/{ml['dagshub_repo']}.mlflow"
mlflow.set_tracking_uri(uri)
mlflow.set_experiment(ml["experiment_name"])
logger.info(f"MLflow → DagsHub: {uri}")
def log_epoch_metrics(history):
"""Log per-epoch metrics to the active MLflow run."""
for epoch, (ta, va, tl, vl) in enumerate(zip(
history.history["accuracy"], history.history["val_accuracy"],
history.history["loss"], history.history["val_loss"],
)):
mlflow.log_metrics({
"train_accuracy" : float(ta),
"val_accuracy" : float(va),
"train_loss" : float(tl),
"val_loss" : float(vl),
}, step=epoch)
def get_callbacks(cfg: dict, checkpoint_path: str):
"""
Standard callbacks for all models.
save_format='h5' is required to avoid EagerTensor JSON
serialization crash with EfficientNet in TF 2.10.
"""
t = cfg["training"]
return [
EarlyStopping(
monitor = "val_accuracy",
patience = t["early_stopping_patience"],
restore_best_weights= True,
verbose = 1,
),
ReduceLROnPlateau(
monitor = "val_loss",
factor = t["reduce_lr_factor"],
patience = t["reduce_lr_patience"],
min_lr = t["min_lr"],
verbose = 1,
),
ModelCheckpoint(
filepath = checkpoint_path,
monitor = "val_accuracy",
save_best_only = True,
save_format = "h5", # ← fixes EagerTensor crash
verbose = 0,
),
]
# ---------------------------------------------------------------------------
# Training functions
# ---------------------------------------------------------------------------
def train_baseline(cfg, train_data, val_data):
image_size = tuple(cfg["data"]["image_size"])
epochs = cfg["training"]["epochs"]
save_dir = cfg["models"]["save_dir"]
os.makedirs(save_dir, exist_ok=True)
model = compile_model(build_baseline_cnn(image_size=image_size))
with mlflow.start_run(run_name="Baseline_CNN"):
mlflow.log_params({
"model_type" : "Baseline CNN",
"filters" : "32-64-128",
"optimizer" : "adam",
"lr" : 0.001,
"epochs" : epochs,
"batch_size" : cfg["data"]["batch_size"],
})
history = model.fit(
train_data, epochs=epochs, validation_data=val_data,
callbacks=get_callbacks(cfg, f"{save_dir}/baseline_best.h5")
)
log_epoch_metrics(history)
mlflow.log_metrics({
"best_val_accuracy" : float(max(history.history["val_accuracy"])),
"best_val_loss" : float(min(history.history["val_loss"])),
})
mlflow.keras.log_model(model, "baseline_cnn")
logger.info("Baseline CNN training complete.")
return model, history
def train_transfer_learning(cfg, train_data, val_data):
image_size = tuple(cfg["data"]["image_size"])
epochs = cfg["training"]["epochs"]
save_dir = cfg["models"]["save_dir"]
model = compile_model(build_mobilenet_tl(image_size=image_size))
with mlflow.start_run(run_name="Transfer_Learning_MobileNetV2"):
mlflow.log_params({
"model_type" : "MobileNetV2 TL",
"base_frozen" : True,
"optimizer" : "adam",
"lr" : 0.001,
"epochs" : epochs,
})
history = model.fit(
train_data, epochs=epochs, validation_data=val_data,
callbacks=get_callbacks(cfg, f"{save_dir}/tl_best.h5")
)
log_epoch_metrics(history)
mlflow.log_metrics({
"best_val_accuracy" : float(max(history.history["val_accuracy"])),
"best_val_loss" : float(min(history.history["val_loss"])),
})
mlflow.keras.log_model(model, "transfer_learning")
logger.info("Transfer Learning training complete.")
return model, history
def train_finetuned(cfg, tl_model, train_data, val_data):
epochs = cfg["training"]["epochs"]
save_dir = cfg["models"]["save_dir"]
build_mobilenet_finetuned(tl_model.layers[0], unfreeze_last=20)
compile_model(tl_model, lr=1e-5)
with mlflow.start_run(run_name="Fine_Tuned_MobileNetV2"):
mlflow.log_params({
"model_type" : "MobileNetV2 Fine-Tuned",
"unfrozen_layers" : 20,
"lr" : 1e-5,
"epochs" : epochs,
})
history = tl_model.fit(
train_data, epochs=epochs, validation_data=val_data,
callbacks=get_callbacks(cfg, f"{save_dir}/ft_best.h5")
)
log_epoch_metrics(history)
mlflow.log_metrics({
"best_val_accuracy" : float(max(history.history["val_accuracy"])),
"best_val_loss" : float(min(history.history["val_loss"])),
})
mlflow.keras.log_model(tl_model, "fine_tuned")
logger.info("Fine-Tuned training complete.")
return tl_model, history
def train_efficientnet(cfg, train_data, val_data):
"""
EfficientNetB0 in TF 2.10 crashes ModelCheckpoint because its internal
rescaling layer stores weights as EagerTensors which cannot be JSON-serialized
during checkpoint saving.
Fix: use a custom callback that calls model.save_weights() instead of
model.save() — weights-only saving never touches the model config JSON,
so EagerTensors are never serialized.
"""
image_size = tuple(cfg["data"]["image_size"])
epochs = cfg["training"]["epochs"]
save_dir = cfg["models"]["save_dir"]
t = cfg["training"]
model = build_efficientnet(image_size=image_size)
model.compile(
optimizer = optimizers.Adam(learning_rate=0.001),
loss = "categorical_crossentropy",
metrics = ["accuracy"],
)
# ── Custom checkpoint: saves weights only (no JSON config serialization) ──
weights_path = os.path.join(save_dir, "effnet_best_weights.h5")
best_val_acc = [0.0] # mutable container so inner class can write to it
class WeightsCheckpoint(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
va = float(logs.get("val_accuracy", 0.0))
if va > best_val_acc[0]:
best_val_acc[0] = va
self.model.save_weights(weights_path)
logger.info(f" EfficientNet weights saved (val_acc={va:.4f})")
with mlflow.start_run(run_name="EfficientNetB0_TL"):
mlflow.log_params({
"model_type" : "EfficientNetB0",
"base_frozen" : True,
"lr" : 0.001,
"epochs" : epochs,
})
history = model.fit(
train_data,
epochs = epochs,
validation_data = val_data,
callbacks = [
EarlyStopping(
monitor = "val_accuracy",
patience = t["early_stopping_patience"],
restore_best_weights = True,
verbose = 1,
),
ReduceLROnPlateau(
monitor = "val_loss",
factor = t["reduce_lr_factor"],
patience = t["reduce_lr_patience"],
min_lr = t["min_lr"],
verbose = 1,
),
WeightsCheckpoint(), # ← replaces ModelCheckpoint entirely
],
)
log_epoch_metrics(history)
mlflow.log_metrics({
"best_val_accuracy" : float(max(history.history["val_accuracy"])),
"best_val_loss" : float(min(history.history["val_loss"])),
})
# Restore best weights
if os.path.exists(weights_path):
model.load_weights(weights_path)
logger.info("Best EfficientNet weights restored.")
# mlflow.keras.log_model crashes for EfficientNetB0 in TF 2.10 because
# mlflow internally calls model.save() which hits the same EagerTensor
# JSON bug. Workaround: save weights manually and log as artifact.
final_weights_path = os.path.join(save_dir, "effnet_final_weights.h5")
model.save_weights(final_weights_path)
mlflow.log_artifact(final_weights_path, artifact_path="efficientnet_weights")
mlflow.set_tag("efficientnet_note",
"Logged as weights-only artifact due to TF2.10 EagerTensor bug")
logger.info(f"EfficientNet weights logged to MLflow → {final_weights_path}")
logger.info("EfficientNetB0 training complete.")
return model, history
# ---------------------------------------------------------------------------
# Optuna hyperparameter search
# ---------------------------------------------------------------------------
def run_optuna(cfg, train_data, val_data):
image_size = tuple(cfg["data"]["image_size"])
n_trials = cfg["optuna"]["n_trials"]
def objective(trial):
params = {
"filters_1" : trial.suggest_categorical("filters_1", [32, 64]),
"filters_2" : trial.suggest_categorical("filters_2", [64, 128]),
"filters_3" : trial.suggest_categorical("filters_3", [128, 256]),
"dense_units" : trial.suggest_categorical("dense_units", [64, 128, 256]),
"dropout" : trial.suggest_float("dropout", 0.2, 0.5),
"lr" : trial.suggest_float("lr", 1e-4, 1e-2, log=True),
}
with mlflow.start_run(run_name=f"trial_{trial.number:02d}", nested=True):
mlflow.log_params(params)
m = compile_model(
build_optuna_cnn(params, image_size=image_size),
lr=params["lr"]
)
h = m.fit(
train_data,
epochs = 10,
validation_data = val_data,
callbacks = [EarlyStopping(monitor="val_accuracy",
patience=3,
restore_best_weights=True)],
verbose = 0,
)
best_val = float(max(h.history["val_accuracy"]))
mlflow.log_metric("best_val_accuracy", best_val)
for ep, (ta, va, tl, vl) in enumerate(zip(
h.history["accuracy"], h.history["val_accuracy"],
h.history["loss"], h.history["val_loss"],
)):
mlflow.log_metrics({
"train_accuracy" : float(ta),
"val_accuracy" : float(va),
"train_loss" : float(tl),
"val_loss" : float(vl),
}, step=ep)
return best_val
optuna.logging.set_verbosity(optuna.logging.WARNING)
study = optuna.create_study(direction="maximize")
with mlflow.start_run(run_name="Optuna_Search_Parent"):
mlflow.log_params({
"n_trials" : n_trials,
"direction" : "maximize",
"objective" : "val_accuracy",
})
study.optimize(objective, n_trials=n_trials)
mlflow.log_metric("best_val_accuracy", float(study.best_value))
mlflow.log_metric("best_trial_number", study.best_trial.number)
mlflow.log_params({f"best_{k}": v for k, v in study.best_params.items()})
logger.info(f"Optuna best val_accuracy : {study.best_value:.4f}")
logger.info(f"Optuna best params : {study.best_params}")
return study
def train_optuna_best(cfg, study, train_data, val_data):
image_size = tuple(cfg["data"]["image_size"])
epochs = cfg["training"]["epochs"]
save_dir = cfg["models"]["save_dir"]
p = study.best_params
model = compile_model(
build_optuna_cnn(p, image_size=image_size),
lr=p["lr"]
)
with mlflow.start_run(run_name="Optuna_Best_CNN_Final") as run:
mlflow.log_params({
**p,
"model_type" : "Optuna Best CNN",
"epochs" : epochs,
})
history = model.fit(
train_data, epochs=epochs, validation_data=val_data,
callbacks=get_callbacks(cfg, f"{save_dir}/optuna_best.h5")
)
log_epoch_metrics(history)
mlflow.log_metrics({
"best_val_accuracy" : float(max(history.history["val_accuracy"])),
"best_val_loss" : float(min(history.history["val_loss"])),
})
sample = train_data[0][0][:1]
signature = infer_signature(sample, model.predict(sample))
mlflow.keras.log_model(model, "optuna_best_cnn", signature=signature)
best_run_id = run.info.run_id
logger.info("Optuna Best CNN training complete.")
return model, history, best_run_id
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
cfg = load_config("config.yaml")
setup_mlflow(cfg)
train_data, val_data, test_data = get_data_generators(cfg)
logger.info("=" * 50)
logger.info("Starting full training pipeline")
logger.info("=" * 50)
baseline_model, history_base = train_baseline(cfg, train_data, val_data)
tl_model, history_tl = train_transfer_learning(cfg, train_data, val_data)
ft_model, history_ft = train_finetuned(cfg, tl_model, train_data, val_data)
eff_model, history_eff = train_efficientnet(cfg, train_data, val_data)
study = run_optuna(cfg, train_data, val_data)
optuna_model, history_optuna, best_run_id = train_optuna_best(
cfg, study, train_data, val_data
)
results = {
"Baseline CNN" : float(max(history_base.history["val_accuracy"])),
"Transfer Learning" : float(max(history_tl.history["val_accuracy"])),
"Fine-Tuned" : float(max(history_ft.history["val_accuracy"])),
"EfficientNetB0" : float(max(history_eff.history["val_accuracy"])),
"Optuna Best CNN" : float(max(history_optuna.history["val_accuracy"])),
}
logger.info("\n" + "=" * 45)
logger.info(f"{'Model':<25} {'Val Acc':>10}")
logger.info("=" * 45)
for name, acc in results.items():
mark = " << best" if acc == max(results.values()) else ""
logger.info(f"{name:<25} {acc:>10.4f}{mark}")
logger.info("=" * 45)
os.makedirs("./logs", exist_ok=True)
plot_comparison(results, save_path="./logs/model_comparison.png")
logger.info("Training pipeline complete.")