CSE445 / app.py
MisbahKhan's picture
Create app.py
18c0aec verified
from tensorflow.keras.models import load_model
from fastapi import FastAPI, UploadFile, File
import numpy as np
from PIL import Image
import io
app = FastAPI()
# Load the model with custom mse
model = load_model("super_resolution_model.h5", custom_objects={"mse": "mse"})
def preprocess_image(image: Image.Image):
# Adjust based on your model's input requirements
# Example: Resize to 64x64, normalize to [0, 1]
image = image.resize((64, 64)) # Replace with your model's input size
image_array = np.array(image) / 255.0
return np.expand_dims(image_array, axis=0)
def postprocess_image(output_array: np.ndarray):
# Adjust based on your model's output
# Example: Clip values, convert to uint8
output_array = np.clip(output_array[0] * 255.0, 0, 255).astype("uint8")
return Image.fromarray(output_array)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# Read and preprocess image
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
input_array = preprocess_image(image)
# Run inference
high_res = model.predict(input_array)
# Postprocess output
high_res_image = postprocess_image(high_res)
# Save to bytes
output = io.BytesIO()
high_res_image.save(output, format="PNG")
output.seek(0)
return {"image": output.getvalue()}
@app.get("/")
def root():
return {"message": "Super-resolution API"}