| import openai |
| from time import sleep |
|
|
|
|
| class OpenAIModel: |
| def __init__(self, model_name, api_key): |
| self.model_name = model_name |
| openai.api_key = api_key |
|
|
| def completions(self, |
| prompt: str, |
| n: int, |
| temperature: float, |
| top_p: float, |
| max_tokens: int, |
| **kwargs): |
| for i in range(100): |
| try: |
| sleep(1. / 5) |
| return self.sample_openai_call(prompt=prompt, |
| num_return_sequences=n, |
| temperature=temperature, top_p=top_p, |
| max_gen_length=max_tokens) |
| except Exception as e: |
| print(e) |
| print(f'{i} SLEEP') |
| sleep(60) |
|
|
|
|
| def sample_openai_call(self, |
| prompt: str, |
| num_return_sequences: int, |
| temperature: float, top_p: float, |
| max_gen_length: int): |
| for i in range(100): |
| try: |
| sleep(3) |
| return self.sample_call_with_len(prompt=prompt, |
| num_return_sequences=num_return_sequences, |
| temperature=temperature, top_p=top_p, |
| max_gen_length=max_gen_length) |
| except Exception as e: |
| if 'maximum context length' in str(e): |
| max_gen_length -= 32 |
| print(f'{i} REDUCING max_gen_length TO {max_gen_length}') |
| else: |
| raise e |
|
|
|
|
| def sample_call_with_len( |
| self, |
| prompt, |
| num_return_sequences=1, |
| temperature=0.2, |
| top_p=0.95, |
| max_gen_length=128, |
| return_raw=False |
| ): |
| response = openai.ChatCompletion.create( |
| model=self.model_name, |
| messages=[ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": prompt}, |
| ], |
| n=num_return_sequences, |
| temperature=temperature, |
| max_tokens=max_gen_length, |
| top_p=top_p, |
| frequency_penalty=0, |
| presence_penalty=0 |
| ) |
|
|
| if return_raw: |
| return response |
| result = [] |
| for completion in response['choices']: |
| result.append(completion["message"]["content"]) |
| return result |
|
|
|
|
| api_key = "" |
| if __name__ == '__main__': |
| |
| model = OpenAIModel(model_name="gpt-3.5-turbo", api_key=api_key) |
| |
| prompt = "Given two phrases \"access information\" and \"information access\", predict their semantic relatedness within the range [0, 1]" |
| result = model.sample_call_with_len(prompt=prompt) |
| print(result) |
|
|