| |
|
|
| import os |
| import numpy as np |
| import tensorflow as tf |
| import mlflow |
| import mlflow.keras |
| from mlflow.models.signature import infer_signature |
| from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint |
| from tensorflow.keras import optimizers |
|
|
| from src.utils import get_logger, load_config, plot_history, plot_comparison |
| from src.data_loader import get_data_generators |
| from src.models import (build_baseline_cnn, build_mobilenet_tl, build_mobilenet_finetuned, |
| build_efficientnet, build_optuna_cnn, compile_model) |
| import optuna |
|
|
| logger = get_logger("train") |
|
|
|
|
| |
| |
| |
|
|
| def setup_mlflow(cfg: dict): |
| ml = cfg["mlflow"] |
| os.environ["MLFLOW_TRACKING_USERNAME"] = ml["dagshub_username"] |
| os.environ["MLFLOW_TRACKING_PASSWORD"] = ml["dagshub_token"] |
| uri = f"https://dagshub.com/{ml['dagshub_username']}/{ml['dagshub_repo']}.mlflow" |
| mlflow.set_tracking_uri(uri) |
| mlflow.set_experiment(ml["experiment_name"]) |
| logger.info(f"MLflow → DagsHub: {uri}") |
|
|
|
|
| def log_epoch_metrics(history): |
| """Log per-epoch metrics to the active MLflow run.""" |
| for epoch, (ta, va, tl, vl) in enumerate(zip( |
| history.history["accuracy"], history.history["val_accuracy"], |
| history.history["loss"], history.history["val_loss"], |
| )): |
| mlflow.log_metrics({ |
| "train_accuracy" : float(ta), |
| "val_accuracy" : float(va), |
| "train_loss" : float(tl), |
| "val_loss" : float(vl), |
| }, step=epoch) |
|
|
|
|
| def get_callbacks(cfg: dict, checkpoint_path: str): |
| """ |
| Standard callbacks for all models. |
| save_format='h5' is required to avoid EagerTensor JSON |
| serialization crash with EfficientNet in TF 2.10. |
| """ |
| t = cfg["training"] |
| return [ |
| EarlyStopping( |
| monitor = "val_accuracy", |
| patience = t["early_stopping_patience"], |
| restore_best_weights= True, |
| verbose = 1, |
| ), |
| ReduceLROnPlateau( |
| monitor = "val_loss", |
| factor = t["reduce_lr_factor"], |
| patience = t["reduce_lr_patience"], |
| min_lr = t["min_lr"], |
| verbose = 1, |
| ), |
| ModelCheckpoint( |
| filepath = checkpoint_path, |
| monitor = "val_accuracy", |
| save_best_only = True, |
| save_format = "h5", |
| verbose = 0, |
| ), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def train_baseline(cfg, train_data, val_data): |
| image_size = tuple(cfg["data"]["image_size"]) |
| epochs = cfg["training"]["epochs"] |
| save_dir = cfg["models"]["save_dir"] |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| model = compile_model(build_baseline_cnn(image_size=image_size)) |
|
|
| with mlflow.start_run(run_name="Baseline_CNN"): |
| mlflow.log_params({ |
| "model_type" : "Baseline CNN", |
| "filters" : "32-64-128", |
| "optimizer" : "adam", |
| "lr" : 0.001, |
| "epochs" : epochs, |
| "batch_size" : cfg["data"]["batch_size"], |
| }) |
| history = model.fit( |
| train_data, epochs=epochs, validation_data=val_data, |
| callbacks=get_callbacks(cfg, f"{save_dir}/baseline_best.h5") |
| ) |
| log_epoch_metrics(history) |
| mlflow.log_metrics({ |
| "best_val_accuracy" : float(max(history.history["val_accuracy"])), |
| "best_val_loss" : float(min(history.history["val_loss"])), |
| }) |
| mlflow.keras.log_model(model, "baseline_cnn") |
| logger.info("Baseline CNN training complete.") |
| return model, history |
|
|
|
|
| def train_transfer_learning(cfg, train_data, val_data): |
| image_size = tuple(cfg["data"]["image_size"]) |
| epochs = cfg["training"]["epochs"] |
| save_dir = cfg["models"]["save_dir"] |
|
|
| model = compile_model(build_mobilenet_tl(image_size=image_size)) |
|
|
| with mlflow.start_run(run_name="Transfer_Learning_MobileNetV2"): |
| mlflow.log_params({ |
| "model_type" : "MobileNetV2 TL", |
| "base_frozen" : True, |
| "optimizer" : "adam", |
| "lr" : 0.001, |
| "epochs" : epochs, |
| }) |
| history = model.fit( |
| train_data, epochs=epochs, validation_data=val_data, |
| callbacks=get_callbacks(cfg, f"{save_dir}/tl_best.h5") |
| ) |
| log_epoch_metrics(history) |
| mlflow.log_metrics({ |
| "best_val_accuracy" : float(max(history.history["val_accuracy"])), |
| "best_val_loss" : float(min(history.history["val_loss"])), |
| }) |
| mlflow.keras.log_model(model, "transfer_learning") |
| logger.info("Transfer Learning training complete.") |
| return model, history |
|
|
|
|
| def train_finetuned(cfg, tl_model, train_data, val_data): |
| epochs = cfg["training"]["epochs"] |
| save_dir = cfg["models"]["save_dir"] |
|
|
| build_mobilenet_finetuned(tl_model.layers[0], unfreeze_last=20) |
| compile_model(tl_model, lr=1e-5) |
|
|
| with mlflow.start_run(run_name="Fine_Tuned_MobileNetV2"): |
| mlflow.log_params({ |
| "model_type" : "MobileNetV2 Fine-Tuned", |
| "unfrozen_layers" : 20, |
| "lr" : 1e-5, |
| "epochs" : epochs, |
| }) |
| history = tl_model.fit( |
| train_data, epochs=epochs, validation_data=val_data, |
| callbacks=get_callbacks(cfg, f"{save_dir}/ft_best.h5") |
| ) |
| log_epoch_metrics(history) |
| mlflow.log_metrics({ |
| "best_val_accuracy" : float(max(history.history["val_accuracy"])), |
| "best_val_loss" : float(min(history.history["val_loss"])), |
| }) |
| mlflow.keras.log_model(tl_model, "fine_tuned") |
| logger.info("Fine-Tuned training complete.") |
| return tl_model, history |
|
|
|
|
| def train_efficientnet(cfg, train_data, val_data): |
| """ |
| EfficientNetB0 in TF 2.10 crashes ModelCheckpoint because its internal |
| rescaling layer stores weights as EagerTensors which cannot be JSON-serialized |
| during checkpoint saving. |
| |
| Fix: use a custom callback that calls model.save_weights() instead of |
| model.save() — weights-only saving never touches the model config JSON, |
| so EagerTensors are never serialized. |
| """ |
| image_size = tuple(cfg["data"]["image_size"]) |
| epochs = cfg["training"]["epochs"] |
| save_dir = cfg["models"]["save_dir"] |
| t = cfg["training"] |
|
|
| model = build_efficientnet(image_size=image_size) |
| model.compile( |
| optimizer = optimizers.Adam(learning_rate=0.001), |
| loss = "categorical_crossentropy", |
| metrics = ["accuracy"], |
| ) |
|
|
| |
| weights_path = os.path.join(save_dir, "effnet_best_weights.h5") |
| best_val_acc = [0.0] |
|
|
| class WeightsCheckpoint(tf.keras.callbacks.Callback): |
| def on_epoch_end(self, epoch, logs=None): |
| va = float(logs.get("val_accuracy", 0.0)) |
| if va > best_val_acc[0]: |
| best_val_acc[0] = va |
| self.model.save_weights(weights_path) |
| logger.info(f" EfficientNet weights saved (val_acc={va:.4f})") |
|
|
| with mlflow.start_run(run_name="EfficientNetB0_TL"): |
| mlflow.log_params({ |
| "model_type" : "EfficientNetB0", |
| "base_frozen" : True, |
| "lr" : 0.001, |
| "epochs" : epochs, |
| }) |
|
|
| history = model.fit( |
| train_data, |
| epochs = epochs, |
| validation_data = val_data, |
| callbacks = [ |
| EarlyStopping( |
| monitor = "val_accuracy", |
| patience = t["early_stopping_patience"], |
| restore_best_weights = True, |
| verbose = 1, |
| ), |
| ReduceLROnPlateau( |
| monitor = "val_loss", |
| factor = t["reduce_lr_factor"], |
| patience = t["reduce_lr_patience"], |
| min_lr = t["min_lr"], |
| verbose = 1, |
| ), |
| WeightsCheckpoint(), |
| ], |
| ) |
|
|
| log_epoch_metrics(history) |
| mlflow.log_metrics({ |
| "best_val_accuracy" : float(max(history.history["val_accuracy"])), |
| "best_val_loss" : float(min(history.history["val_loss"])), |
| }) |
|
|
| |
| if os.path.exists(weights_path): |
| model.load_weights(weights_path) |
| logger.info("Best EfficientNet weights restored.") |
|
|
| |
| |
| |
| final_weights_path = os.path.join(save_dir, "effnet_final_weights.h5") |
| model.save_weights(final_weights_path) |
| mlflow.log_artifact(final_weights_path, artifact_path="efficientnet_weights") |
| mlflow.set_tag("efficientnet_note", |
| "Logged as weights-only artifact due to TF2.10 EagerTensor bug") |
| logger.info(f"EfficientNet weights logged to MLflow → {final_weights_path}") |
| logger.info("EfficientNetB0 training complete.") |
| return model, history |
|
|
|
|
| |
| |
| |
|
|
| def run_optuna(cfg, train_data, val_data): |
| image_size = tuple(cfg["data"]["image_size"]) |
| n_trials = cfg["optuna"]["n_trials"] |
|
|
| def objective(trial): |
| params = { |
| "filters_1" : trial.suggest_categorical("filters_1", [32, 64]), |
| "filters_2" : trial.suggest_categorical("filters_2", [64, 128]), |
| "filters_3" : trial.suggest_categorical("filters_3", [128, 256]), |
| "dense_units" : trial.suggest_categorical("dense_units", [64, 128, 256]), |
| "dropout" : trial.suggest_float("dropout", 0.2, 0.5), |
| "lr" : trial.suggest_float("lr", 1e-4, 1e-2, log=True), |
| } |
|
|
| with mlflow.start_run(run_name=f"trial_{trial.number:02d}", nested=True): |
| mlflow.log_params(params) |
|
|
| m = compile_model( |
| build_optuna_cnn(params, image_size=image_size), |
| lr=params["lr"] |
| ) |
| h = m.fit( |
| train_data, |
| epochs = 10, |
| validation_data = val_data, |
| callbacks = [EarlyStopping(monitor="val_accuracy", |
| patience=3, |
| restore_best_weights=True)], |
| verbose = 0, |
| ) |
|
|
| best_val = float(max(h.history["val_accuracy"])) |
| mlflow.log_metric("best_val_accuracy", best_val) |
|
|
| for ep, (ta, va, tl, vl) in enumerate(zip( |
| h.history["accuracy"], h.history["val_accuracy"], |
| h.history["loss"], h.history["val_loss"], |
| )): |
| mlflow.log_metrics({ |
| "train_accuracy" : float(ta), |
| "val_accuracy" : float(va), |
| "train_loss" : float(tl), |
| "val_loss" : float(vl), |
| }, step=ep) |
|
|
| return best_val |
|
|
| optuna.logging.set_verbosity(optuna.logging.WARNING) |
| study = optuna.create_study(direction="maximize") |
|
|
| with mlflow.start_run(run_name="Optuna_Search_Parent"): |
| mlflow.log_params({ |
| "n_trials" : n_trials, |
| "direction" : "maximize", |
| "objective" : "val_accuracy", |
| }) |
| study.optimize(objective, n_trials=n_trials) |
| mlflow.log_metric("best_val_accuracy", float(study.best_value)) |
| mlflow.log_metric("best_trial_number", study.best_trial.number) |
| mlflow.log_params({f"best_{k}": v for k, v in study.best_params.items()}) |
|
|
| logger.info(f"Optuna best val_accuracy : {study.best_value:.4f}") |
| logger.info(f"Optuna best params : {study.best_params}") |
| return study |
|
|
|
|
| def train_optuna_best(cfg, study, train_data, val_data): |
| image_size = tuple(cfg["data"]["image_size"]) |
| epochs = cfg["training"]["epochs"] |
| save_dir = cfg["models"]["save_dir"] |
| p = study.best_params |
|
|
| model = compile_model( |
| build_optuna_cnn(p, image_size=image_size), |
| lr=p["lr"] |
| ) |
|
|
| with mlflow.start_run(run_name="Optuna_Best_CNN_Final") as run: |
| mlflow.log_params({ |
| **p, |
| "model_type" : "Optuna Best CNN", |
| "epochs" : epochs, |
| }) |
| history = model.fit( |
| train_data, epochs=epochs, validation_data=val_data, |
| callbacks=get_callbacks(cfg, f"{save_dir}/optuna_best.h5") |
| ) |
| log_epoch_metrics(history) |
| mlflow.log_metrics({ |
| "best_val_accuracy" : float(max(history.history["val_accuracy"])), |
| "best_val_loss" : float(min(history.history["val_loss"])), |
| }) |
| sample = train_data[0][0][:1] |
| signature = infer_signature(sample, model.predict(sample)) |
| mlflow.keras.log_model(model, "optuna_best_cnn", signature=signature) |
| best_run_id = run.info.run_id |
|
|
| logger.info("Optuna Best CNN training complete.") |
| return model, history, best_run_id |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| cfg = load_config("config.yaml") |
| setup_mlflow(cfg) |
|
|
| train_data, val_data, test_data = get_data_generators(cfg) |
|
|
| logger.info("=" * 50) |
| logger.info("Starting full training pipeline") |
| logger.info("=" * 50) |
|
|
| baseline_model, history_base = train_baseline(cfg, train_data, val_data) |
| tl_model, history_tl = train_transfer_learning(cfg, train_data, val_data) |
| ft_model, history_ft = train_finetuned(cfg, tl_model, train_data, val_data) |
| eff_model, history_eff = train_efficientnet(cfg, train_data, val_data) |
| study = run_optuna(cfg, train_data, val_data) |
| optuna_model, history_optuna, best_run_id = train_optuna_best( |
| cfg, study, train_data, val_data |
| ) |
|
|
| results = { |
| "Baseline CNN" : float(max(history_base.history["val_accuracy"])), |
| "Transfer Learning" : float(max(history_tl.history["val_accuracy"])), |
| "Fine-Tuned" : float(max(history_ft.history["val_accuracy"])), |
| "EfficientNetB0" : float(max(history_eff.history["val_accuracy"])), |
| "Optuna Best CNN" : float(max(history_optuna.history["val_accuracy"])), |
| } |
|
|
| logger.info("\n" + "=" * 45) |
| logger.info(f"{'Model':<25} {'Val Acc':>10}") |
| logger.info("=" * 45) |
| for name, acc in results.items(): |
| mark = " << best" if acc == max(results.values()) else "" |
| logger.info(f"{name:<25} {acc:>10.4f}{mark}") |
| logger.info("=" * 45) |
|
|
| os.makedirs("./logs", exist_ok=True) |
| plot_comparison(results, save_path="./logs/model_comparison.png") |
| logger.info("Training pipeline complete.") |