| import os |
| import litellm |
| from tenacity import ( |
| retry, |
| stop_after_attempt, |
| wait_random_exponential, |
| ) |
| from litellm import completion, acompletion |
| from typing import List |
| from ..core.registry import register_model |
| from .model_configs import LiteLLMConfig |
| from .openai_model import OpenAILLM |
| from .model_utils import infer_litellm_company_from_model, Cost |
|
|
| @register_model(config_cls=LiteLLMConfig, alias=["litellm"]) |
| class LiteLLM(OpenAILLM): |
|
|
| def init_model(self): |
| """ |
| Initialize the model based on the configuration. |
| """ |
| |
| if self.config.llm_type != "LiteLLM": |
| raise ValueError("llm_type must be 'LiteLLM'") |
|
|
| |
| self.model = self.config.model |
| self.api_base = self.config.api_base |
| self.api_key = self.config.api_key |
| |
| company = infer_litellm_company_from_model(self.model) |
|
|
| if self.config.is_local or company == "local": |
| if not self.api_base: |
| raise ValueError("api_base is required for local models in LiteLLMConfig") |
| |
| litellm.api_base = self.api_base |
| litellm.api_key = self.api_key |
| else: |
| |
| if company == "openai": |
| if not self.config.openai_key: |
| raise ValueError("OpenAI API key is required for OpenAI models. You should set `openai_key` in LiteLLMConfig") |
| os.environ["OPENAI_API_KEY"] = self.config.openai_key |
| elif company == "azure": |
| if not self.config.azure_key or not self.config.azure_endpoint: |
| raise ValueError("Azure OpenAI key and endpoint are required for Azure models. You should set `azure_key` and `azure_endpoint` in LiteLLMConfig") |
| os.environ["AZURE_API_KEY"] = self.config.azure_key |
| os.environ["AZURE_API_BASE"] = self.config.azure_endpoint |
| if self.config.api_version: |
| os.environ["AZURE_API_VERSION"] = self.config.api_version |
| elif company == "deepseek": |
| if not self.config.deepseek_key: |
| raise ValueError("DeepSeek API key is required for DeepSeek models. You should set `deepseek_key` in LiteLLMConfig") |
| os.environ["DEEPSEEK_API_KEY"] = self.config.deepseek_key |
| elif company == "anthropic": |
| if not self.config.anthropic_key: |
| raise ValueError("Anthropic API key is required for Anthropic models. You should set `anthropic_key` in LiteLLMConfig") |
| os.environ["ANTHROPIC_API_KEY"] = self.config.anthropic_key |
| elif company == "gemini": |
| if not self.config.gemini_key: |
| raise ValueError("Gemini API key is required for Gemini models. You should set `gemini_key` in LiteLLMConfig") |
| os.environ["GEMINI_API_KEY"] = self.config.gemini_key |
| elif company == "meta_llama": |
| if not self.config.meta_llama_key: |
| raise ValueError("Meta Llama API key is required for Meta Llama models. You should set `meta_llama_key` in LiteLLMConfig") |
| os.environ["LLAMA_API_KEY"] = self.config.meta_llama_key |
| elif company == "openrouter": |
| if not self.config.openrouter_key: |
| raise ValueError("OpenRouter API key is required for OpenRouter models. You should set `openrouter_key` in LiteLLMConfig. You can also set `openrouter_base` in LiteLLMConfig to use a custom base URL [optional]") |
| os.environ["OPENROUTER_API_KEY"] = self.config.openrouter_key |
| os.environ["OPENROUTER_API_BASE"] = self.config.openrouter_base |
| elif company == "perplexity": |
| if not self.config.perplexity_key: |
| raise ValueError("Perplexity API key is required for Perplexity models. You should set `perplexity_key` in LiteLLMConfig") |
| os.environ["PERPLEXITYAI_API_KEY"] = self.config.perplexity_key |
| elif company == "groq": |
| if not self.config.groq_key: |
| raise ValueError("Groq API key is required for Groq models. You should set `groq_key` in LiteLLMConfig") |
| os.environ["GROQ_API_KEY"] = self.config.groq_key |
| else: |
| raise ValueError(f"Unsupported company: {company}") |
|
|
| self._default_ignore_fields = [ |
| "llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key", |
| "gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key", |
| "groq_key", "api_base", "is_local", "azure_endpoint", "azure_key", "api_version", "api_key" |
| ] |
| |
| def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost: |
| if self.config.is_local: |
| return Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=0.0, output_cost=0.0) |
| return super()._compute_cost(input_tokens, output_tokens) |
|
|
| @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) |
| def single_generate(self, messages: List[dict], **kwargs) -> str: |
|
|
| """ |
| Generate a single response using the completion function. |
| |
| Args: |
| messages (List[dict]): A list of dictionaries representing the conversation history. |
| **kwargs (Any): Additional parameters to be passed to the `completion` function. |
| |
| Returns: |
| str: A string containing the model's response. |
| """ |
| stream = kwargs["stream"] if "stream" in kwargs else self.config.stream |
| output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response |
|
|
| try: |
| completion_params = self.get_completion_params(**kwargs) |
| company = infer_litellm_company_from_model(self.model) |
| if self.config.is_local or company == "local": |
| completion_params["api_base"] = self.api_base |
| elif company == "azure": |
| completion_params["api_base"] = self.config.azure_endpoint |
| completion_params["api_version"] = self.config.api_version |
| completion_params["api_key"] = self.config.azure_key |
| response = completion(messages=messages, **completion_params) |
| if stream: |
| output = self.get_stream_output(response, output_response=output_response) |
| cost = self._stream_cost(messages=messages, output=output) |
| else: |
| output: str = self.get_completion_output(response=response, output_response=output_response) |
| cost = self._completion_cost(response=response) |
| self._update_cost(cost=cost) |
|
|
| except Exception as e: |
| raise RuntimeError(f"Error during single_generate: {str(e)}") |
| |
| return output |
| |
| def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]: |
| """ |
| Generate responses for a batch of messages. |
| |
| Args: |
| batch_messages (List[List[dict]]): A list of message lists, where each sublist represents a conversation. |
| **kwargs (Any): Additional parameters to be passed to the `completion` function. |
| |
| Returns: |
| List[str]: A list of responses for each conversation. |
| """ |
| results = [] |
| for messages in batch_messages: |
| response = self.single_generate(messages, **kwargs) |
| results.append(response) |
| return results |
| |
| async def single_generate_async(self, messages: List[dict], **kwargs) -> str: |
| """ |
| Generate a single response using the async completion function. |
| |
| Args: |
| messages (List[dict]): A list of dictionaries representing the conversation history. |
| **kwargs (Any): Additional parameters to be passed to the `completion` function. |
| |
| Returns: |
| str: A string containing the model's response. |
| """ |
| stream = kwargs["stream"] if "stream" in kwargs else self.config.stream |
| output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response |
|
|
| try: |
| completion_params = self.get_completion_params(**kwargs) |
| company = infer_litellm_company_from_model(self.model) |
| if self.config.is_local or company == "local": |
| completion_params["api_base"] = self.api_base |
| elif company == "azure": |
| completion_params["api_base"] = self.config.azure_endpoint |
| completion_params["api_version"] = self.config.api_version |
| completion_params["api_key"] = self.config.azure_key |
| response = await acompletion(messages=messages, **completion_params) |
| if stream: |
| if hasattr(response, "__aiter__"): |
| output = await self.get_stream_output_async(response, output_response=output_response) |
| else: |
| output = self.get_stream_output(response, output_response=output_response) |
| cost = self._stream_cost(messages=messages, output=output) |
| else: |
| output: str = self.get_completion_output(response=response, output_response=output_response) |
| cost = self._completion_cost(response=response) |
| self._update_cost(cost=cost) |
| except Exception as e: |
| raise RuntimeError(f"Error during single_generate_async: {str(e)}") |
| |
| return output |
|
|