Spaces:
Sleeping
Sleeping
| # rag_core.py | |
| from langchain_community.embeddings.fastembed import FastEmbedEmbeddings | |
| from langchain_chroma import Chroma | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_groq import ChatGroq | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import os | |
| # Configuration | |
| persist_directory = "./chroma_storage" | |
| embedding_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
| vectorstore = Chroma( | |
| embedding_function=embedding_model, | |
| persist_directory=persist_directory | |
| ) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
| chat_model = ChatGroq( | |
| temperature=0.3, | |
| model_name="llama-3.1-8b-instant", | |
| api_key=os.getenv("groqapi_key"), | |
| ) | |
| # Prompt RAG | |
| rag_template = """\ | |
| Use the following context to answer the user's query. If you cannot answer, please respond with 'I don't know'. | |
| User's Query: | |
| {question} | |
| Context: | |
| {context} | |
| """ | |
| rag_prompt = ChatPromptTemplate.from_template(rag_template) | |
| # SentenceTransformer pour la similarité (si besoin) | |
| similarity_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| def calculate_similarity(question, document): | |
| q_emb = similarity_model.encode(question, convert_to_tensor=True).cpu().detach().numpy() | |
| d_emb = similarity_model.encode(document, convert_to_tensor=True).cpu().detach().numpy() | |
| return cosine_similarity([q_emb], [d_emb])[0][0] | |
| # Génération de sous-requêtes | |
| def generate_queries(query: str, llm, num_queries: int = 4): | |
| query_gen_str = """\ | |
| You are a helpful assistant that generates multiple search queries based on a \ | |
| single input query. Generate {num_queries} search queries, one on each line, \ | |
| related to the following input query: | |
| Query: {query} | |
| Queries: | |
| """ | |
| query_prompt = ChatPromptTemplate.from_template(query_gen_str) | |
| formatted_prompt = query_prompt.format(num_queries=num_queries, query=query) | |
| response = llm.predict(formatted_prompt) | |
| return response.strip().splitlines() | |
| # Récupération de contexte enrichi | |
| def get_context(query): | |
| sub_queries = generate_queries(query, chat_model) | |
| chunks = [retriever.invoke(q) for q in sub_queries] | |
| return "\n".join(map(str, chunks)) | |
| # La chaîne complète | |
| rag_chain = ( | |
| {"context": get_context, "question": RunnablePassthrough()} | |
| | rag_prompt | |
| | chat_model | |
| | StrOutputParser() | |
| ) | |