github-sync-test / smolagents /codeagent_litellm_template.py
billyaungmyint's picture
Sync from GitHub via hub-sync
135756a verified
import argparse
import os
from dataclasses import dataclass
from smolagents import CodeAgent, LiteLLMModel
@dataclass(frozen=True)
class ProviderPreset:
name: str
api_key_env: str
api_base_env: str
model_env: str
default_api_base: str
default_model: str
PROVIDERS: dict[str, ProviderPreset] = {
# GitHub Models (OpenAI-compatible endpoint)
"github": ProviderPreset(
name="github",
api_key_env="GITHUB_TOKEN",
api_base_env="GITHUB_API_BASE",
model_env="GITHUB_MODEL_ID",
default_api_base="https://models.inference.ai.azure.com",
default_model="openai/gpt-4.1-mini",
),
# Gemini via Google OpenAI-compatible endpoint
"gemini": ProviderPreset(
name="gemini",
api_key_env="GEMINI_API_KEY",
api_base_env="GEMINI_API_BASE",
model_env="GEMINI_MODEL_ID",
default_api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
default_model="gemini-2.0-flash",
),
# Z.ai / GLM (BigModel) - set ZAI_API_BASE to your account endpoint if needed
"zai": ProviderPreset(
name="zai",
api_key_env="ZAI_API_KEY",
api_base_env="ZAI_API_BASE",
model_env="ZAI_MODEL_ID",
default_api_base="https://open.bigmodel.cn/api/paas/v4/",
default_model="zhipuai/glm-4.5",
),
# Hugging Face Inference Router (OpenAI-compatible endpoint)
"huggingface": ProviderPreset(
name="huggingface",
api_key_env="HF_TOKEN",
api_base_env="HF_API_BASE",
model_env="HF_MODEL_ID",
default_api_base="https://router.huggingface.co/v1",
default_model="moonshotai/Kimi-K2.5",
),
# Fireworks (OpenAI-compatible endpoint)
"fireworks": ProviderPreset(
name="fireworks",
api_key_env="FIREWORKS_API_KEY",
api_base_env="FIREWORKS_API_BASE",
model_env="FIREWORKS_MODEL_ID",
default_api_base="https://api.fireworks.ai/inference/v1",
default_model="accounts/fireworks/models/llama-v3p1-8b-instruct",
),
# Minimax (OpenAI-compatible endpoint)
"minimax": ProviderPreset(
name="minimax",
api_key_env="MINIMAX_API_KEY",
api_base_env="MINIMAX_API_BASE",
model_env="MINIMAX_MODEL_ID",
default_api_base="https://api.minimax.chat/v1",
default_model="MiniMax-Text-01",
),
}
def build_model(provider: str, model_override: str | None, temperature: float) -> LiteLLMModel:
preset = PROVIDERS[provider]
api_key = os.getenv(preset.api_key_env)
api_base = os.getenv(preset.api_base_env, preset.default_api_base)
model_id = model_override or os.getenv(preset.model_env, preset.default_model)
if not api_key:
raise ValueError(
f"Missing API key. Set environment variable {preset.api_key_env} for provider '{provider}'."
)
return LiteLLMModel(
model_id=model_id,
api_base=api_base,
api_key=api_key,
temperature=temperature,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="smolagents CodeAgent + LiteLLMModel template for multi-provider usage"
)
parser.add_argument(
"--provider",
choices=sorted(PROVIDERS.keys()),
default="gemini",
help="Which provider preset to use",
)
parser.add_argument(
"--model",
default=None,
help="Optional model override (otherwise uses provider env/default)",
)
parser.add_argument(
"--prompt",
default="Explain how AI agents work in 3 concise bullet points.",
help="Prompt to run",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="Sampling temperature",
)
parser.add_argument(
"--max-steps",
type=int,
default=5,
help="Maximum CodeAgent steps",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print resolved config without calling the model",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
preset = PROVIDERS[args.provider]
resolved_api_base = os.getenv(preset.api_base_env, preset.default_api_base)
resolved_model = args.model or os.getenv(preset.model_env, preset.default_model)
if args.dry_run:
print("Provider:", args.provider)
print("Model:", resolved_model)
print("API base:", resolved_api_base)
print("Required API key env:", preset.api_key_env)
return
model = build_model(args.provider, args.model, args.temperature)
# CodeAgent works best when models follow strict structured output behavior.
# 'markdown' tags and structured outputs improve cross-provider reliability.
agent = CodeAgent(
tools=[],
model=model,
max_steps=args.max_steps,
code_block_tags="markdown",
use_structured_outputs_internally=True,
)
answer = agent.run(args.prompt)
print(answer)
if __name__ == "__main__":
main()