Spaces:
Running
Running
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| import logging | |
| import os | |
| from typing import ( | |
| Any, | |
| Awaitable, | |
| Callable, | |
| List, | |
| Optional, | |
| Iterator, | |
| AsyncIterator, | |
| Tuple, | |
| TypedDict, | |
| ) | |
| from litellm import completion, acompletion, embedding | |
| import litellm | |
| from python.helpers import dotenv | |
| from python.helpers.dotenv import load_dotenv | |
| from python.helpers.providers import get_provider_config | |
| from python.helpers.rate_limiter import RateLimiter | |
| from python.helpers.tokens import approximate_tokens | |
| from langchain_core.language_models.chat_models import SimpleChatModel | |
| from langchain_core.outputs.chat_generation import ChatGenerationChunk | |
| from langchain_core.callbacks.manager import ( | |
| CallbackManagerForLLMRun, | |
| AsyncCallbackManagerForLLMRun, | |
| ) | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| AIMessageChunk, | |
| HumanMessage, | |
| SystemMessage, | |
| ) | |
| from langchain.embeddings.base import Embeddings | |
| from sentence_transformers import SentenceTransformer | |
| # disable extra logging, must be done repeatedly, otherwise browser-use will turn it back on for some reason | |
| def turn_off_logging(): | |
| os.environ["LITELLM_LOG"] = "ERROR" # only errors | |
| litellm.suppress_debug_info = True | |
| # Silence **all** LiteLLM sub-loggers (utils, cost_calculator…) | |
| for name in logging.Logger.manager.loggerDict: | |
| if name.lower().startswith("litellm"): | |
| logging.getLogger(name).setLevel(logging.ERROR) | |
| # init | |
| load_dotenv() | |
| turn_off_logging() | |
| print("DEBUG: models.py loaded") | |
| class ModelType(Enum): | |
| CHAT = "Chat" | |
| EMBEDDING = "Embedding" | |
| class ModelConfig: | |
| type: ModelType | |
| provider: str | |
| name: str | |
| api_base: str = "" | |
| ctx_length: int = 0 | |
| limit_requests: int = 0 | |
| limit_input: int = 0 | |
| limit_output: int = 0 | |
| vision: bool = False | |
| kwargs: dict = field(default_factory=dict) | |
| def build_kwargs(self): | |
| kwargs = self.kwargs.copy() or {} | |
| if self.api_base and "api_base" not in kwargs: | |
| kwargs["api_base"] = self.api_base | |
| return kwargs | |
| class ChatChunk(TypedDict): | |
| """Simplified response chunk for chat models.""" | |
| response_delta: str | |
| reasoning_delta: str | |
| rate_limiters: dict[str, RateLimiter] = {} | |
| api_keys_round_robin: dict[str, int] = {} | |
| def get_api_key(service: str) -> str: | |
| # get api key for the service | |
| key = ( | |
| dotenv.get_dotenv_value(f"API_KEY_{service.upper()}") | |
| or dotenv.get_dotenv_value(f"{service.upper()}_API_KEY") | |
| or dotenv.get_dotenv_value(f"{service.upper()}_API_TOKEN") | |
| or "None" | |
| ) | |
| # if the key contains a comma, use round-robin | |
| if "," in key: | |
| api_keys = [k.strip() for k in key.split(",") if k.strip()] | |
| api_keys_round_robin[service] = api_keys_round_robin.get(service, -1) + 1 | |
| key = api_keys[api_keys_round_robin[service] % len(api_keys)] | |
| return key | |
| def get_rate_limiter( | |
| provider: str, name: str, requests: int, input: int, output: int | |
| ) -> RateLimiter: | |
| key = f"{provider}\\{name}" | |
| rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60)) | |
| limiter.limits["requests"] = requests or 0 | |
| limiter.limits["input"] = input or 0 | |
| limiter.limits["output"] = output or 0 | |
| return limiter | |
| async def apply_rate_limiter(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None): | |
| if not model_config: | |
| return | |
| limiter = get_rate_limiter( | |
| model_config.provider, | |
| model_config.name, | |
| model_config.limit_requests, | |
| model_config.limit_input, | |
| model_config.limit_output, | |
| ) | |
| limiter.add(input=approximate_tokens(input_text)) | |
| limiter.add(requests=1) | |
| await limiter.wait(rate_limiter_callback) | |
| return limiter | |
| def apply_rate_limiter_sync(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None): | |
| if not model_config: | |
| return | |
| import asyncio, nest_asyncio | |
| nest_asyncio.apply() | |
| return asyncio.run(apply_rate_limiter(model_config, input_text, rate_limiter_callback)) | |
| class LiteLLMChatWrapper(SimpleChatModel): | |
| model_name: str | |
| provider: str | |
| kwargs: dict = {} | |
| class Config: | |
| arbitrary_types_allowed = True | |
| extra = "allow" # Allow extra attributes | |
| validate_assignment = False # Don't validate on assignment | |
| def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): | |
| model_value = f"{provider}/{model}" | |
| super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore | |
| # Set A0 model config as instance attribute after parent init | |
| self.a0_model_conf = model_config | |
| def _llm_type(self) -> str: | |
| return "litellm-chat" | |
| def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]: | |
| result = [] | |
| # Map LangChain message types to LiteLLM roles | |
| role_mapping = { | |
| "human": "user", | |
| "ai": "assistant", | |
| "system": "system", | |
| "tool": "tool", | |
| } | |
| for m in messages: | |
| role = role_mapping.get(m.type, m.type) | |
| message_dict = {"role": role, "content": m.content} | |
| # Handle tool calls for AI messages | |
| tool_calls = getattr(m, "tool_calls", None) | |
| if tool_calls: | |
| # Convert LangChain tool calls to LiteLLM format | |
| new_tool_calls = [] | |
| for tool_call in tool_calls: | |
| # Ensure arguments is a JSON string | |
| args = tool_call["args"] | |
| if isinstance(args, dict): | |
| import json | |
| args_str = json.dumps(args) | |
| else: | |
| args_str = str(args) | |
| new_tool_calls.append( | |
| { | |
| "id": tool_call.get("id", ""), | |
| "type": "function", | |
| "function": { | |
| "name": tool_call["name"], | |
| "arguments": args_str, | |
| }, | |
| } | |
| ) | |
| message_dict["tool_calls"] = new_tool_calls | |
| # Handle tool call ID for ToolMessage | |
| tool_call_id = getattr(m, "tool_call_id", None) | |
| if tool_call_id: | |
| message_dict["tool_call_id"] = tool_call_id | |
| result.append(message_dict) | |
| return result | |
| def _call( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| import asyncio | |
| msgs = self._convert_messages(messages) | |
| # Apply rate limiting if configured | |
| apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) | |
| # Call the model | |
| resp = completion( | |
| model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs} | |
| ) | |
| # Parse output | |
| parsed = _parse_chunk(resp) | |
| return parsed["response_delta"] | |
| def _stream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> Iterator[ChatGenerationChunk]: | |
| import asyncio | |
| msgs = self._convert_messages(messages) | |
| # Apply rate limiting if configured | |
| apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) | |
| for chunk in completion( | |
| model=self.model_name, | |
| messages=msgs, | |
| stream=True, | |
| stop=stop, | |
| **{**self.kwargs, **kwargs}, | |
| ): | |
| parsed = _parse_chunk(chunk) | |
| # Only yield chunks with non-None content | |
| if parsed["response_delta"]: | |
| yield ChatGenerationChunk( | |
| message=AIMessageChunk(content=parsed["response_delta"]) | |
| ) | |
| async def _astream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> AsyncIterator[ChatGenerationChunk]: | |
| msgs = self._convert_messages(messages) | |
| # Apply rate limiting if configured | |
| await apply_rate_limiter(self.a0_model_conf, str(msgs)) | |
| response = await acompletion( | |
| model=self.model_name, | |
| messages=msgs, | |
| stream=True, | |
| stop=stop, | |
| **{**self.kwargs, **kwargs}, | |
| ) | |
| async for chunk in response: # type: ignore | |
| parsed = _parse_chunk(chunk) | |
| # Only yield chunks with non-None content | |
| if parsed["response_delta"]: | |
| yield ChatGenerationChunk( | |
| message=AIMessageChunk(content=parsed["response_delta"]) | |
| ) | |
| async def unified_call( | |
| self, | |
| system_message="", | |
| user_message="", | |
| messages: List[BaseMessage] | None = None, | |
| response_callback: Callable[[str, str], Awaitable[None]] | None = None, | |
| reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None, | |
| tokens_callback: Callable[[str, int], Awaitable[None]] | None = None, | |
| rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None, | |
| **kwargs: Any, | |
| ) -> Tuple[str, str]: | |
| turn_off_logging() | |
| if not messages: | |
| messages = [] | |
| # construct messages | |
| if system_message: | |
| messages.insert(0, SystemMessage(content=system_message)) | |
| if user_message: | |
| messages.append(HumanMessage(content=user_message)) | |
| # convert to litellm format | |
| msgs_conv = self._convert_messages(messages) | |
| # Apply rate limiting if configured | |
| limiter = await apply_rate_limiter(self.a0_model_conf, str(msgs_conv), rate_limiter_callback) | |
| # call model | |
| print(f"DEBUG: calling acompletion with model={self.model_name}") | |
| _completion = await acompletion( | |
| model=self.model_name, | |
| messages=msgs_conv, | |
| stream=True, | |
| **{**self.kwargs, **kwargs}, | |
| ) | |
| # results | |
| reasoning = "" | |
| response = "" | |
| # iterate over chunks | |
| async for chunk in _completion: # type: ignore | |
| parsed = _parse_chunk(chunk) | |
| # collect reasoning delta and call callbacks | |
| if parsed["reasoning_delta"]: | |
| reasoning += parsed["reasoning_delta"] | |
| if reasoning_callback: | |
| await reasoning_callback(parsed["reasoning_delta"], reasoning) | |
| if tokens_callback: | |
| await tokens_callback( | |
| parsed["reasoning_delta"], | |
| approximate_tokens(parsed["reasoning_delta"]), | |
| ) | |
| # Add output tokens to rate limiter if configured | |
| if limiter: | |
| limiter.add(output=approximate_tokens(parsed["reasoning_delta"])) | |
| # collect response delta and call callbacks | |
| if parsed["response_delta"]: | |
| response += parsed["response_delta"] | |
| if response_callback: | |
| await response_callback(parsed["response_delta"], response) | |
| if tokens_callback: | |
| await tokens_callback( | |
| parsed["response_delta"], | |
| approximate_tokens(parsed["response_delta"]), | |
| ) | |
| # Add output tokens to rate limiter if configured | |
| if limiter: | |
| limiter.add(output=approximate_tokens(parsed["response_delta"])) | |
| # return complete results | |
| return response, reasoning | |
| class BrowserCompatibleChatWrapper(LiteLLMChatWrapper): | |
| """ | |
| A wrapper for browser agent that can filter/sanitize messages | |
| before sending them to the LLM. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| turn_off_logging() | |
| super().__init__(*args, **kwargs) | |
| # Browser-use may expect a 'model' attribute | |
| self.model = self.model_name | |
| def _call( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| turn_off_logging() | |
| result = super()._call(messages, stop, run_manager, **kwargs) | |
| return result | |
| async def _astream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> AsyncIterator[ChatGenerationChunk]: | |
| turn_off_logging() | |
| async for chunk in super()._astream(messages, stop, run_manager, **kwargs): | |
| yield chunk | |
| class LiteLLMEmbeddingWrapper(Embeddings): | |
| model_name: str | |
| kwargs: dict = {} | |
| a0_model_conf: Optional[ModelConfig] = None | |
| def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): | |
| self.model_name = f"{provider}/{model}" if provider != "openai" else model | |
| self.kwargs = kwargs | |
| self.a0_model_conf = model_config | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| # Apply rate limiting if configured | |
| apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) | |
| resp = embedding(model=self.model_name, input=texts, **self.kwargs) | |
| return [ | |
| item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore | |
| for item in resp.data # type: ignore | |
| ] | |
| def embed_query(self, text: str) -> List[float]: | |
| # Apply rate limiting if configured | |
| apply_rate_limiter_sync(self.a0_model_conf, text) | |
| resp = embedding(model=self.model_name, input=[text], **self.kwargs) | |
| item = resp.data[0] # type: ignore | |
| return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore | |
| class LocalSentenceTransformerWrapper(Embeddings): | |
| """Local wrapper for sentence-transformers models to avoid HuggingFace API calls""" | |
| def __init__(self, provider: str, model: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): | |
| # Clean common user-input mistakes | |
| model = model.strip().strip('"').strip("'") | |
| # Remove the "sentence-transformers/" prefix if present | |
| if model.startswith("sentence-transformers/"): | |
| model = model[len("sentence-transformers/") :] | |
| self.model = SentenceTransformer(model, **kwargs) | |
| self.model_name = model | |
| self.a0_model_conf = model_config | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| # Apply rate limiting if configured | |
| apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) | |
| embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore | |
| return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore | |
| def embed_query(self, text: str) -> List[float]: | |
| # Apply rate limiting if configured | |
| apply_rate_limiter_sync(self.a0_model_conf, text) | |
| embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore | |
| result = ( | |
| embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0] | |
| ) | |
| return result # type: ignore | |
| def _get_litellm_chat( | |
| cls: type = LiteLLMChatWrapper, | |
| model_name: str = "", | |
| provider_name: str = "", | |
| model_config: Optional[ModelConfig] = None, | |
| **kwargs: Any, | |
| ): | |
| # use api key from kwargs or env | |
| api_key = kwargs.pop("api_key", None) or get_api_key(provider_name) | |
| # Only pass API key if key is not a placeholder | |
| if api_key and api_key not in ("None", "NA"): | |
| kwargs["api_key"] = api_key | |
| provider_name, model_name, kwargs = _adjust_call_args( | |
| provider_name, model_name, kwargs | |
| ) | |
| print(f"DEBUG: Creating {cls.__name__} with provider={provider_name}, model={model_name}, api_base={kwargs.get('api_base')}") | |
| return cls(provider=provider_name, model=model_name, model_config=model_config, **kwargs) | |
| def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): | |
| # Check if this is a local sentence-transformers model | |
| if provider_name == "huggingface" and model_name.startswith( | |
| "sentence-transformers/" | |
| ): | |
| # Use local sentence-transformers instead of LiteLLM for local models | |
| provider_name, model_name, kwargs = _adjust_call_args( | |
| provider_name, model_name, kwargs | |
| ) | |
| return LocalSentenceTransformerWrapper( | |
| provider=provider_name, model=model_name, model_config=model_config, **kwargs | |
| ) | |
| # use api key from kwargs or env | |
| api_key = kwargs.pop("api_key", None) or get_api_key(provider_name) | |
| # Only pass API key if key is not a placeholder | |
| if api_key and api_key not in ("None", "NA"): | |
| kwargs["api_key"] = api_key | |
| provider_name, model_name, kwargs = _adjust_call_args( | |
| provider_name, model_name, kwargs | |
| ) | |
| return LiteLLMEmbeddingWrapper(model=model_name, provider=provider_name, model_config=model_config, **kwargs) | |
| def _parse_chunk(chunk: Any) -> ChatChunk: | |
| delta = chunk["choices"][0].get("delta", {}) | |
| message = chunk["choices"][0].get("message", {}) or chunk["choices"][0].get( | |
| "model_extra", {} | |
| ).get("message", {}) | |
| response_delta = ( | |
| delta.get("content", "") | |
| if isinstance(delta, dict) | |
| else getattr(delta, "content", "") | |
| ) or ( | |
| message.get("content", "") | |
| if isinstance(message, dict) | |
| else getattr(message, "content", "") | |
| ) | |
| reasoning_delta = ( | |
| delta.get("reasoning_content", "") | |
| if isinstance(delta, dict) | |
| else getattr(delta, "reasoning_content", "") | |
| ) | |
| return ChatChunk(reasoning_delta=reasoning_delta, response_delta=response_delta) | |
| def _adjust_call_args(provider_name: str, model_name: str, kwargs: dict): | |
| # Robustly handle provider name if it's the label instead of ID | |
| label_to_id = { | |
| "other openai compatible": "other", | |
| "openai": "openai", | |
| "anthropic": "anthropic", | |
| "google": "google", | |
| "deepseek": "deepseek", | |
| "groq": "groq", | |
| "huggingface": "huggingface", | |
| "lm studio": "lm_studio", | |
| "mistral ai": "mistral", | |
| "ollama": "ollama", | |
| "openrouter": "openrouter", | |
| "sambanova": "sambanova", | |
| "venice": "venice" | |
| } | |
| provider_name_low = str(provider_name).lower() | |
| if provider_name_low in label_to_id: | |
| provider_name = label_to_id[provider_name_low] | |
| # for openrouter add app reference | |
| if provider_name == "openrouter": | |
| kwargs["extra_headers"] = { | |
| "HTTP-Referer": "https://agent-zero.ai", | |
| "X-Title": "Agent Zero", | |
| } | |
| # remap other to openai for litellm | |
| if provider_name == "other": | |
| provider_name = "openai" | |
| return provider_name, model_name, kwargs | |
| def _merge_provider_defaults( | |
| provider_type: str, original_provider: str, kwargs: dict | |
| ) -> tuple[str, dict]: | |
| provider_name = original_provider # default: unchanged | |
| # Robustly handle provider name if it's the label instead of ID | |
| label_to_id = { | |
| "other openai compatible": "other", | |
| "openai": "openai", | |
| "anthropic": "anthropic", | |
| "google": "google", | |
| "deepseek": "deepseek", | |
| "groq": "groq", | |
| "huggingface": "huggingface", | |
| "lm studio": "lm_studio", | |
| "mistral ai": "mistral", | |
| "ollama": "ollama", | |
| "openrouter": "openrouter", | |
| "sambanova": "sambanova", | |
| "venice": "venice" | |
| } | |
| orig_low = str(original_provider).lower() | |
| if orig_low in label_to_id: | |
| original_provider = label_to_id[orig_low] | |
| provider_name = original_provider | |
| cfg = get_provider_config(provider_type, original_provider) | |
| if cfg: | |
| provider_name = cfg.get("litellm_provider", original_provider).lower() | |
| # Extra arguments nested under `kwargs` for readability | |
| extra_kwargs = cfg.get("kwargs") if isinstance(cfg, dict) else None # type: ignore[arg-type] | |
| if isinstance(extra_kwargs, dict): | |
| for k, v in extra_kwargs.items(): | |
| kwargs.setdefault(k, v) | |
| # Inject API key based on the *original* provider id if still missing | |
| if "api_key" not in kwargs: | |
| key = get_api_key(original_provider) | |
| if key and key not in ("None", "NA"): | |
| kwargs["api_key"] = key | |
| return provider_name, kwargs | |
| def get_chat_model(provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any) -> LiteLLMChatWrapper: | |
| orig = str(provider).lower() | |
| provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) | |
| return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, model_config, **kwargs) | |
| def get_browser_model( | |
| provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any | |
| ) -> BrowserCompatibleChatWrapper: | |
| orig = str(provider).lower() | |
| provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) | |
| return _get_litellm_chat( | |
| BrowserCompatibleChatWrapper, name, provider_name, model_config, **kwargs | |
| ) | |
| def get_embedding_model( | |
| provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any | |
| ) -> LiteLLMEmbeddingWrapper | LocalSentenceTransformerWrapper: | |
| orig = str(provider).lower() | |
| provider_name, kwargs = _merge_provider_defaults("embedding", orig, kwargs) | |
| return _get_litellm_embedding(name, provider_name, model_config, **kwargs) | |