| import asyncio |
| from tenacity import ( |
| retry, |
| stop_after_attempt, |
| wait_random_exponential, |
| ) |
| from openai import OpenAI, Stream |
| from openai.types.chat import ChatCompletion |
| from typing import Optional, List |
| from litellm import token_counter, cost_per_token |
| from ..core.registry import register_model |
| from .model_configs import OpenAILLMConfig |
| from .base_model import BaseLLM |
| from .model_utils import Cost, cost_manager, get_openai_model_cost |
|
|
|
|
| @register_model(config_cls=OpenAILLMConfig, alias=["openai_llm"]) |
| class OpenAILLM(BaseLLM): |
|
|
| def init_model(self): |
| config: OpenAILLMConfig = self.config |
| self._client = self._init_client(config) |
| self._default_ignore_fields = [ |
| "llm_type", "output_response", "openai_key", "deepseek_key", "anthropic_key", |
| "gemini_key", "meta_llama_key", "openrouter_key", "openrouter_base", "perplexity_key", |
| "groq_key" |
| ] |
| if self.config.model not in get_openai_model_cost(): |
| raise KeyError(f"'{self.config.model}' is not a valid OpenAI model name!") |
| |
| def _init_client(self, config: OpenAILLMConfig): |
| client = OpenAI(api_key=config.openai_key) |
| return client |
|
|
| def formulate_messages(self, prompts: List[str], system_messages: Optional[List[str]] = None) -> List[List[dict]]: |
| |
| if system_messages: |
| assert len(prompts) == len(system_messages), f"the number of prompts ({len(prompts)}) is different from the number of system_messages ({len(system_messages)})" |
| else: |
| system_messages = [None] * len(prompts) |
| |
| messages_list = [] |
| for prompt, system_message in zip(prompts, system_messages): |
| messages = [] |
| if system_message: |
| messages.append({"role": "system", "content": system_message}) |
| messages.append({"role": "user", "content": prompt}) |
| messages_list.append(messages) |
|
|
| return messages_list |
|
|
| def update_completion_params(self, params1: dict, params2: dict) -> dict: |
| config_params: list = self.config.get_config_params() |
| for key, value in params2.items(): |
| if key in self._default_ignore_fields: |
| continue |
| if key not in config_params: |
| continue |
| params1[key] = value |
| return params1 |
|
|
| def get_completion_params(self, **kwargs): |
| completion_params = self.config.get_set_params(ignore=self._default_ignore_fields) |
| completion_params = self.update_completion_params(completion_params, kwargs) |
| return completion_params |
| |
| def get_stream_output(self, response: Stream, output_response: bool=True) -> str: |
| """ |
| Process stream response and return the complete output. |
| |
| Args: |
| response: The stream response from OpenAI |
| output_response: Whether to print the response in real-time |
| |
| Returns: |
| str: The complete output text |
| """ |
| output = "" |
| for chunk in response: |
| content = chunk.choices[0].delta.content |
| if content: |
| if output_response: |
| print(content, end="", flush=True) |
| output += content |
| if output_response: |
| print("") |
| return output |
| |
| async def get_stream_output_async(self, response, output_response: bool = False) -> str: |
| """ |
| Process async stream response and return the complete output. |
| |
| Args: |
| response (AsyncIterator[ChatCompletionChunk]): The async stream response from OpenAI |
| output_response (bool): Whether to print the response in real-time |
| |
| |
| Returns: |
| str: The complete output text |
| """ |
| output = "" |
| async for chunk in response: |
| content = chunk.choices[0].delta.content |
| if content: |
| if output_response: |
| print(content, end="", flush=True) |
| output += content |
| if output_response: |
| print("") |
| return output |
|
|
| def get_completion_output(self, response: ChatCompletion, output_response: bool=True) -> str: |
| output = response.choices[0].message.content |
| if output_response: |
| print(output) |
| return output |
|
|
| @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) |
| def single_generate(self, messages: List[dict], **kwargs) -> str: |
|
|
| stream = kwargs["stream"] if "stream" in kwargs else self.config.stream |
| output_response = kwargs["output_response"] if "output_response" in kwargs else self.config.output_response |
|
|
| try: |
| completion_params = self.get_completion_params(**kwargs) |
| response = self._client.chat.completions.create(messages=messages, **completion_params) |
| if stream: |
| output = self.get_stream_output(response, output_response=output_response) |
| cost = self._stream_cost(messages=messages, output=output) |
| else: |
| output: str = self.get_completion_output(response=response, output_response=output_response) |
| cost = self._completion_cost(response) |
| self._update_cost(cost=cost) |
| except Exception as e: |
| raise RuntimeError(f"Error during single_generate of OpenAILLM: {str(e)}") |
| |
| return output |
| |
| def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]: |
| return [self.single_generate(messages=one_messages, **kwargs) for one_messages in batch_messages] |
|
|
| async def single_generate_async(self, messages: List[dict], **kwargs) -> str: |
|
|
| stream = kwargs.get("stream", self.config.stream) |
| output_response = kwargs.get("output_response", self.config.output_response) |
|
|
| try: |
| |
| |
| |
| isolated_client = self._init_client(self.config) |
| completion_params = self.get_completion_params(**kwargs) |
|
|
| |
| loop = asyncio.get_event_loop() |
| response = await loop.run_in_executor( |
| None, |
| lambda: isolated_client.chat.completions.create( |
| messages=messages, |
| **completion_params |
| ) |
| ) |
|
|
| if stream: |
| if hasattr(response, "__aiter__"): |
| output = await self.get_stream_output_async(response, output_response=output_response) |
| else: |
| output = self.get_stream_output(response, output_response=output_response) |
| cost = self._stream_cost(messages=messages, output=output) |
| else: |
| output: str = self.get_completion_output(response=response, output_response=output_response) |
| cost = self._completion_cost(response) |
| self._update_cost(cost=cost) |
| |
| except Exception as e: |
| raise RuntimeError(f"Error during single_generate_async of OpenAILLM: {str(e)}") |
|
|
| return output |
| |
| def _completion_cost(self, response: ChatCompletion) -> Cost: |
| input_tokens = response.usage.prompt_tokens |
| output_tokens = response.usage.completion_tokens |
| return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens) |
|
|
| def _stream_cost(self, messages: List[dict], output: str) -> Cost: |
| model: str = self.config.model |
| input_tokens = token_counter(model=model, messages=messages) |
| output_tokens = token_counter(model=model, text=output) |
| return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens) |
| |
| def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost: |
| |
| input_cost, output_cost = cost_per_token( |
| model=self.config.model, |
| prompt_tokens=input_tokens, |
| completion_tokens=output_tokens, |
| ) |
| cost = Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=input_cost, output_cost=output_cost) |
| return cost |
| |
| def _update_cost(self, cost: Cost): |
| cost_manager.update_cost(cost=cost, model=self.config.model) |
| |