| from llama_index.llms.huggingface import HuggingFaceLLM, HuggingFaceInferenceAPI |
| from llama_index.llms.openai import OpenAI |
| from llama_index.llms.replicate import Replicate |
|
|
| from dotenv import load_dotenv |
| import os |
| import streamlit as st |
|
|
| load_dotenv() |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| integrated_llms = { |
| "gpt-3.5-turbo-0125": "openai", |
| "meta/llama-2-13b-chat": "replicate", |
| "mistralai/Mistral-7B-Instruct-v0.2": "huggingface", |
| |
| |
| } |
|
|
|
|
| def load_llm(model_name: str, source: str = "huggingface"): |
| print("model_name: ", model_name, "source: ", source) |
| if integrated_llms.get(model_name) is None: |
| return None |
| try: |
| if source.startswith("openai"): |
| llm_gpt_3_5_turbo_0125 = OpenAI( |
| model=model_name, |
| api_key=st.session_state.openai_api_key, |
| ) |
|
|
| return llm_gpt_3_5_turbo_0125 |
|
|
| elif source.startswith("replicate"): |
| llm_llama_13b_v2_replicate = Replicate( |
| model=model_name, |
| is_chat_model=True, |
| additional_kwargs={"max_new_tokens": 250}, |
| prompt_key=st.session_state.replicate_api_token, |
| ) |
|
|
| return llm_llama_13b_v2_replicate |
|
|
| elif source.startswith("huggingface"): |
| llm_mixtral_8x7b = HuggingFaceInferenceAPI( |
| model_name=model_name, |
| token=st.session_state.hf_token, |
| ) |
|
|
| return llm_mixtral_8x7b |
|
|
| except Exception as e: |
| print(e) |
|
|