makhtar7186 commited on
Commit
c16f1b2
·
verified ·
1 Parent(s): 28f00ea

Update rag_core.py

Browse files
Files changed (1) hide show
  1. rag_core.py +75 -76
rag_core.py CHANGED
@@ -1,76 +1,75 @@
1
- # rag_core.py
2
- from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
3
- from langchain_chroma import Chroma
4
- from langchain_core.prompts import ChatPromptTemplate
5
- from langchain_core.runnables import RunnablePassthrough
6
- from langchain_core.output_parsers import StrOutputParser
7
- from langchain_groq import ChatGroq
8
- from sentence_transformers import SentenceTransformer
9
- from sklearn.metrics.pairwise import cosine_similarity
10
- from dotenv import load_dotenv
11
- import os
12
-
13
- load_dotenv()
14
-
15
- # Configuration
16
- persist_directory = "./chroma_storage"
17
- embedding_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
18
- vectorstore = Chroma(
19
- embedding_function=embedding_model,
20
- persist_directory=persist_directory
21
- )
22
- retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
23
-
24
- chat_model = ChatGroq(
25
- temperature=0.3,
26
- model_name="llama-3.1-8b-instant",
27
- api_key=os.getenv("groqapi_key"),
28
- )
29
-
30
- # Prompt RAG
31
- rag_template = """\
32
- Use the following context to answer the user's query. If you cannot answer, please respond with 'I don't know'.
33
-
34
- User's Query:
35
- {question}
36
-
37
- Context:
38
- {context}
39
- """
40
- rag_prompt = ChatPromptTemplate.from_template(rag_template)
41
-
42
- # SentenceTransformer pour la similarité (si besoin)
43
- similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
44
-
45
- def calculate_similarity(question, document):
46
- q_emb = similarity_model.encode(question, convert_to_tensor=True).cpu().detach().numpy()
47
- d_emb = similarity_model.encode(document, convert_to_tensor=True).cpu().detach().numpy()
48
- return cosine_similarity([q_emb], [d_emb])[0][0]
49
-
50
- # Génération de sous-requêtes
51
- def generate_queries(query: str, llm, num_queries: int = 4):
52
- query_gen_str = """\
53
- You are a helpful assistant that generates multiple search queries based on a \
54
- single input query. Generate {num_queries} search queries, one on each line, \
55
- related to the following input query:
56
- Query: {query}
57
- Queries:
58
- """
59
- query_prompt = ChatPromptTemplate.from_template(query_gen_str)
60
- formatted_prompt = query_prompt.format(num_queries=num_queries, query=query)
61
- response = llm.predict(formatted_prompt)
62
- return response.strip().splitlines()
63
-
64
- # Récupération de contexte enrichi
65
- def get_context(query):
66
- sub_queries = generate_queries(query, chat_model)
67
- chunks = [retriever.invoke(q) for q in sub_queries]
68
- return "\n".join(map(str, chunks))
69
-
70
- # La chaîne complète
71
- rag_chain = (
72
- {"context": get_context, "question": RunnablePassthrough()}
73
- | rag_prompt
74
- | chat_model
75
- | StrOutputParser()
76
- )
 
1
+ # rag_core.py
2
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
3
+ from langchain_chroma import Chroma
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_groq import ChatGroq
8
+ from sentence_transformers import SentenceTransformer
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+
11
+ import os
12
+
13
+
14
+ # Configuration
15
+ persist_directory = "./chroma_storage"
16
+ embedding_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
17
+ vectorstore = Chroma(
18
+ embedding_function=embedding_model,
19
+ persist_directory=persist_directory
20
+ )
21
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
22
+
23
+ chat_model = ChatGroq(
24
+ temperature=0.3,
25
+ model_name="llama-3.1-8b-instant",
26
+ api_key=os.getenv("groqapi_key"),
27
+ )
28
+
29
+ # Prompt RAG
30
+ rag_template = """\
31
+ Use the following context to answer the user's query. If you cannot answer, please respond with 'I don't know'.
32
+
33
+ User's Query:
34
+ {question}
35
+
36
+ Context:
37
+ {context}
38
+ """
39
+ rag_prompt = ChatPromptTemplate.from_template(rag_template)
40
+
41
+ # SentenceTransformer pour la similarité (si besoin)
42
+ similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
43
+
44
+ def calculate_similarity(question, document):
45
+ q_emb = similarity_model.encode(question, convert_to_tensor=True).cpu().detach().numpy()
46
+ d_emb = similarity_model.encode(document, convert_to_tensor=True).cpu().detach().numpy()
47
+ return cosine_similarity([q_emb], [d_emb])[0][0]
48
+
49
+ # Génération de sous-requêtes
50
+ def generate_queries(query: str, llm, num_queries: int = 4):
51
+ query_gen_str = """\
52
+ You are a helpful assistant that generates multiple search queries based on a \
53
+ single input query. Generate {num_queries} search queries, one on each line, \
54
+ related to the following input query:
55
+ Query: {query}
56
+ Queries:
57
+ """
58
+ query_prompt = ChatPromptTemplate.from_template(query_gen_str)
59
+ formatted_prompt = query_prompt.format(num_queries=num_queries, query=query)
60
+ response = llm.predict(formatted_prompt)
61
+ return response.strip().splitlines()
62
+
63
+ # Récupération de contexte enrichi
64
+ def get_context(query):
65
+ sub_queries = generate_queries(query, chat_model)
66
+ chunks = [retriever.invoke(q) for q in sub_queries]
67
+ return "\n".join(map(str, chunks))
68
+
69
+ # La chaîne complète
70
+ rag_chain = (
71
+ {"context": get_context, "question": RunnablePassthrough()}
72
+ | rag_prompt
73
+ | chat_model
74
+ | StrOutputParser()
75
+ )