| from typing import List |
| from tenacity import retry, stop_after_attempt, wait_random_exponential |
| import requests |
| from models.Base import BaseModel |
|
|
|
|
| class GeminiModel(BaseModel): |
| def __init__(self, |
| model_id="gemini-2.5-pro-preview-05-06", |
| api_key=None): |
| assert api_key is not None, "no api key is provided." |
| self.model_id = model_id |
| |
| self.SERVER = "https://llm-api.amd.com/vertex/gemini" |
| self.HEADERS = {"Ocp-Apim-Subscription-Key": api_key} |
|
|
| |
| @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) |
| def generate(self, |
| messages: List, |
| temperature=1.0, |
| presence_penalty=0, |
| frequency_penalty=0, |
| max_tokens=30000) -> str: |
| body = { |
| "messages": messages, |
| "max_tokens": max_tokens, |
| "temperature":temperature, |
| "top_P": 0.95, |
| "presence_Penalty": presence_penalty, |
| "frequency_Penalty": frequency_penalty, |
| } |
| response_gemine = requests.post(url=f"{self.SERVER}/{self.model_id}/chat", |
| json=body, |
| headers=self.HEADERS) |
| assert response_gemine.status_code == 200 |
| code_chat_completion_result = response_gemine.json() |
| |
| return code_chat_completion_result['candidates'][0]['content']['parts'][0]['text'] |