| |
|
|
| from __future__ import annotations |
|
|
| import os |
| import yaml |
| import openai |
| import ast |
| import pdb |
| import asyncio |
| from typing import Any, List |
| import os |
| import pathlib |
| import openai |
|
|
|
|
| |
|
|
| |
| |
|
|
| class OpenAIChat(): |
| def __init__( |
| self, |
| model_name='gpt-3.5-turbo', |
| max_tokens=2500, |
| temperature=0, |
| top_p=1, |
| request_timeout=60, |
| ): |
| openai.api_key = os.environ.get("OPENAI_API_KEY", None) |
| assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable." |
| if 'gpt' not in model_name: |
| openai.api_base = "http://localhost:8000/v1" |
| self.config = { |
| 'model_name': model_name, |
| 'max_tokens': max_tokens, |
| 'temperature': temperature, |
| 'top_p': top_p, |
| 'request_timeout': request_timeout, |
| } |
|
|
| |
| def _boolean_fix(self, output): |
| return output.replace("true", "True").replace("false", "False") |
|
|
| def _type_check(self, output, expected_type): |
| try: |
| output_eval = ast.literal_eval(output) |
| if not isinstance(output_eval, expected_type): |
| return None |
| return output_eval |
| except: |
| return None |
|
|
| async def dispatch_openai_requests( |
| self, |
| messages_list, |
| ) -> list[str]: |
| """Dispatches requests to OpenAI API asynchronously. |
| |
| Args: |
| messages_list: List of messages to be sent to OpenAI ChatCompletion API. |
| Returns: |
| List of responses from OpenAI API. |
| """ |
| async def _request_with_retry(messages, retry=3): |
| for _ in range(retry): |
| try: |
| response = await openai.ChatCompletion.acreate( |
| model=self.config['model_name'], |
| messages=messages, |
| max_tokens=self.config['max_tokens'], |
| temperature=self.config['temperature'], |
| top_p=self.config['top_p'], |
| request_timeout=self.config['request_timeout'], |
| ) |
| return response |
| except openai.error.RateLimitError: |
| print('Rate limit error, waiting for 40 second...') |
| await asyncio.sleep(40) |
| except openai.error.APIError: |
| print('API error, waiting for 1 second...') |
| await asyncio.sleep(1) |
| except openai.error.Timeout: |
| print('Timeout error, waiting for 1 second...') |
| await asyncio.sleep(1) |
| except openai.error.ServiceUnavailableError: |
| print('Service unavailable error, waiting for 3 second...') |
| await asyncio.sleep(3) |
| except openai.error.APIConnectionError: |
| print('API Connection error, waiting for 3 second...') |
| await asyncio.sleep(3) |
|
|
| return None |
|
|
| async_responses = [ |
| _request_with_retry(messages) |
| for messages in messages_list |
| ] |
|
|
| return await asyncio.gather(*async_responses) |
| |
| async def async_run(self, messages_list, expected_type): |
| retry = 1 |
| responses = [None for _ in range(len(messages_list))] |
| messages_list_cur_index = [i for i in range(len(messages_list))] |
|
|
| while retry > 0 and len(messages_list_cur_index) > 0: |
| print(f'{retry} retry left...') |
| messages_list_cur = [messages_list[i] for i in messages_list_cur_index] |
| |
| predictions = await self.dispatch_openai_requests( |
| messages_list=messages_list_cur, |
| ) |
|
|
| preds = [self._type_check(self._boolean_fix(prediction['choices'][0]['message']['content']), expected_type) if prediction is not None else None for prediction in predictions] |
|
|
| finised_index = [] |
| for i, pred in enumerate(preds): |
| if pred is not None: |
| responses[messages_list_cur_index[i]] = pred |
| finised_index.append(messages_list_cur_index[i]) |
| |
| messages_list_cur_index = [i for i in messages_list_cur_index if i not in finised_index] |
| |
| retry -= 1 |
| |
| return responses |
|
|
| class OpenAIEmbed(): |
| def __init__(): |
| openai.api_key = os.environ.get("OPENAI_API_KEY", None) |
| assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable." |
|
|
| async def create_embedding(self, text, retry=3): |
| for _ in range(retry): |
| try: |
| response = await openai.Embedding.acreate(input=text, model="text-embedding-ada-002") |
| return response |
| except openai.error.RateLimitError: |
| print('Rate limit error, waiting for 1 second...') |
| await asyncio.sleep(1) |
| except openai.error.APIError: |
| print('API error, waiting for 1 second...') |
| await asyncio.sleep(1) |
| except openai.error.Timeout: |
| print('Timeout error, waiting for 1 second...') |
| await asyncio.sleep(1) |
| return None |
|
|
| async def process_batch(self, batch, retry=3): |
| tasks = [self.create_embedding(text, retry=retry) for text in batch] |
| return await asyncio.gather(*tasks) |
|
|
| if __name__ == "__main__": |
| chat = OpenAIChat() |
|
|
| predictions = chat.async_run( |
| messages_list=[ |
| [{"role": "user", "content": "show either 'ab' or '['a']'. Do not do anything else."}], |
| ] * 20, |
| expected_type=List, |
| ) |
|
|
| |
| embed = OpenAIEmbed() |
| batch = ["string1", "string2", "string3", "string4", "string5", "string6", "string7", "string8", "string9", "string10"] |
| embeddings = asyncio.run(embed.process_batch(batch, retry=3)) |
| for embedding in embeddings: |
| print(embedding["data"][0]["embedding"]) |