| from tensorflow.keras.models import load_model |
| from fastapi import FastAPI, UploadFile, File |
| import numpy as np |
| from PIL import Image |
| import io |
|
|
| app = FastAPI() |
|
|
| |
| model = load_model("super_resolution_model.h5", custom_objects={"mse": "mse"}) |
|
|
| def preprocess_image(image: Image.Image): |
| |
| |
| image = image.resize((64, 64)) |
| image_array = np.array(image) / 255.0 |
| return np.expand_dims(image_array, axis=0) |
|
|
| def postprocess_image(output_array: np.ndarray): |
| |
| |
| output_array = np.clip(output_array[0] * 255.0, 0, 255).astype("uint8") |
| return Image.fromarray(output_array) |
|
|
| @app.post("/predict") |
| async def predict(file: UploadFile = File(...)): |
| |
| image = Image.open(io.BytesIO(await file.read())).convert("RGB") |
| input_array = preprocess_image(image) |
| |
| |
| high_res = model.predict(input_array) |
| |
| |
| high_res_image = postprocess_image(high_res) |
| |
| |
| output = io.BytesIO() |
| high_res_image.save(output, format="PNG") |
| output.seek(0) |
| |
| return {"image": output.getvalue()} |
|
|
| @app.get("/") |
| def root(): |
| return {"message": "Super-resolution API"} |