File size: 2,957 Bytes
ee05825
 
 
 
 
 
 
 
0b96936
ee05825
 
 
0b96936
 
 
 
 
ee05825
 
 
 
 
 
 
 
 
 
 
 
 
0b96936
 
ee05825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b41664
 
 
 
ee05825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102


from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests as req
import pandas as pd
import chromadb
from sentence_transformers import SentenceTransformer
from huggingface_hub import snapshot_download , login
import os
from dotenv import load_dotenv

import os

HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)

GROQ_API_KEY =  os.getenv("GROQ_API_KEY")
GROQ_URL       = "https://api.groq.com/openai/v1/chat/completions"
GROQ_MODEL     = "llama-3.1-8b-instant"
CHROMA_DIR     = "chroma_db/chroma_db"
HF_DATASET_ID  = "Mohammedelhakim/parenting-qa-vectordb"
NUM_RESULTS    = 3


if not os.path.exists(CHROMA_DIR):
    print("Downloading vector database from Hugging Face Hub...")
    snapshot_download(
        repo_id="Mohammedelhakim/parenting-qa-vectordb",
        repo_type="dataset",
        local_dir="chroma_db",
        token=HF_TOKEN
       
)
    print("Vector database downloaded!")
else:
    print("Vector database found locally, loading...")


print("Loading embedding model...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")

chroma_client = chromadb.PersistentClient(path=CHROMA_DIR)
print(chroma_client.list_collections())
collection    = chroma_client.get_collection("parenting_qa")
print("Ready!")


app = FastAPI()

@app.get("/")
def root():
    return {"status": "ok"}

class ChatRequest(BaseModel):
    message: str

@app.post("/chat")
def chat(request: ChatRequest):
    # 1. Embed the user question
    query_embedding = embedder.encode(request.message).tolist()

    # 2. Retrieve most similar Q&As
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=NUM_RESULTS,
    )

    # 3. Build context
    context = ""
    for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
        context += f"Q: {doc}\nA: {meta['answer']}\n\n"

    # 4. Call Groq
    messages = [
        {"role": "system", "content": (
            "You are a helpful and caring parenting assistant. "
            "You answer questions from parents about their babies and children, "
            "such as health, development, feeding, sleep, and behavior. "
            "Give clear, reassuring, and practical answers. "
            "If something sounds medically serious, always advise the parent to consult a doctor. "
            "Use the following similar Q&A examples to guide your answer:\n\n"
            + context
        )},
        {"role": "user", "content": request.message},
    ]

    response = req.post(
        GROQ_URL,
        headers={
            "Authorization": f"Bearer {GROQ_API_KEY}",
            "Content-Type": "application/json",
        },
        json={"model": GROQ_MODEL, "messages": messages},
        timeout=30,
    )

    result = response.json()
    if "choices" not in result:
        raise HTTPException(status_code=500, detail=f"Groq API error: {result}")

    return {"response": result["choices"][0]["message"]["content"]}