openenv-workflow-agent / inference.py
Imsachin010's picture
update graders to ensure strictly fractional scores and run all 3 tasks
6ca88b7
import os
from openai import OpenAI
from app.env import WorkflowEnv
from app.actions import Action
from tasks.easy import create_easy_task
from tasks.medium import create_medium_task
from tasks.hard import create_hard_task
from graders.easy_grader import EasyGrader
from graders.medium_grader import MediumGrader
from graders.hard_grader import HardGrader
# ---------------- ENV CONFIG (CRITICAL) ----------------
API_BASE_URL = os.environ["API_BASE_URL"]
API_KEY = os.environ["API_KEY"]
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
client = OpenAI(
base_url=API_BASE_URL,
api_key=API_KEY
)
# ---------------- LOGGING ----------------
def log_start(task, env, model):
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step, action, reward, done, error):
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}",
flush=True,
)
def log_end(success, steps, score, rewards):
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}",
flush=True,
)
# ---------------- LLM ACTION (MANDATORY) ----------------
def llm_decide_action(email):
prompt = f"""
Email:
Subject: {email.subject}
Body: {email.body}
Choose ONE action:
classify, request_info, archive
Output only the action name.
"""
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
max_tokens=10,
temperature=0.0,
)
action_text = completion.choices[0].message.content.strip().lower()
return action_text
except Exception as e:
print(f"LLM API error: {e}")
return "classify"
# ---------------- POLICY ----------------
def get_action(obs):
if not obs.emails:
return None
email = obs.emails[0]
# 🔥 ALWAYS CALL LLM (important for validator)
action_text = llm_decide_action(email)
# 🔥 Guardrail (to avoid looping)
already_asked = any(
h["action"]["type"] == "request_info"
for h in obs.history
)
if already_asked:
return Action(
type="classify",
target_id=email.id,
payload={"label": "meeting_request"}
)
if "request" in action_text:
return Action(type="request_info", target_id=email.id)
elif "classify" in action_text:
return Action(
type="classify",
target_id=email.id,
payload={"label": "meeting_request"}
)
return Action(type="archive", target_id=email.id)
# ---------------- MAIN ----------------
def main():
tasks = [
("easy", create_easy_task, EasyGrader),
("medium", create_medium_task, MediumGrader),
("hard", create_hard_task, HardGrader),
]
for task_name, create_func, GraderClass in tasks:
state, gt = create_func()
env = WorkflowEnv(state)
grader = GraderClass()
obs = env.reset()
rewards = []
steps = 0
log_start(task_name, "workflow-env", MODEL_NAME)
try:
done = False
while not done and steps < 10:
action = get_action(obs)
if action is None:
break
obs, reward, done, _ = env.step(action)
rewards.append(reward)
steps += 1
log_step(steps, action.type, reward, done, None)
# stop after meaningful action
if action.type == "classify":
break
trajectory = env.state().history
score = grader.grade(trajectory, gt)
score = max(0.01, min(0.99, float(score))) # Strictly between 0 and 1
success = score > 0.3
finally:
log_end(success, steps, score, rewards)
if __name__ == "__main__":
main()