Spaces:
Running
Running
| import os | |
| import gymnasium as gym | |
| import torch | |
| import numpy as np | |
| from minigrid.wrappers import FlatObsWrapper | |
| from stable_baselines3 import PPO | |
| from tqdm import tqdm | |
| class PPOHarvester: | |
| """ | |
| Utility to run a 'Teacher' PPO agent to collect high-quality state-action-reward triplets. | |
| """ | |
| def __init__(self, env_id="MiniGrid-Empty-8x8-v0", model_path=None): | |
| self.env_id = env_id | |
| self.env = FlatObsWrapper(gym.make(env_id, render_mode="rgb_array")) | |
| if model_path and os.path.exists(model_path): | |
| self.model = PPO.load(model_path, env=self.env) | |
| else: | |
| print(f"No model found at {model_path}. Training a new one for collection...") | |
| self.model = PPO("MlpPolicy", self.env, verbose=1) | |
| self.model.learn(total_timesteps=20000) | |
| if model_path: | |
| self.model.save(model_path) | |
| def collect_trajectories(self, num_episodes=100): | |
| trajectories = [] | |
| for i in tqdm(range(num_episodes), desc="Collecting trajectories"): | |
| obs, _ = self.env.reset(seed=42 + i) | |
| done = False | |
| truncated = False | |
| episode = { | |
| "observations": [], | |
| "actions": [], | |
| "rewards": [], | |
| "dones": [] | |
| } | |
| while not (done or truncated): | |
| action, _states = self.model.predict(obs, deterministic=False) | |
| next_obs, reward, done, truncated, info = self.env.step(action) | |
| episode["observations"].append(obs) | |
| episode["actions"].append(action) | |
| episode["rewards"].append(reward) | |
| episode["dones"].append(done) | |
| obs = next_obs | |
| # Convert to numpy arrays | |
| for key in episode: | |
| episode[key] = np.array(episode[key]) | |
| trajectories.append(episode) | |
| return trajectories | |
| def save_trajectories(self, trajectories, file_path): | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| torch.save(trajectories, file_path) | |
| print(f"Saved {len(trajectories)} trajectories to {file_path}") | |
| if __name__ == "__main__": | |
| harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip") | |
| trajs = harvester.collect_trajectories(num_episodes=50) | |
| harvester.save_trajectories(trajs, "data/trajectories.pt") | |