| """ |
| Load a trained age regression model and run a prediction on a single image. |
| |
| Usage: python predict.py --model_path saved_model_age_regressor --image_path some_image.jpg |
| """ |
| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image |
| import tensorflow as tf |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model_path', type=str, default='saved_model_age_regressor') |
| parser.add_argument('--image_path', type=str, required=True) |
| parser.add_argument('--img_size', type=int, default=224) |
| parser.add_argument('--output_key', type=str, default=None, |
| help='If the model returns a dict, select this key for the numeric prediction. If omitted the first numeric output will be used.') |
| return parser.parse_args() |
|
|
|
|
| def load_image(path, img_size): |
| img = Image.open(path).convert('RGB') |
| img = img.resize((img_size, img_size)) |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| return arr |
|
|
|
|
| def main(): |
| args = parse_args() |
| model_path = Path(args.model_path) |
| |
| if model_path.is_file() and model_path.suffix.lower() in ('.h5', '.keras'): |
| model = tf.keras.models.load_model(str(model_path), compile=False) |
| print(f"Loaded Keras model file: {model_path}") |
| elif model_path.is_dir(): |
| |
| |
| try: |
| model = tf.keras.models.load_model(str(model_path), compile=False) |
| print(f"Loaded Keras-compatible model from directory: {model_path}") |
| except Exception: |
| |
| try: |
| tf_layer = tf.keras.layers.TFSMLayer(str(model_path), call_endpoint='serving_default') |
| model = tf.keras.Sequential([ |
| tf.keras.Input(shape=(args.img_size, args.img_size, 3)), |
| tf_layer, |
| ]) |
| print(f"Wrapped TensorFlow SavedModel at {model_path} with TFSMLayer (serving_default).") |
| except Exception as e: |
| raise RuntimeError(f"Failed to load or wrap SavedModel directory '{model_path}': {e}") |
| else: |
| |
| model = tf.keras.models.load_model(str(model_path), compile=False) |
| print(f"Loaded model from path: {model_path}") |
| image_path = Path(args.image_path) |
| if not image_path.exists(): |
| raise FileNotFoundError(f"Image not found: {image_path}") |
| x = load_image(image_path, args.img_size) |
| x = np.expand_dims(x, axis=0) |
| pred = model.predict(x) |
|
|
| |
| |
| if isinstance(pred, dict): |
| if args.output_key: |
| if args.output_key not in pred: |
| raise KeyError(f"Requested output key '{args.output_key}' not found. Available keys: {list(pred.keys())}") |
| chosen = pred[args.output_key] |
| else: |
| first_key = next(iter(pred.keys())) |
| print(f"No --output_key provided; using first output key: '{first_key}'") |
| chosen = pred[first_key] |
| arr = np.asarray(chosen) |
| else: |
| arr = np.asarray(pred) |
|
|
| if arr.size == 0: |
| raise ValueError("Model returned an empty prediction.") |
| age_pred = float(arr.flatten()[0]) |
| print(f"Predicted age: {age_pred:.2f} years") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|