Spaces:
Sleeping
Sleeping
Commit ·
12a8a0f
1
Parent(s): d6a76d5
use MODEL_NAME from env instead of hardconding
Browse files- agent.py +2 -2
- agent_llm.py +7 -6
- inference.py +2 -2
agent.py
CHANGED
|
@@ -224,8 +224,8 @@ FORMAT:
|
|
| 224 |
|
| 225 |
def call_llm(prompt):
|
| 226 |
completion = client.chat.completions.create(
|
| 227 |
-
|
| 228 |
-
model="llama-3.1-8b-instant",
|
| 229 |
messages=[{"role": "user", "content": prompt}],
|
| 230 |
temperature=0.2,
|
| 231 |
response_format={"type": "json_object"}
|
|
|
|
| 224 |
|
| 225 |
def call_llm(prompt):
|
| 226 |
completion = client.chat.completions.create(
|
| 227 |
+
model=os.getenv("MODEL_NAME"),
|
| 228 |
+
#model="llama-3.1-8b-instant",
|
| 229 |
messages=[{"role": "user", "content": prompt}],
|
| 230 |
temperature=0.2,
|
| 231 |
response_format={"type": "json_object"}
|
agent_llm.py
CHANGED
|
@@ -66,6 +66,7 @@ FORMAT:
|
|
| 66 |
def call_llm(prompt):
|
| 67 |
completion = client.chat.completions.create(
|
| 68 |
model=os.getenv("MODEL_NAME"),
|
|
|
|
| 69 |
messages=[{"role": "user", "content": prompt}],
|
| 70 |
temperature=0.2,
|
| 71 |
response_format={"type": "json_object"}
|
|
@@ -204,16 +205,16 @@ def run_agent():
|
|
| 204 |
obs, reward, done, info = env.step(action)
|
| 205 |
|
| 206 |
|
| 207 |
-
print(f"\nOBS: {obs}")
|
| 208 |
-
print(f"\nACTION: {action}")
|
| 209 |
-
print(f"\nREWARD: {reward}")
|
| 210 |
-
print(f"\nDONE: {done}")
|
| 211 |
|
| 212 |
|
| 213 |
#print("FINAL:", info)
|
| 214 |
-
print(f"\nFINAL: {info if info else 'No info returned'}")
|
| 215 |
|
| 216 |
-
print(f"\nMETRICS: {env.get_metrics()}")
|
| 217 |
|
| 218 |
|
| 219 |
if __name__ == "__main__":
|
|
|
|
| 66 |
def call_llm(prompt):
|
| 67 |
completion = client.chat.completions.create(
|
| 68 |
model=os.getenv("MODEL_NAME"),
|
| 69 |
+
#model="llama-3.1-8b-instant",
|
| 70 |
messages=[{"role": "user", "content": prompt}],
|
| 71 |
temperature=0.2,
|
| 72 |
response_format={"type": "json_object"}
|
|
|
|
| 205 |
obs, reward, done, info = env.step(action)
|
| 206 |
|
| 207 |
|
| 208 |
+
#print(f"\nOBS: {obs}")
|
| 209 |
+
#print(f"\nACTION: {action}")
|
| 210 |
+
#print(f"\nREWARD: {reward}")
|
| 211 |
+
#print(f"\nDONE: {done}")
|
| 212 |
|
| 213 |
|
| 214 |
#print("FINAL:", info)
|
| 215 |
+
#print(f"\nFINAL: {info if info else 'No info returned'}")
|
| 216 |
|
| 217 |
+
#print(f"\nMETRICS: {env.get_metrics()}")
|
| 218 |
|
| 219 |
|
| 220 |
if __name__ == "__main__":
|
inference.py
CHANGED
|
@@ -18,8 +18,8 @@ def main():
|
|
| 18 |
env = CustomerSupportEnv()
|
| 19 |
obs = env.reset()
|
| 20 |
|
| 21 |
-
|
| 22 |
-
model_name="llama-3.1-8b-instant"
|
| 23 |
|
| 24 |
task_name = "customer-support"
|
| 25 |
benchmark = "openenv"
|
|
|
|
| 18 |
env = CustomerSupportEnv()
|
| 19 |
obs = env.reset()
|
| 20 |
|
| 21 |
+
model_name = os.getenv("MODEL_NAME", "unknown-model")
|
| 22 |
+
#model_name="llama-3.1-8b-instant"
|
| 23 |
|
| 24 |
task_name = "customer-support"
|
| 25 |
benchmark = "openenv"
|