Spaces:
Running
Running
File size: 2,458 Bytes
848238a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import os
import gymnasium as gym
import torch
import numpy as np
from minigrid.wrappers import FlatObsWrapper
from stable_baselines3 import PPO
from tqdm import tqdm
class PPOHarvester:
"""
Utility to run a 'Teacher' PPO agent to collect high-quality state-action-reward triplets.
"""
def __init__(self, env_id="MiniGrid-Empty-8x8-v0", model_path=None):
self.env_id = env_id
self.env = FlatObsWrapper(gym.make(env_id, render_mode="rgb_array"))
if model_path and os.path.exists(model_path):
self.model = PPO.load(model_path, env=self.env)
else:
print(f"No model found at {model_path}. Training a new one for collection...")
self.model = PPO("MlpPolicy", self.env, verbose=1)
self.model.learn(total_timesteps=20000)
if model_path:
self.model.save(model_path)
def collect_trajectories(self, num_episodes=100):
trajectories = []
for i in tqdm(range(num_episodes), desc="Collecting trajectories"):
obs, _ = self.env.reset(seed=42 + i)
done = False
truncated = False
episode = {
"observations": [],
"actions": [],
"rewards": [],
"dones": []
}
while not (done or truncated):
action, _states = self.model.predict(obs, deterministic=False)
next_obs, reward, done, truncated, info = self.env.step(action)
episode["observations"].append(obs)
episode["actions"].append(action)
episode["rewards"].append(reward)
episode["dones"].append(done)
obs = next_obs
# Convert to numpy arrays
for key in episode:
episode[key] = np.array(episode[key])
trajectories.append(episode)
return trajectories
def save_trajectories(self, trajectories, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
torch.save(trajectories, file_path)
print(f"Saved {len(trajectories)} trajectories to {file_path}")
if __name__ == "__main__":
harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
trajs = harvester.collect_trajectories(num_episodes=50)
harvester.save_trajectories(trajs, "data/trajectories.pt")
|