adithya9903's picture
fix: monorepo
e543908
"""Trivial random baseline agent for PolypharmacyEnv."""
from __future__ import annotations
import random
from typing import List, Tuple
from ..env_core import PolypharmacyEnv
from ..models import PolypharmacyAction, PolypharmacyObservation
def run_random_episode(
env: PolypharmacyEnv,
task_id: str = "budgeted_screening",
seed: int | None = None,
) -> Tuple[float, float, int]:
rng = random.Random(seed)
obs = env.reset(task_id=task_id, seed=seed)
total_reward = 0.0
grader_score = 0.0
steps = 0
while not obs.done:
med_ids = [m.drug_id for m in obs.current_medications]
choice = rng.choice(["query_ddi", "propose_intervention", "finish_review"])
if choice == "query_ddi" and len(med_ids) >= 2 and obs.remaining_query_budget > 0:
pair = rng.sample(med_ids, 2)
action = PolypharmacyAction(
action_type="query_ddi",
drug_id_1=pair[0],
drug_id_2=pair[1],
)
elif choice == "propose_intervention" and med_ids and obs.remaining_intervention_budget > 0:
target = rng.choice(med_ids)
itype = rng.choice(["stop", "dose_reduce", "substitute", "add_monitoring"])
action = PolypharmacyAction(
action_type="propose_intervention",
target_drug_id=target,
intervention_type=itype,
rationale="random",
)
else:
action = PolypharmacyAction(action_type="finish_review")
obs = env.step(action)
total_reward += obs.reward or 0.0
steps += 1
if obs.done:
grader_score = obs.metadata.get("grader_score", 0.0)
break
return total_reward, grader_score, steps