| import os |
| from operator import itemgetter |
|
|
| from langchain_chroma import Chroma |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_core.runnables import RunnablePassthrough, RunnableParallel |
| from langchain_core.output_parsers import JsonOutputParser |
| from langchain.prompts import PromptTemplate |
|
|
| from lib.models import MODELS_MAP |
| from lib.utils import format_docs, retrieve_answer, load_embeddings |
| from lib.entities import LLMEvalResult |
|
|
| def create_retriever(llm_name, db_path, docs, collection_name="local-rag"): |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=60) |
|
|
| splits = text_splitter.split_documents(docs) |
|
|
| embeddings = load_embeddings(llm_name) |
|
|
| if not os.path.exists(db_path): |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=db_path, collection_name=collection_name) |
| else: |
| vectorstore = Chroma(persist_directory=db_path, embedding_function=embeddings, collection_name=collection_name) |
|
|
| retriever = vectorstore.as_retriever() |
| return retriever |
|
|
| def create_qa_chain(llm, retriever, prompts_text): |
| initial_prompt_text = prompts_text["initial_prompt"] |
| qa_eval_prompt_text = prompts_text["evaluation_prompt"] |
|
|
| initial_prompt = PromptTemplate( |
| template=initial_prompt_text, |
| input_variables=["question", "context"] |
| ) |
|
|
| json_parser = JsonOutputParser(pydantic_object=LLMEvalResult) |
| qa_eval_prompt = PromptTemplate( |
| template=qa_eval_prompt_text, |
| input_variables=["question","answer"], |
| partial_variables={"format_instructions": json_parser.get_format_instructions()}, |
| ) |
|
|
| qa_eval_prompt_with_context = PromptTemplate( |
| template=qa_eval_prompt_text, |
| input_variables=["question","answer","context"], |
| partial_variables={"format_instructions": json_parser.get_format_instructions()}, |
| ) |
|
|
| chain = ( |
| RunnableParallel(context = retriever | format_docs, question = RunnablePassthrough()) | |
| RunnableParallel(answer = initial_prompt | llm | retrieve_answer, question = itemgetter("question"), context = itemgetter("context") ) | |
| RunnableParallel(input = qa_eval_prompt, context = itemgetter("context"), answer = itemgetter("answer")) | |
| RunnableParallel(evaluation = itemgetter("input") | llm , context = itemgetter("context"), answer = itemgetter("answer") ) | |
| RunnableParallel(output = itemgetter("answer"), evaluation = itemgetter("evaluation") | json_parser, context = itemgetter("context")) |
| ) |
| return chain |