prashantmatlani commited on
Commit
79c8057
·
1 Parent(s): b03ffc7

modified task graders to include task name

Browse files
Files changed (1) hide show
  1. inference.py +7 -5
inference.py CHANGED
@@ -45,7 +45,7 @@ def compute_score(success, steps, rewards):
45
  return max(0.01, min(0.99, score))
46
 
47
 
48
- def run_single_task(task_id):
49
  env = CustomerSupportEnv()
50
  obs = env.reset()
51
 
@@ -75,7 +75,7 @@ def run_single_task(task_id):
75
  rewards.append(reward)
76
 
77
  print(
78
- f"[STEP] task={task_id} step={step_count} "
79
  f"action={format_action(action)} "
80
  f"reward={reward:.2f} "
81
  f"done={'true' if done else 'false'} "
@@ -88,7 +88,7 @@ def run_single_task(task_id):
88
 
89
  except Exception as e:
90
  print(
91
- f"[STEP] task={task_id} step={step_count+1} "
92
  f"action=null reward=0.00 done=true error={str(e)}"
93
  )
94
 
@@ -97,7 +97,7 @@ def run_single_task(task_id):
97
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
98
 
99
  print(
100
- f"[END] task={task_id} "
101
  f"success={'true' if success else 'false'} "
102
  f"steps={step_count} "
103
  f"score={score:.2f} "
@@ -123,7 +123,9 @@ def main():
123
  NUM_TASKS = 3
124
 
125
  for i in range(NUM_TASKS):
126
- run_single_task(task_id=i + 1)
 
 
127
 
128
 
129
  if __name__ == "__main__":
 
45
  return max(0.01, min(0.99, score))
46
 
47
 
48
+ def run_single_task(task_name):
49
  env = CustomerSupportEnv()
50
  obs = env.reset()
51
 
 
75
  rewards.append(reward)
76
 
77
  print(
78
+ f"[STEP] task={task_name} step={step_count} "
79
  f"action={format_action(action)} "
80
  f"reward={reward:.2f} "
81
  f"done={'true' if done else 'false'} "
 
88
 
89
  except Exception as e:
90
  print(
91
+ f"[STEP] task={task_name} step={step_count+1} "
92
  f"action=null reward=0.00 done=true error={str(e)}"
93
  )
94
 
 
97
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
98
 
99
  print(
100
+ f"[END] task={task_name} "
101
  f"success={'true' if success else 'false'} "
102
  f"steps={step_count} "
103
  f"score={score:.2f} "
 
123
  NUM_TASKS = 3
124
 
125
  for i in range(NUM_TASKS):
126
+ #run_single_task(task_id=i + 1)
127
+ task_name = f"customer-support-{i+1}"
128
+ run_single_task(task_name)
129
 
130
 
131
  if __name__ == "__main__":