prashantmatlani commited on
Commit
b03ffc7
·
1 Parent(s): 5683339

modified task grader

Browse files
Files changed (1) hide show
  1. inference.py +45 -38
inference.py CHANGED
@@ -29,32 +29,30 @@ def format_action(action: dict) -> str:
29
 
30
  return str(action)
31
 
32
- def main():
 
 
 
 
33
 
34
- env = CustomerSupportEnv()
35
- obs = env.reset()
 
 
 
36
 
37
- model_name = os.getenv("MODEL_NAME", "unknown-model")
38
- #model_name="llama-3.1-8b-instant"
39
 
40
- api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
41
 
42
- #api_base_url = os.getenv("API_BASE_URL")
43
-
44
- print(f"[CONFIG] api_base_url={api_base_url}")
45
-
46
- task_name = "customer-support"
47
- benchmark = "openenv"
48
 
49
  step_count = 0
50
  rewards = []
51
  success = False
52
 
53
- # =========================
54
- # START
55
- # =========================
56
- print(f"[START] task={task_name} env={benchmark} model={model_name}")
57
-
58
  try:
59
  done = False
60
 
@@ -76,11 +74,8 @@ def main():
76
  step_count += 1
77
  rewards.append(reward)
78
 
79
- # =========================
80
- # STEP
81
- # =========================
82
  print(
83
- f"[STEP] step={step_count} "
84
  f"action={format_action(action)} "
85
  f"reward={reward:.2f} "
86
  f"done={'true' if done else 'false'} "
@@ -89,35 +84,47 @@ def main():
89
 
90
  obs = next_obs
91
 
92
- # success from env
93
  success = info.get("task_success", False)
94
 
95
  except Exception as e:
96
- # still must print END
97
  print(
98
- f"[STEP] step={step_count+1} "
99
  f"action=null reward=0.00 done=true error={str(e)}"
100
  )
101
 
102
- finally:
103
- # =========================
104
- # END
105
- # =========================
106
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
107
 
108
- score = 1.0 if success else 0.0
109
 
110
- #print(
111
- # f"[END] success={'true' if success else 'false'} "
112
- # f"steps={step_count} "
113
- # f"rewards={rewards_str}"
114
- #)
115
- print(
116
- f"[END] success={'true' if success else 'false'} "
117
  f"steps={step_count} "
118
  f"score={score:.2f} "
119
  f"rewards={rewards_str}"
120
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
  main()
 
29
 
30
  return str(action)
31
 
32
+ def compute_score(success, steps, rewards):
33
+ """
34
+ Continuous score in (0,1)
35
+ """
36
+ avg_reward = sum(rewards) / max(1, len(rewards))
37
 
38
+ score = (
39
+ 0.5 * (1.0 if success else 0.0) +
40
+ 0.3 * (1 / (1 + steps)) +
41
+ 0.2 * max(0, min(1, avg_reward))
42
+ )
43
 
44
+ # Clamp to (0,1) but not exact
45
+ return max(0.01, min(0.99, score))
46
 
 
47
 
48
+ def run_single_task(task_id):
49
+ env = CustomerSupportEnv()
50
+ obs = env.reset()
 
 
 
51
 
52
  step_count = 0
53
  rewards = []
54
  success = False
55
 
 
 
 
 
 
56
  try:
57
  done = False
58
 
 
74
  step_count += 1
75
  rewards.append(reward)
76
 
 
 
 
77
  print(
78
+ f"[STEP] task={task_id} step={step_count} "
79
  f"action={format_action(action)} "
80
  f"reward={reward:.2f} "
81
  f"done={'true' if done else 'false'} "
 
84
 
85
  obs = next_obs
86
 
 
87
  success = info.get("task_success", False)
88
 
89
  except Exception as e:
 
90
  print(
91
+ f"[STEP] task={task_id} step={step_count+1} "
92
  f"action=null reward=0.00 done=true error={str(e)}"
93
  )
94
 
95
+ score = compute_score(success, step_count, rewards)
 
 
 
 
96
 
97
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
98
 
99
+ print(
100
+ f"[END] task={task_id} "
101
+ f"success={'true' if success else 'false'} "
 
 
 
 
102
  f"steps={step_count} "
103
  f"score={score:.2f} "
104
  f"rewards={rewards_str}"
105
+ )
106
+
107
+
108
+ def main():
109
+
110
+ model_name = os.getenv("MODEL_NAME", "unknown-model")
111
+ api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
112
+
113
+ print(f"[CONFIG] api_base_url={api_base_url}")
114
+
115
+ task_name = "customer-support"
116
+ benchmark = "openenv"
117
+
118
+ print(f"[START] task={task_name} env={benchmark} model={model_name}")
119
+
120
+ # =========================
121
+ # RUN MULTIPLE TASKS (IMPORTANT)
122
+ # =========================
123
+ NUM_TASKS = 3
124
+
125
+ for i in range(NUM_TASKS):
126
+ run_single_task(task_id=i + 1)
127
+
128
 
129
  if __name__ == "__main__":
130
  main()