| """ |
| Train a TensorFlow regression model to predict age from face images (UTKFace dataset). |
| |
| Usage: |
| - Put UTKFace images into a folder, e.g. data/UTKFace/ |
| - python train.py --dataset_dir data/UTKFace --epochs 30 --batch_size 32 |
| |
| The script extracts the age from the filename (before the first underscore). |
| """ |
|
|
| import os |
| import argparse |
| import random |
| import math |
| import zipfile |
| from pathlib import Path |
|
|
| import numpy as np |
| from tqdm import tqdm |
| import requests |
|
|
| import tensorflow as tf |
| from tensorflow import keras |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train an age regression model on UTKFace images") |
| parser.add_argument("--dataset_dir", type=str, default="data/UTKFace", help="Path to folder containing UTKFace images") |
| parser.add_argument("--img_size", type=int, default=224, help="Image size (square)") |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--epochs", type=int, default=30) |
| parser.add_argument("--val_split", type=float, default=0.12, help="Fraction to reserve for validation") |
| parser.add_argument("--learning_rate", type=float, default=1e-4) |
| parser.add_argument("--auto_download", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False, |
| help="Whether to attempt to download UTKFace archive automatically if dataset folder is missing") |
| parser.add_argument("--fine_tune", type=lambda x: (str(x).lower() in ("true", "1", "yes")), default=False, |
| help="Whether to unfreeze part of the backbone for fine-tuning") |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def attempt_download_utkface(dest_dir: Path): |
| """Attempt to download a ZIP archive of the UTKFace repository and extract it. |
| |
| This may fail if the remote hosting changes. The function attempts a best-effort download |
| from the repository URL commonly used to host UTKFace on GitHub. |
| """ |
| dest_dir.mkdir(parents=True, exist_ok=True) |
| github_zip = "https://github.com/susanqq/UTKFace/archive/refs/heads/master.zip" |
| tmp_zip = dest_dir / "utkface_master.zip" |
| print(f"Attempting to download UTKFace from {github_zip} ...") |
|
|
| try: |
| with requests.get(github_zip, stream=True, timeout=30) as r: |
| r.raise_for_status() |
| total = int(r.headers.get('content-length', 0)) |
| with open(tmp_zip, 'wb') as f: |
| for chunk in r.iter_content(chunk_size=8192): |
| if chunk: |
| f.write(chunk) |
|
|
| print("Download complete. Extracting archive...") |
| with zipfile.ZipFile(tmp_zip, 'r') as z: |
| z.extractall(dest_dir) |
|
|
| |
| extracted_root = None |
| for name in os.listdir(dest_dir): |
| if name.lower().startswith('utkface') and os.path.isdir(dest_dir / name): |
| extracted_root = dest_dir / name |
| break |
| if extracted_root: |
| images = list(extracted_root.rglob('*.jpg')) + list(extracted_root.rglob('*.png')) |
| for p in images: |
| target = dest_dir / p.name |
| try: |
| os.replace(p, target) |
| except Exception: |
| pass |
| |
| try: |
| os.remove(tmp_zip) |
| except Exception: |
| pass |
| print("UTKFace images should now be in:", dest_dir) |
| except Exception as e: |
| print("Automatic download failed:", e) |
| print("Please download the UTKFace archive manually and place images in the dataset directory.") |
|
|
|
|
| def collect_image_paths_and_labels(dataset_dir: Path): |
| |
| img_paths = [] |
| labels = [] |
| supported_ext = ('.jpg', '.jpeg', '.png') |
| for p in dataset_dir.iterdir(): |
| if p.is_file() and p.suffix.lower() in supported_ext: |
| |
| parts = p.name.split('_') |
| try: |
| age = int(parts[0]) |
| except Exception: |
| continue |
| img_paths.append(str(p)) |
| labels.append(age) |
| return img_paths, labels |
|
|
|
|
| def make_dataset(paths, labels, img_size, batch_size, is_training=True): |
| paths = tf.convert_to_tensor(paths) |
| labels = tf.convert_to_tensor(labels, dtype=tf.float32) |
|
|
| ds = tf.data.Dataset.from_tensor_slices((paths, labels)) |
| if is_training: |
| ds = ds.shuffle(10000, reshuffle_each_iteration=True) |
|
|
| def _load_image(path, label): |
| img = tf.io.read_file(path) |
| img = tf.image.decode_jpeg(img, channels=3) |
| img = tf.image.resize(img, [img_size, img_size]) |
| img = img / 255.0 |
| if is_training: |
| img = data_augmentation(img) |
| return img, label |
|
|
| ds = ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) |
| ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) |
| return ds |
|
|
|
|
| def data_augmentation(image): |
| |
| image = tf.image.random_flip_left_right(image) |
| image = tf.image.random_brightness(image, max_delta=0.08) |
| image = tf.image.random_contrast(image, 0.9, 1.1) |
| |
| if tf.random.uniform(()) > 0.6: |
| crop_frac = tf.random.uniform((), 0.8, 1.0) |
| shape = tf.shape(image) |
| crop_h = tf.cast(tf.cast(shape[0], tf.float32) * crop_frac, tf.int32) |
| crop_w = tf.cast(tf.cast(shape[1], tf.float32) * crop_frac, tf.int32) |
| image = tf.image.random_crop(image, size=[crop_h, crop_w, 3]) |
| image = tf.image.resize(image, [shape[0], shape[1]]) |
| return image |
|
|
|
|
| def build_model(img_size, fine_tune=False): |
| inputs = keras.Input(shape=(img_size, img_size, 3)) |
| base = keras.applications.MobileNetV2(include_top=False, input_tensor=inputs, weights='imagenet') |
| base.trainable = False |
|
|
| x = base.output |
| x = keras.layers.GlobalAveragePooling2D()(x) |
| x = keras.layers.Dropout(0.2)(x) |
| x = keras.layers.Dense(128, activation='relu')(x) |
| x = keras.layers.Dense(64, activation='relu')(x) |
| outputs = keras.layers.Dense(1, name='age')(x) |
|
|
| model = keras.Model(inputs=inputs, outputs=outputs) |
|
|
| if fine_tune: |
| |
| base.trainable = True |
| |
| for layer in base.layers[:-30]: |
| layer.trainable = False |
|
|
| return model |
|
|
|
|
| def main(): |
| args = parse_args() |
| dataset_dir = Path(args.dataset_dir) |
|
|
| if (not dataset_dir.exists() or not any(dataset_dir.iterdir())) and args.auto_download: |
| attempt_download_utkface(dataset_dir) |
|
|
| if not dataset_dir.exists() or not any(dataset_dir.iterdir()): |
| raise RuntimeError(f"No images found in {dataset_dir}. Place UTKFace images there or use --auto_download True to attempt download.") |
|
|
| paths, labels = collect_image_paths_and_labels(dataset_dir) |
| if len(paths) == 0: |
| raise RuntimeError("No valid UTKFace images found in dataset directory. Ensure the files follow the naming convention '<age>_...'.") |
|
|
| |
| paths = np.array(paths) |
| labels = np.array(labels, dtype=np.float32) |
|
|
| |
| indices = np.arange(len(paths)) |
| np.random.shuffle(indices) |
| paths = paths[indices] |
| labels = labels[indices] |
|
|
| n_val = max(1, int(len(paths) * args.val_split)) |
| val_paths = paths[:n_val].tolist() |
| val_labels = labels[:n_val].tolist() |
| train_paths = paths[n_val:].tolist() |
| train_labels = labels[n_val:].tolist() |
|
|
| print(f"Found {len(train_paths)} training images and {len(val_paths)} validation images.") |
|
|
| train_ds = make_dataset(train_paths, train_labels, args.img_size, args.batch_size, is_training=True) |
| val_ds = make_dataset(val_paths, val_labels, args.img_size, args.batch_size, is_training=False) |
|
|
| model = build_model(args.img_size, fine_tune=args.fine_tune) |
| model.compile(optimizer=keras.optimizers.Adam(learning_rate=args.learning_rate), |
| loss='mse', |
| metrics=[keras.metrics.MeanAbsoluteError(name='mae')]) |
|
|
| model.summary() |
|
|
| callbacks = [ |
| keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_loss'), |
| keras.callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True), |
| keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-7) |
| ] |
|
|
| history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, callbacks=callbacks) |
|
|
| |
| print("Evaluating on validation set:") |
| eval_res = model.evaluate(val_ds) |
| print(dict(zip(model.metrics_names, eval_res))) |
|
|
| |
| try: |
| |
| model.export('saved_model_age_regressor') |
| print('Exported SavedModel to ./saved_model_age_regressor') |
| except Exception as e: |
| print('SavedModel export failed:', e) |
| |
| try: |
| model.save('saved_model_age_regressor.keras') |
| print('Saved Keras model to ./saved_model_age_regressor.keras') |
| except Exception as e2: |
| print('Keras native save failed:', e2) |
| |
| try: |
| model.save('final_model.h5') |
| print('Saved HDF5 model to ./final_model.h5') |
| except Exception as e3: |
| print('HDF5 save failed:', e3) |
|
|
| |
| sample_paths = val_paths[:12] |
| sample_labels = val_labels[:12] |
|
|
| sample_ds = make_dataset(sample_paths, sample_labels, args.img_size, batch_size=12, is_training=False) |
| imgs, labs = next(iter(sample_ds)) |
| preds = model.predict(imgs).flatten() |
|
|
| try: |
| import matplotlib.pyplot as plt |
| n = len(preds) |
| cols = 4 |
| rows = math.ceil(n / cols) |
| plt.figure(figsize=(cols * 3, rows * 3)) |
| for i in range(n): |
| ax = plt.subplot(rows, cols, i + 1) |
| img = imgs[i].numpy() |
| plt.imshow(img) |
| plt.axis('off') |
| plt.title(f"True: {int(labs[i])}\nPred: {preds[i]:.1f}") |
| plt.tight_layout() |
| plt.show() |
| except Exception: |
| print("Matplotlib not available or running headless; skipping sample visualization.") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|