| import json |
| import logging |
|
|
| from magic.conversational import question_handler |
| from schema import Answer |
|
|
| logger = logging.getLogger() |
| logging.basicConfig(encoding='utf-8', level=logging.INFO) |
|
|
| from fastapi import FastAPI, Request |
| from fastapi.responses import HTMLResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.templating import Jinja2Templates |
| from fastapi.staticfiles import StaticFiles |
| from starlette.routing import WebSocketRoute |
| import uvicorn |
|
|
| from schema import Question |
| from config import State |
| import arel |
|
|
| |
|
|
| |
| async def reload_data(): |
| print("Reloading server data...") |
|
|
| static_path = "static" |
| template_path = "templates" |
|
|
| hotreload = arel.HotReload( |
| paths=[ |
| arel.Path(static_path), |
| arel.Path(template_path), |
| ], |
| ) |
|
|
| state = State() |
|
|
| app = FastAPI( |
| routes=[WebSocketRoute("/hot-reload", hotreload, name="hot-reload")], |
| on_startup=[hotreload.startup], |
| on_shutdown=[hotreload.shutdown], |
| ) |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
| templates = Jinja2Templates(directory=template_path) |
| templates.env.globals["DEBUG"] = True |
| templates.env.globals["hotreload"] = hotreload |
|
|
| |
| origins = [ |
| "http://localhost", |
| "http://localhost:5001", |
| ] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| with open("test_response.json", "r") as f: |
| test_questions = json.load(f) |
|
|
| |
|
|
| @app.get("/", response_class=HTMLResponse) |
| def index(request: Request): |
| return templates.TemplateResponse( |
| "index.html", |
| {"request": request} |
| ) |
|
|
| @app.post("/api/ask", response_class=JSONResponse) |
| async def ask(question: Question): |
| question_str = question.question |
| if question_str in test_questions: |
| return test_questions[question_str] |
| |
| answer = None |
| try: |
| answer: Answer = await question_handler(question, state) |
| except Exception as e: |
| logger.exception(e) |
| if not answer: |
| return JSONResponse(status_code=404, content={"answer": "Something went wrong."}) |
| return answer.dict(include={"answer", "urls"}) |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run("kth_qa:app", host="localhost", port=5001, reload=True, reload_excludes=['files/', 'logs/'], reload_dirs=['/templates', '/static']) |