Spaces:
Sleeping
Sleeping
| """Trivial random baseline agent for PolypharmacyEnv.""" | |
| from __future__ import annotations | |
| import random | |
| from typing import List, Tuple | |
| from ..env_core import PolypharmacyEnv | |
| from ..models import PolypharmacyAction, PolypharmacyObservation | |
| def run_random_episode( | |
| env: PolypharmacyEnv, | |
| task_id: str = "budgeted_screening", | |
| seed: int | None = None, | |
| ) -> Tuple[float, float, int]: | |
| rng = random.Random(seed) | |
| obs = env.reset(task_id=task_id, seed=seed) | |
| total_reward = 0.0 | |
| grader_score = 0.0 | |
| steps = 0 | |
| while not obs.done: | |
| med_ids = [m.drug_id for m in obs.current_medications] | |
| choice = rng.choice(["query_ddi", "propose_intervention", "finish_review"]) | |
| if choice == "query_ddi" and len(med_ids) >= 2 and obs.remaining_query_budget > 0: | |
| pair = rng.sample(med_ids, 2) | |
| action = PolypharmacyAction( | |
| action_type="query_ddi", | |
| drug_id_1=pair[0], | |
| drug_id_2=pair[1], | |
| ) | |
| elif choice == "propose_intervention" and med_ids and obs.remaining_intervention_budget > 0: | |
| target = rng.choice(med_ids) | |
| itype = rng.choice(["stop", "dose_reduce", "substitute", "add_monitoring"]) | |
| action = PolypharmacyAction( | |
| action_type="propose_intervention", | |
| target_drug_id=target, | |
| intervention_type=itype, | |
| rationale="random", | |
| ) | |
| else: | |
| action = PolypharmacyAction(action_type="finish_review") | |
| obs = env.step(action) | |
| total_reward += obs.reward or 0.0 | |
| steps += 1 | |
| if obs.done: | |
| grader_score = obs.metadata.get("grader_score", 0.0) | |
| break | |
| return total_reward, grader_score, steps | |