File size: 2,897 Bytes
f83e60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from dotenv import load_dotenv

from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from sentence_transformers import CrossEncoder

# Setup Configuration
CHROMA_DB_DIR = "vectorstore"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
LLM_MODEL = "llama-3.1-8b-instant"  # Use a currently active Groq model

def main():
    load_dotenv()
    
    # 1. Initialize embeddings and reload the vector store
    print("Loading vector store & embedding model...")
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
    vectorstore = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings)
    
    # 2. Setup the base retriever to get top k=5 chunks
    base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

    # 3. Setup ReRanker for relevance ordering
    print("Initializing CrossEncoder ReRanker...")
    cross_encoder = CrossEncoder(RERANKER_MODEL)

    # 4. Craft strict RAG prompt
    template = """You are a factual assistant. Answer ONLY using the context below.
If the answer isn't in the context, say "I don't know."
Context: {context}
Question: {question}"""
    prompt = PromptTemplate.from_template(template)

    # 5. Initialize the Groq LLM
    print("Initializing LLM via Groq...")
    if not os.environ.get("GROQ_API_KEY"):
        print("ERROR: GROQ_API_KEY not found in environment!")
        return
        
    llm = ChatGroq(model_name=LLM_MODEL, temperature=0)

    # The query workflow
    query = "What is the company policy for remote work?"
    print(f"\nQUERY: {query}\n")

    print("Retrieving and re-ranking documents...")
    initial_docs = base_retriever.invoke(query)
    
    # Apply CrossEncoder manually
    pairs = [[query, doc.page_content] for doc in initial_docs]
    scores = cross_encoder.predict(pairs)
    
    # Attach scores and sort
    for doc, score in zip(initial_docs, scores):
        doc.metadata['relevance_score'] = score
    
    # Sort docs by score descending and take top 3
    initial_docs.sort(key=lambda d: d.metadata['relevance_score'], reverse=True)
    top_docs = initial_docs[:3]
    
    # Format the context text from the retrieved docs
    context_text = "\n\n".join([doc.page_content for doc in top_docs])
    
    print("Generating response...")
    # Format prompt and call LLM
    chain = prompt | llm
    response = chain.invoke({"context": context_text, "question": query})

    print("\n--- FINAL ANSWER ---")
    print(response.content)
    print("\n--- SOURCES ---")
    for idx, doc in enumerate(top_docs):
        print(f"\n[Source {idx+1}] Score: {doc.metadata.get('relevance_score'):.4f}")
        print(doc.page_content[:150] + "...")

if __name__ == "__main__":
    main()