| from fastapi import FastAPI, WebSocket, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse |
| from pydantic import BaseModel |
| from typing import Optional, List, Dict, Any |
| import os |
| import logging |
| from dotenv import load_dotenv |
| from chatbot import MentalHealthChatbot |
| from datetime import datetime |
| import json |
| import uvicorn |
| import torch |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| load_dotenv() |
|
|
| |
| app = FastAPI( |
| title="Mental Health Chatbot", |
| description="mental health support chatbot", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| chatbot = MentalHealthChatbot( |
| model_name="meta-llama/Llama-3.2-3B-Instruct", |
| peft_model_path="nada013/mental-health-chatbot", |
| use_4bit=True, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| therapy_guidelines_path="guidelines.txt" |
| ) |
|
|
| |
| if torch.cuda.is_available(): |
| logger.info(f"GPU Device: {torch.cuda.get_device_name(0)}") |
| logger.info(f"Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") |
|
|
| |
| class MessageRequest(BaseModel): |
| user_id: str |
| message: str |
|
|
| class MessageResponse(BaseModel): |
| response: str |
| session_id: str |
|
|
| class SessionSummary(BaseModel): |
| session_id: str |
| user_id: str |
| start_time: str |
| end_time: str |
| duration_minutes: float |
| current_phase: str |
| primary_emotions: List[str] |
| emotion_progression: List[str] |
| summary: str |
| recommendations: List[str] |
| session_characteristics: Dict[str, Any] |
|
|
| class UserReply(BaseModel): |
| text: str |
| timestamp: str |
| session_id: str |
|
|
| class Message(BaseModel): |
| text: str |
| role: str = "user" |
|
|
| |
| @app.get("/") |
| async def root(): |
| """Root endpoint with API information.""" |
| return { |
| "name": "Mental Health Chatbot API", |
| "version": "1.0.0", |
| "description": "API for mental health support chatbot", |
| "endpoints": { |
| "POST /start_session": "Start a new chat session", |
| "POST /send_message": "Send a message to the chatbot", |
| "POST /end_session": "End the current session", |
| "GET /health": "Health check endpoint", |
| "GET /docs": "API documentation (Swagger UI)", |
| "GET /redoc": "API documentation (ReDoc)", |
| "GET /ws": "WebSocket endpoint" |
| } |
| } |
|
|
| @app.post("/start_session", response_model=MessageResponse) |
| async def start_session(user_id: str): |
| try: |
| session_id, initial_message = chatbot.start_session(user_id) |
| return MessageResponse(response=initial_message, session_id=session_id) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/send_message", response_model=MessageResponse) |
| async def send_message(request: MessageRequest): |
| try: |
| |
| if request.user_id not in chatbot.conversations or not chatbot.conversations[request.user_id].is_active: |
| |
| session_id, _ = chatbot.start_session(request.user_id) |
| logger.info(f"Started new session {session_id} for user {request.user_id} during message send") |
| |
| |
| response = chatbot.process_message(request.user_id, request.message) |
| session = chatbot.conversations[request.user_id] |
| return MessageResponse(response=response, session_id=session.session_id) |
| except Exception as e: |
| logger.error(f"Error processing message for user {request.user_id}: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/end_session", response_model=SessionSummary) |
| async def end_session(user_id: str): |
| try: |
| summary = chatbot.end_session(user_id) |
| if not summary: |
| raise HTTPException(status_code=404, detail="No active session found") |
| return summary |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "healthy"} |
|
|
| @app.get("/session_summary/{session_id}", response_model=SessionSummary) |
| async def get_session_summary( |
| session_id: str, |
| include_summary: bool = True, |
| include_recommendations: bool = True, |
| include_emotions: bool = True, |
| include_characteristics: bool = True, |
| include_duration: bool = True, |
| include_phase: bool = True |
| ): |
| try: |
| summary = chatbot.get_session_summary(session_id) |
| if not summary: |
| raise HTTPException(status_code=404, detail="Session summary not found") |
| |
| filtered_summary = { |
| "session_id": summary["session_id"], |
| "user_id": summary["user_id"], |
| "start_time": summary["start_time"], |
| "end_time": summary["end_time"], |
| "duration_minutes": summary.get("duration_minutes", 0.0), |
| "current_phase": summary.get("current_phase", "unknown"), |
| "primary_emotions": summary.get("primary_emotions", []), |
| "emotion_progression": summary.get("emotion_progression", []), |
| "summary": summary.get("summary", ""), |
| "recommendations": summary.get("recommendations", []), |
| "session_characteristics": summary.get("session_characteristics", {}) |
| } |
| |
| |
| if not include_summary: |
| filtered_summary["summary"] = "" |
| if not include_recommendations: |
| filtered_summary["recommendations"] = [] |
| if not include_emotions: |
| filtered_summary["primary_emotions"] = [] |
| filtered_summary["emotion_progression"] = [] |
| if not include_characteristics: |
| filtered_summary["session_characteristics"] = {} |
| if not include_duration: |
| filtered_summary["duration_minutes"] = 0.0 |
| if not include_phase: |
| filtered_summary["current_phase"] = "unknown" |
| |
| return filtered_summary |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/user_replies/{user_id}") |
| async def get_user_replies(user_id: str): |
| try: |
| replies = chatbot.get_user_replies(user_id) |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"user_replies_{user_id}_{timestamp}.json" |
| filepath = os.path.join("user_replies", filename) |
| |
| |
| os.makedirs("user_replies", exist_ok=True) |
| |
| |
| with open(filepath, 'w') as f: |
| json.dump({ |
| "user_id": user_id, |
| "timestamp": datetime.now().isoformat(), |
| "replies": replies |
| }, f, indent=2) |
| |
| |
| return FileResponse( |
| path=filepath, |
| filename=filename, |
| media_type="application/json" |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| try: |
| while True: |
| data = await websocket.receive_json() |
| user_id = data.get("user_id") |
| message = data.get("message") |
| |
| if not user_id or not message: |
| await websocket.send_json({"error": "Missing user_id or message"}) |
| continue |
| |
| response = chatbot.process_message(user_id, message) |
| session_id = chatbot.conversations[user_id].session_id |
| |
| await websocket.send_json({ |
| "response": response, |
| "session_id": session_id |
| }) |
| except Exception as e: |
| await websocket.send_json({"error": str(e)}) |
| finally: |
| await websocket.close() |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |