# rag_core.py from langchain_community.embeddings.fastembed import FastEmbedEmbeddings from langchain_chroma import Chroma from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_groq import ChatGroq from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import os # Configuration persist_directory = "./chroma_storage" embedding_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") vectorstore = Chroma( embedding_function=embedding_model, persist_directory=persist_directory ) retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) chat_model = ChatGroq( temperature=0.3, model_name="llama-3.1-8b-instant", api_key=os.getenv("groqapi_key"), ) # Prompt RAG rag_template = """\ Use the following context to answer the user's query. If you cannot answer, please respond with 'I don't know'. User's Query: {question} Context: {context} """ rag_prompt = ChatPromptTemplate.from_template(rag_template) # SentenceTransformer pour la similarité (si besoin) similarity_model = SentenceTransformer("all-MiniLM-L6-v2") def calculate_similarity(question, document): q_emb = similarity_model.encode(question, convert_to_tensor=True).cpu().detach().numpy() d_emb = similarity_model.encode(document, convert_to_tensor=True).cpu().detach().numpy() return cosine_similarity([q_emb], [d_emb])[0][0] # Génération de sous-requêtes def generate_queries(query: str, llm, num_queries: int = 4): query_gen_str = """\ You are a helpful assistant that generates multiple search queries based on a \ single input query. Generate {num_queries} search queries, one on each line, \ related to the following input query: Query: {query} Queries: """ query_prompt = ChatPromptTemplate.from_template(query_gen_str) formatted_prompt = query_prompt.format(num_queries=num_queries, query=query) response = llm.predict(formatted_prompt) return response.strip().splitlines() # Récupération de contexte enrichi def get_context(query): sub_queries = generate_queries(query, chat_model) chunks = [retriever.invoke(q) for q in sub_queries] return "\n".join(map(str, chunks)) # La chaîne complète rag_chain = ( {"context": get_context, "question": RunnablePassthrough()} | rag_prompt | chat_model | StrOutputParser() )