from tensorflow.keras.models import load_model from fastapi import FastAPI, UploadFile, File import numpy as np from PIL import Image import io app = FastAPI() # Load the model with custom mse model = load_model("super_resolution_model.h5", custom_objects={"mse": "mse"}) def preprocess_image(image: Image.Image): # Adjust based on your model's input requirements # Example: Resize to 64x64, normalize to [0, 1] image = image.resize((64, 64)) # Replace with your model's input size image_array = np.array(image) / 255.0 return np.expand_dims(image_array, axis=0) def postprocess_image(output_array: np.ndarray): # Adjust based on your model's output # Example: Clip values, convert to uint8 output_array = np.clip(output_array[0] * 255.0, 0, 255).astype("uint8") return Image.fromarray(output_array) @app.post("/predict") async def predict(file: UploadFile = File(...)): # Read and preprocess image image = Image.open(io.BytesIO(await file.read())).convert("RGB") input_array = preprocess_image(image) # Run inference high_res = model.predict(input_array) # Postprocess output high_res_image = postprocess_image(high_res) # Save to bytes output = io.BytesIO() high_res_image.save(output, format="PNG") output.seek(0) return {"image": output.getvalue()} @app.get("/") def root(): return {"message": "Super-resolution API"}