| |
| import modules.app_constants as app_constants |
| from langchain_openai import ChatOpenAI |
| from langchain.chains import RetrievalQAWithSourcesChain |
| from openai import OpenAI |
| from modules import app_logger, common_utils, app_st_session_utils |
|
|
| |
| app_logger = app_logger.app_logger |
|
|
| |
| def query_llm(prompt, page="nav_private_ai", retriever=None, message_store=None, use_retrieval_chain=False, last_page=None, username=""): |
| try: |
| |
| if use_retrieval_chain: |
| app_logger.info("Using ChatOpenAI with RetrievalQAWithSourcesChain") |
| llm = ChatOpenAI( |
| model_name=app_constants.MODEL_NAME, |
| openai_api_key=app_constants.openai_api_key, |
| base_url=app_constants.local_model_uri, |
| streaming=True |
| ) |
| qa = RetrievalQAWithSourcesChain.from_chain_type( |
| llm=llm, |
| chain_type=app_constants.RAG_TECHNIQUE, |
| retriever=retriever, |
| return_source_documents=False |
| ) |
| else: |
| app_logger.info("Using direct OpenAI API call") |
| llm = OpenAI( |
| base_url=app_constants.local_model_uri, |
| api_key=app_constants.openai_api_key |
| ) |
|
|
| |
| if last_page != page: |
| app_logger.info(f"Updating messages for new page: {page}") |
| common_utils.get_system_role(page, message_store) |
|
|
| |
| messages_to_send = common_utils.construct_messages_to_send(page, message_store, prompt) |
| app_logger.debug(messages_to_send) |
| |
| response = None |
| if use_retrieval_chain: |
| response = qa.invoke(prompt) |
| else: |
| response = llm.chat.completions.create( |
| model=app_constants.MODEL_NAME, |
| messages=messages_to_send |
| ) |
|
|
| |
| raw_msg = response.get('answer') if use_retrieval_chain else response.choices[0].message.content |
| source_info = response.get('sources', '').strip() if use_retrieval_chain else '' |
| formatted_msg = app_st_session_utils.format_response(raw_msg + "Source: " + source_info if source_info else raw_msg) |
|
|
| return formatted_msg |
|
|
| except Exception as e: |
| error_message = f"An error occurred while querying the language model: {e}" |
| app_logger.error(error_message) |
| return error_message |
|
|