File size: 4,903 Bytes
0a55f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations

import json
import os
import re
from dataclasses import dataclass
from typing import Any, Type

import litellm
from litellm import completion
from pydantic import BaseModel, ValidationError


@dataclass
class ModelConfig:
    provider: str
    model: str
    temperature: float = 0.2
    max_tokens: int = 12000

    @property
    def model_name(self) -> str:
        if "/" in self.model:
            return self.model
        if self.provider.lower() == "openai":
            return f"openai/{self.model}"
        if self.provider.lower() == "gemini":
            return f"gemini/{self.model}"
        return self.model


class MultiProviderLLMClient:
    def __init__(self, default_config: ModelConfig, stage_models: dict[str, str] | None = None):
        self.default_config = default_config
        self.stage_models = stage_models or {}
        litellm.drop_params = True
        self._validate_env(default_config.provider)

    def _validate_env(self, provider: str) -> None:
        provider = provider.lower()
        if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
            raise ValueError("OPENAI_API_KEY is required for provider=openai")
        if provider == "gemini" and not os.getenv("GEMINI_API_KEY"):
            raise ValueError("GEMINI_API_KEY is required for provider=gemini")

    def config_for_stage(self, stage_name: str) -> ModelConfig:
        model_override = self.stage_models.get(stage_name)
        if not model_override:
            return self.default_config
        provider = self.default_config.provider
        model = model_override
        if "/" in model_override:
            provider, model = model_override.split("/", 1)
        self._validate_env(provider)
        return ModelConfig(
            provider=provider,
            model=model,
            temperature=self.default_config.temperature,
            max_tokens=self.default_config.max_tokens,
        )

    def generate_structured(
        self,
        *,
        stage_name: str,
        system_prompt: str,
        user_prompt: str,
        response_model: Type[BaseModel],
    ) -> BaseModel:
        config = self.config_for_stage(stage_name)
        completion_kwargs = {
            "model": config.model_name,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "max_tokens": config.max_tokens,
            "response_format": {"type": "json_object"},
        }
        temperature = self._temperature_for_model(config)
        if temperature is not None:
            completion_kwargs["temperature"] = temperature

        response = completion(
            **completion_kwargs,
        )
        content = response.choices[0].message.content or ""
        payload = self._parse_json(content)
        try:
            return response_model.model_validate(payload)
        except ValidationError as exc:
            if isinstance(payload, list) and len(payload) == 1 and isinstance(payload[0], dict):
                try:
                    return response_model.model_validate(payload[0])
                except ValidationError:
                    pass
            raise ValueError(
                f"Stage {stage_name} returned invalid JSON for {response_model.__name__}: {exc}\nRaw content:\n{content}"
            ) from exc

    def generate_text(
        self,
        *,
        stage_name: str,
        system_prompt: str,
        user_prompt: str,
    ) -> str:
        config = self.config_for_stage(stage_name)
        completion_kwargs = {
            "model": config.model_name,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "max_tokens": config.max_tokens,
        }
        temperature = self._temperature_for_model(config)
        if temperature is not None:
            completion_kwargs["temperature"] = temperature
        response = completion(**completion_kwargs)
        return (response.choices[0].message.content or "").strip()

    @staticmethod
    def _parse_json(text: str) -> Any:
        text = text.strip()
        if text.startswith("```"):
            match = re.search(r"```(?:json)?\s*(.*?)```", text, flags=re.S)
            if match:
                text = match.group(1).strip()
        try:
            return json.loads(text)
        except json.JSONDecodeError:
            match = re.search(r"(\{.*\}|\[.*\])", text, flags=re.S)
            if match:
                return json.loads(match.group(1))
            raise

    @staticmethod
    def _temperature_for_model(config: ModelConfig) -> float | None:
        model_name = config.model_name.lower()
        if "gpt-5" in model_name:
            return None
        return config.temperature