| import joblib
|
| from flask import Flask, request, jsonify
|
| import pandas as pd
|
| import numpy as np
|
|
|
|
|
| app = Flask(__name__)
|
|
|
|
|
| try:
|
| model = joblib.load("xgboost_model.joblib")
|
| onehot_encoder = joblib.load("onehot_encoder.joblib")
|
| print("Model and encoder loaded successfully.")
|
| except FileNotFoundError:
|
| print("Error: Model or encoder file not found.")
|
| model = None
|
| onehot_encoder = None
|
|
|
|
|
| categorical_cols = [
|
| 'Product_Sugar_Content', 'Product_Type', 'Store_Size',
|
| 'Store_Location_City_Type', 'Store_Type'
|
| ]
|
|
|
|
|
| numerical_cols = [
|
| 'Product_Weight', 'Product_Allocated_Area', 'Product_MRP',
|
| 'Store_Establishment_Year'
|
| ]
|
|
|
| @app.route('/predict', methods=['POST'])
|
| def predict():
|
| """
|
| Endpoint to make predictions on new data.
|
| Input should be a JSON object with the following keys:
|
| - Product_Weight (float)
|
| - Product_Sugar_Content (string)
|
| - Product_Allocated_Area (float)
|
| - Product_Type (string)
|
| - Product_MRP (float)
|
| - Store_Establishment_Year (int)
|
| - Store_Size (string)
|
| - Store_Location_City_Type (string)
|
| - Store_Type (string)
|
| """
|
| if model is None or onehot_encoder is None:
|
| return jsonify({"error": "Model not loaded. Check server logs."}), 500
|
|
|
| try:
|
|
|
| data = request.get_json(silent=True)
|
| if not data:
|
| return jsonify({"error": "No data provided or invalid JSON format."}), 400
|
|
|
|
|
| input_df = pd.DataFrame([data])
|
|
|
|
|
| encoded_features = onehot_encoder.transform(input_df[categorical_cols]).toarray()
|
|
|
|
|
| encoded_df = pd.DataFrame(encoded_features, columns=onehot_encoder.get_feature_names_out(categorical_cols))
|
|
|
|
|
| final_df = pd.concat([input_df[numerical_cols], encoded_df], axis=1)
|
|
|
|
|
| prediction = model.predict(final_df)
|
|
|
|
|
| response = {
|
| "prediction": float(prediction[0])
|
| }
|
|
|
| return jsonify(response), 200
|
|
|
| except KeyError as e:
|
| return jsonify({"error": f"Missing feature in request: {e}"}), 400
|
| except Exception as e:
|
| return jsonify({"error": str(e)}), 500
|
|
|
| if __name__ == '__main__':
|
| app.run(host='0.0.0.0', port=5000, debug=False) |