File size: 2,476 Bytes
c16f1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# rag_core.py
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

import os


# Configuration
persist_directory = "./chroma_storage"
embedding_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
vectorstore = Chroma(
    embedding_function=embedding_model,
    persist_directory=persist_directory
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})

chat_model = ChatGroq(
    temperature=0.3,
    model_name="llama-3.1-8b-instant",
    api_key=os.getenv("groqapi_key"),
)

# Prompt RAG
rag_template = """\
Use the following context to answer the user's query. If you cannot answer, please respond with 'I don't know'.

User's Query:
{question}

Context:
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(rag_template)

# SentenceTransformer pour la similarité (si besoin)
similarity_model = SentenceTransformer("all-MiniLM-L6-v2")

def calculate_similarity(question, document):
    q_emb = similarity_model.encode(question, convert_to_tensor=True).cpu().detach().numpy()
    d_emb = similarity_model.encode(document, convert_to_tensor=True).cpu().detach().numpy()
    return cosine_similarity([q_emb], [d_emb])[0][0]

# Génération de sous-requêtes
def generate_queries(query: str, llm, num_queries: int = 4):
    query_gen_str = """\
You are a helpful assistant that generates multiple search queries based on a \
single input query. Generate {num_queries} search queries, one on each line, \
related to the following input query:
Query: {query}
Queries:
"""
    query_prompt = ChatPromptTemplate.from_template(query_gen_str)
    formatted_prompt = query_prompt.format(num_queries=num_queries, query=query)
    response = llm.predict(formatted_prompt)
    return response.strip().splitlines()

# Récupération de contexte enrichi
def get_context(query):
    sub_queries = generate_queries(query, chat_model)
    chunks = [retriever.invoke(q) for q in sub_queries]
    return "\n".join(map(str, chunks))

# La chaîne complète
rag_chain = (
    {"context": get_context, "question": RunnablePassthrough()}
    | rag_prompt
    | chat_model
    | StrOutputParser()
)