| from tensorflow import keras |
| import numpy as np |
|
|
|
|
| # import tensorflow as tf |
|
|
| # loaded = tf.saved_model.load('rwthmaterials_dp800_network1_inclusion') |
| # print("Available endpoints:", list(loaded.signatures.keys())) |
|
|
| # Load the model |
|
|
| model = keras.models.load_model('rwthmaterials_dp800_network1_inclusion.keras') |
|
|
| # Inspect model inputs and outputs |
| print("Model Summary:") |
| model.summary() |
|
|
| print("Inputs:") |
| for i, input_tensor in enumerate(model.inputs): |
| print(f"Input {i+1}: name={input_tensor.name}, shape={input_tensor.shape}") |
|
|
| print("Outputs:") |
| for i, output_tensor in enumerate(model.outputs): |
| print(f"Output {i+1}: name={output_tensor.name}, shape={output_tensor.shape}") |
|
|
| # Generate a wrapper function based on input count |
| def generate_wrapper(model): |
| def wrapper(*args): |
| # Convert inputs to numpy arrays and reshape if needed |
| processed_inputs = [] |
| for i, input_tensor in enumerate(model.inputs): |
| shape = input_tensor.shape |
| # Replace None with 1 for batch dimension |
| input_shape = [dim if dim is not None else 1 for dim in shape] |
| arr = np.array(args[i]).reshape(input_shape) |
| processed_inputs.append(arr) |
| # Predict |
| prediction = model.predict(processed_inputs) |
| return prediction.tolist() |
| return wrapper |
|
|
| # Create the wrapper |
| predict_fn = generate_wrapper(model) |
|
|
| # Example usage with dummy data |
| # Replace with actual input data when integrating with Gradio |
| # dummy_input1 = np.random.rand(1, 6, 6, 2048) |
| # dummy_input2 = np.random.rand(1, 6, 6, 2048) |
| # print(predict_fn(dummy_input1, dummy_input2)) |
|
|