telecom_api / rag_core.py
makhtar7186's picture
Update rag_core.py
c16f1b2 verified
# rag_core.py
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import os
# Configuration
persist_directory = "./chroma_storage"
embedding_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
vectorstore = Chroma(
embedding_function=embedding_model,
persist_directory=persist_directory
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
chat_model = ChatGroq(
temperature=0.3,
model_name="llama-3.1-8b-instant",
api_key=os.getenv("groqapi_key"),
)
# Prompt RAG
rag_template = """\
Use the following context to answer the user's query. If you cannot answer, please respond with 'I don't know'.
User's Query:
{question}
Context:
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(rag_template)
# SentenceTransformer pour la similarité (si besoin)
similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
def calculate_similarity(question, document):
q_emb = similarity_model.encode(question, convert_to_tensor=True).cpu().detach().numpy()
d_emb = similarity_model.encode(document, convert_to_tensor=True).cpu().detach().numpy()
return cosine_similarity([q_emb], [d_emb])[0][0]
# Génération de sous-requêtes
def generate_queries(query: str, llm, num_queries: int = 4):
query_gen_str = """\
You are a helpful assistant that generates multiple search queries based on a \
single input query. Generate {num_queries} search queries, one on each line, \
related to the following input query:
Query: {query}
Queries:
"""
query_prompt = ChatPromptTemplate.from_template(query_gen_str)
formatted_prompt = query_prompt.format(num_queries=num_queries, query=query)
response = llm.predict(formatted_prompt)
return response.strip().splitlines()
# Récupération de contexte enrichi
def get_context(query):
sub_queries = generate_queries(query, chat_model)
chunks = [retriever.invoke(q) for q in sub_queries]
return "\n".join(map(str, chunks))
# La chaîne complète
rag_chain = (
{"context": get_context, "question": RunnablePassthrough()}
| rag_prompt
| chat_model
| StrOutputParser()
)