DT-Explorer / src /data /harvester.py
sadhumitha-s's picture
docs: update readme references and add modular trajectory harvester
848238a
import os
import gymnasium as gym
import torch
import numpy as np
from minigrid.wrappers import FlatObsWrapper
from stable_baselines3 import PPO
from tqdm import tqdm
class PPOHarvester:
"""
Utility to run a 'Teacher' PPO agent to collect high-quality state-action-reward triplets.
"""
def __init__(self, env_id="MiniGrid-Empty-8x8-v0", model_path=None):
self.env_id = env_id
self.env = FlatObsWrapper(gym.make(env_id, render_mode="rgb_array"))
if model_path and os.path.exists(model_path):
self.model = PPO.load(model_path, env=self.env)
else:
print(f"No model found at {model_path}. Training a new one for collection...")
self.model = PPO("MlpPolicy", self.env, verbose=1)
self.model.learn(total_timesteps=20000)
if model_path:
self.model.save(model_path)
def collect_trajectories(self, num_episodes=100):
trajectories = []
for i in tqdm(range(num_episodes), desc="Collecting trajectories"):
obs, _ = self.env.reset(seed=42 + i)
done = False
truncated = False
episode = {
"observations": [],
"actions": [],
"rewards": [],
"dones": []
}
while not (done or truncated):
action, _states = self.model.predict(obs, deterministic=False)
next_obs, reward, done, truncated, info = self.env.step(action)
episode["observations"].append(obs)
episode["actions"].append(action)
episode["rewards"].append(reward)
episode["dones"].append(done)
obs = next_obs
# Convert to numpy arrays
for key in episode:
episode[key] = np.array(episode[key])
trajectories.append(episode)
return trajectories
def save_trajectories(self, trajectories, file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
torch.save(trajectories, file_path)
print(f"Saved {len(trajectories)} trajectories to {file_path}")
if __name__ == "__main__":
harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
trajs = harvester.collect_trajectories(num_episodes=50)
harvester.save_trajectories(trajs, "data/trajectories.pt")