Spaces:
Sleeping
Sleeping
Update fastapi_app/main.py
Browse files- fastapi_app/main.py +91 -91
fastapi_app/main.py
CHANGED
|
@@ -1,91 +1,91 @@
|
|
| 1 |
-
from fastapi import FastAPI, Request
|
| 2 |
-
from fastapi.responses import HTMLResponse
|
| 3 |
-
from fastapi.staticfiles import StaticFiles
|
| 4 |
-
from pydantic import BaseModel
|
| 5 |
-
import uvicorn
|
| 6 |
-
import os, sys
|
| 7 |
-
|
| 8 |
-
# Add the root directory to sys.path
|
| 9 |
-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 10 |
-
from model_pipeline.model_predict import load_model, predict as initial_predict
|
| 11 |
-
from llama_pipeline.llama_predict import predict as llama_predict
|
| 12 |
-
from db_connection import insert_db
|
| 13 |
-
from logging_config.logger_config import get_logger
|
| 14 |
-
|
| 15 |
-
# Initialize the FastAPI app
|
| 16 |
-
app = FastAPI()
|
| 17 |
-
|
| 18 |
-
# Initialize the logger
|
| 19 |
-
logger = get_logger(__name__)
|
| 20 |
-
|
| 21 |
-
# Load the latest model at startup
|
| 22 |
-
model = load_model()
|
| 23 |
-
|
| 24 |
-
# Mount the static files directory
|
| 25 |
-
app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
|
| 26 |
-
|
| 27 |
-
@app.get("/", response_class=HTMLResponse)
|
| 28 |
-
def read_root():
|
| 29 |
-
with open("fastapi_app/static/index.html") as f:
|
| 30 |
-
html_content = f.read()
|
| 31 |
-
return HTMLResponse(content=html_content, status_code=200)
|
| 32 |
-
|
| 33 |
-
@app.get("/health")
|
| 34 |
-
def health_check():
|
| 35 |
-
logger.info("Health check endpoint accessed.")
|
| 36 |
-
return {"status": "ok"}
|
| 37 |
-
|
| 38 |
-
class TextInput(BaseModel):
|
| 39 |
-
text: str
|
| 40 |
-
|
| 41 |
-
class PredictionInput(BaseModel):
|
| 42 |
-
text: str
|
| 43 |
-
initial_prediction: str
|
| 44 |
-
llama_category: str
|
| 45 |
-
llama_explanation: str
|
| 46 |
-
user_rating: int
|
| 47 |
-
|
| 48 |
-
@app.post("/predict_sentiment")
|
| 49 |
-
def predict_sentiment(input_data: TextInput):
|
| 50 |
-
logger.info(f"Prediction request received with text: {input_data.text}")
|
| 51 |
-
|
| 52 |
-
# Initial model prediction
|
| 53 |
-
initial_prediction = initial_predict(input_data.text, model = model)
|
| 54 |
-
|
| 55 |
-
# LLaMA 3 prediction
|
| 56 |
-
llama_prediction = llama_predict(input_data.text)
|
| 57 |
-
|
| 58 |
-
# Prepare response
|
| 59 |
-
response = {
|
| 60 |
-
"text": input_data.text,
|
| 61 |
-
"initial_prediction": initial_prediction,
|
| 62 |
-
"llama_category": llama_prediction['Category'],
|
| 63 |
-
"llama_explanation": llama_prediction['Explanation']
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
logger.info(f"Prediction response: {response}")
|
| 67 |
-
return response
|
| 68 |
-
|
| 69 |
-
@app.post("/submit_interaction")
|
| 70 |
-
def submit_interaction(data: PredictionInput):
|
| 71 |
-
logger.info(f"Received interaction data: {data}")
|
| 72 |
-
logger.info(f"Received text: {data.text}")
|
| 73 |
-
logger.info(f"Received initial_prediction: {data.initial_prediction}")
|
| 74 |
-
logger.info(f"Received llama_category: {data.llama_category}")
|
| 75 |
-
logger.info(f"Received llama_explanation: {data.llama_explanation}")
|
| 76 |
-
logger.info(f"Received user_rating: {data.user_rating}")
|
| 77 |
-
|
| 78 |
-
interaction_data = {
|
| 79 |
-
"Input_text": data.text,
|
| 80 |
-
"Model_prediction": data.initial_prediction,
|
| 81 |
-
"Llama_3_Prediction": data.llama_category,
|
| 82 |
-
"Llama_3_Explanation": data.llama_explanation,
|
| 83 |
-
"User Rating": data.user_rating,
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
response = insert_db(interaction_data)
|
| 87 |
-
logger.info(f"Database response: {response}")
|
| 88 |
-
return {"status": "success", "response": response}
|
| 89 |
-
|
| 90 |
-
if __name__ == "__main__":
|
| 91 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request
|
| 2 |
+
from fastapi.responses import HTMLResponse
|
| 3 |
+
from fastapi.staticfiles import StaticFiles
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
import uvicorn
|
| 6 |
+
import os, sys
|
| 7 |
+
|
| 8 |
+
# Add the root directory to sys.path
|
| 9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 10 |
+
from model_pipeline.model_predict import load_model, predict as initial_predict
|
| 11 |
+
from llama_pipeline.llama_predict import predict as llama_predict
|
| 12 |
+
from db_connection import insert_db
|
| 13 |
+
from logging_config.logger_config import get_logger
|
| 14 |
+
|
| 15 |
+
# Initialize the FastAPI app
|
| 16 |
+
app = FastAPI()
|
| 17 |
+
|
| 18 |
+
# Initialize the logger
|
| 19 |
+
logger = get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
# Load the latest model at startup
|
| 22 |
+
model = load_model()
|
| 23 |
+
|
| 24 |
+
# Mount the static files directory
|
| 25 |
+
app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
|
| 26 |
+
|
| 27 |
+
@app.get("/", response_class=HTMLResponse)
|
| 28 |
+
def read_root():
|
| 29 |
+
with open("fastapi_app/static/index.html") as f:
|
| 30 |
+
html_content = f.read()
|
| 31 |
+
return HTMLResponse(content=html_content, status_code=200)
|
| 32 |
+
|
| 33 |
+
@app.get("/health")
|
| 34 |
+
def health_check():
|
| 35 |
+
logger.info("Health check endpoint accessed.")
|
| 36 |
+
return {"status": "ok"}
|
| 37 |
+
|
| 38 |
+
class TextInput(BaseModel):
|
| 39 |
+
text: str
|
| 40 |
+
|
| 41 |
+
class PredictionInput(BaseModel):
|
| 42 |
+
text: str
|
| 43 |
+
initial_prediction: str
|
| 44 |
+
llama_category: str
|
| 45 |
+
llama_explanation: str
|
| 46 |
+
user_rating: int
|
| 47 |
+
|
| 48 |
+
@app.post("/predict_sentiment")
|
| 49 |
+
def predict_sentiment(input_data: TextInput):
|
| 50 |
+
logger.info(f"Prediction request received with text: {input_data.text}")
|
| 51 |
+
|
| 52 |
+
# Initial model prediction
|
| 53 |
+
initial_prediction = initial_predict(input_data.text, model = model)
|
| 54 |
+
|
| 55 |
+
# LLaMA 3 prediction
|
| 56 |
+
llama_prediction = llama_predict(input_data.text)
|
| 57 |
+
|
| 58 |
+
# Prepare response
|
| 59 |
+
response = {
|
| 60 |
+
"text": input_data.text,
|
| 61 |
+
"initial_prediction": initial_prediction,
|
| 62 |
+
"llama_category": llama_prediction['Category'],
|
| 63 |
+
"llama_explanation": llama_prediction['Explanation']
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
logger.info(f"Prediction response: {response}")
|
| 67 |
+
return response
|
| 68 |
+
|
| 69 |
+
@app.post("/submit_interaction")
|
| 70 |
+
def submit_interaction(data: PredictionInput):
|
| 71 |
+
logger.info(f"Received interaction data: {data}")
|
| 72 |
+
logger.info(f"Received text: {data.text}")
|
| 73 |
+
logger.info(f"Received initial_prediction: {data.initial_prediction}")
|
| 74 |
+
logger.info(f"Received llama_category: {data.llama_category}")
|
| 75 |
+
logger.info(f"Received llama_explanation: {data.llama_explanation}")
|
| 76 |
+
logger.info(f"Received user_rating: {data.user_rating}")
|
| 77 |
+
|
| 78 |
+
interaction_data = {
|
| 79 |
+
"Input_text": data.text,
|
| 80 |
+
"Model_prediction": data.initial_prediction,
|
| 81 |
+
"Llama_3_Prediction": data.llama_category,
|
| 82 |
+
"Llama_3_Explanation": data.llama_explanation,
|
| 83 |
+
"User Rating": data.user_rating,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
response = insert_db(interaction_data)
|
| 87 |
+
logger.info(f"Database response: {response}")
|
| 88 |
+
return {"status": "success", "response": response}
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|