shubham7080 commited on
Commit
df3cb9f
·
verified ·
1 Parent(s): 344fc3a

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +93 -0
model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Callable
3
+
4
+ from smolagents import HfApiModel, InferenceClientModel, LiteLLMModel, OpenAIServerModel
5
+
6
+
7
+ def get_huggingface_api_model(model_id: str, **kwargs) -> HfApiModel:
8
+ """
9
+ Returns a Hugging Face API model instance.
10
+ Args:
11
+ model_id (str): The model identifier.
12
+ **kwargs: Additional keyword arguments for the model.
13
+ Returns:
14
+ HfApiModel: Hugging Face API model instance.
15
+ """
16
+ api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN")
17
+ if not api_key:
18
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set")
19
+
20
+ return HfApiModel(model_id=model_id, token=api_key, **kwargs)
21
+
22
+
23
+ def get_inference_client_model(model_id: str, **kwargs) -> InferenceClientModel:
24
+ """
25
+ Returns an Inference Client model instance.
26
+ Args:
27
+ model_id (str): The model identifier.
28
+ **kwargs: Additional keyword arguments for the model.
29
+ Returns:
30
+ InferenceClientModel: Inference client model instance.
31
+ """
32
+ api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN")
33
+ if not api_key:
34
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set")
35
+
36
+ return InferenceClientModel(model_id=model_id, token=api_key, **kwargs)
37
+
38
+
39
+ def get_openai_server_model(model_id: str, **kwargs) -> OpenAIServerModel:
40
+ """
41
+ Returns an OpenAI server model instance.
42
+ Args:
43
+ model_id (str): The model identifier.
44
+ **kwargs: Additional keyword arguments for the model.
45
+ Returns:
46
+ OpenAIServerModel: OpenAI server model instance.
47
+ """
48
+ api_key = os.getenv("OPENAI_API_KEY")
49
+ if not api_key:
50
+ raise ValueError("OPENAI_API_KEY is not set")
51
+
52
+ api_base = os.getenv("OPENAI_API_BASE")
53
+ if not api_base:
54
+ raise ValueError("OPENAI_API_BASE is not set")
55
+
56
+ return OpenAIServerModel(
57
+ model_id=model_id, api_key=api_key, api_base=api_base, **kwargs
58
+ )
59
+
60
+
61
+ def get_lite_llm_model(model_id: str, **kwargs) -> LiteLLMModel:
62
+ """
63
+ Returns a LiteLLM model instance.
64
+ Args:
65
+ model_id (str): The model identifier.
66
+ **kwargs: Additional keyword arguments for the model.
67
+ Returns:
68
+ LiteLLMModel: LiteLLM model instance.
69
+ """
70
+ return LiteLLMModel(model_id=model_id, **kwargs)
71
+
72
+
73
+ def get_model(model_type: str, model_id: str, **kwargs) -> Any:
74
+ """
75
+ Returns a model instance based on the specified type.
76
+ Args:
77
+ model_type (str): The type of the model (e.g., 'HfApiModel').
78
+ model_id (str): The model identifier.
79
+ **kwargs: Additional keyword arguments for the model.
80
+ Returns:
81
+ Any: Model instance of the specified type.
82
+ """
83
+ models: dict[str, Callable[..., Any]] = {
84
+ "HfApiModel": get_huggingface_api_model,
85
+ "InferenceClientModel": get_inference_client_model,
86
+ "OpenAIServerModel": get_openai_server_model,
87
+ "LiteLLMModel": get_lite_llm_model,
88
+ }
89
+
90
+ if model_type not in models:
91
+ raise ValueError(f"Unknown model type: {model_type}")
92
+
93
+ return models[model_type](model_id, **kwargs)