prashantmatlani commited on
Commit
12a8a0f
·
1 Parent(s): d6a76d5

use MODEL_NAME from env instead of hardconding

Browse files
Files changed (3) hide show
  1. agent.py +2 -2
  2. agent_llm.py +7 -6
  3. inference.py +2 -2
agent.py CHANGED
@@ -224,8 +224,8 @@ FORMAT:
224
 
225
  def call_llm(prompt):
226
  completion = client.chat.completions.create(
227
- #model=os.getenv("MODEL_NAME"),
228
- model="llama-3.1-8b-instant",
229
  messages=[{"role": "user", "content": prompt}],
230
  temperature=0.2,
231
  response_format={"type": "json_object"}
 
224
 
225
  def call_llm(prompt):
226
  completion = client.chat.completions.create(
227
+ model=os.getenv("MODEL_NAME"),
228
+ #model="llama-3.1-8b-instant",
229
  messages=[{"role": "user", "content": prompt}],
230
  temperature=0.2,
231
  response_format={"type": "json_object"}
agent_llm.py CHANGED
@@ -66,6 +66,7 @@ FORMAT:
66
  def call_llm(prompt):
67
  completion = client.chat.completions.create(
68
  model=os.getenv("MODEL_NAME"),
 
69
  messages=[{"role": "user", "content": prompt}],
70
  temperature=0.2,
71
  response_format={"type": "json_object"}
@@ -204,16 +205,16 @@ def run_agent():
204
  obs, reward, done, info = env.step(action)
205
 
206
 
207
- print(f"\nOBS: {obs}")
208
- print(f"\nACTION: {action}")
209
- print(f"\nREWARD: {reward}")
210
- print(f"\nDONE: {done}")
211
 
212
 
213
  #print("FINAL:", info)
214
- print(f"\nFINAL: {info if info else 'No info returned'}")
215
 
216
- print(f"\nMETRICS: {env.get_metrics()}")
217
 
218
 
219
  if __name__ == "__main__":
 
66
  def call_llm(prompt):
67
  completion = client.chat.completions.create(
68
  model=os.getenv("MODEL_NAME"),
69
+ #model="llama-3.1-8b-instant",
70
  messages=[{"role": "user", "content": prompt}],
71
  temperature=0.2,
72
  response_format={"type": "json_object"}
 
205
  obs, reward, done, info = env.step(action)
206
 
207
 
208
+ #print(f"\nOBS: {obs}")
209
+ #print(f"\nACTION: {action}")
210
+ #print(f"\nREWARD: {reward}")
211
+ #print(f"\nDONE: {done}")
212
 
213
 
214
  #print("FINAL:", info)
215
+ #print(f"\nFINAL: {info if info else 'No info returned'}")
216
 
217
+ #print(f"\nMETRICS: {env.get_metrics()}")
218
 
219
 
220
  if __name__ == "__main__":
inference.py CHANGED
@@ -18,8 +18,8 @@ def main():
18
  env = CustomerSupportEnv()
19
  obs = env.reset()
20
 
21
- #model_name = os.getenv("MODEL_NAME", "unknown-model")
22
- model_name="llama-3.1-8b-instant"
23
 
24
  task_name = "customer-support"
25
  benchmark = "openenv"
 
18
  env = CustomerSupportEnv()
19
  obs = env.reset()
20
 
21
+ model_name = os.getenv("MODEL_NAME", "unknown-model")
22
+ #model_name="llama-3.1-8b-instant"
23
 
24
  task_name = "customer-support"
25
  benchmark = "openenv"