File size: 2,458 Bytes
848238a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import gymnasium as gym
import torch
import numpy as np
from minigrid.wrappers import FlatObsWrapper
from stable_baselines3 import PPO
from tqdm import tqdm

class PPOHarvester:
    """
    Utility to run a 'Teacher' PPO agent to collect high-quality state-action-reward triplets.
    """
    def __init__(self, env_id="MiniGrid-Empty-8x8-v0", model_path=None):
        self.env_id = env_id
        self.env = FlatObsWrapper(gym.make(env_id, render_mode="rgb_array"))
        if model_path and os.path.exists(model_path):
            self.model = PPO.load(model_path, env=self.env)
        else:
            print(f"No model found at {model_path}. Training a new one for collection...")
            self.model = PPO("MlpPolicy", self.env, verbose=1)
            self.model.learn(total_timesteps=20000)
            if model_path:
                self.model.save(model_path)

    def collect_trajectories(self, num_episodes=100):
        trajectories = []
        for i in tqdm(range(num_episodes), desc="Collecting trajectories"):
            obs, _ = self.env.reset(seed=42 + i)
            done = False
            truncated = False
            episode = {
                "observations": [],
                "actions": [],
                "rewards": [],
                "dones": []
            }
            while not (done or truncated):
                action, _states = self.model.predict(obs, deterministic=False)
                next_obs, reward, done, truncated, info = self.env.step(action)
                
                episode["observations"].append(obs)
                episode["actions"].append(action)
                episode["rewards"].append(reward)
                episode["dones"].append(done)
                
                obs = next_obs
            
            # Convert to numpy arrays
            for key in episode:
                episode[key] = np.array(episode[key])
            
            trajectories.append(episode)
        
        return trajectories

    def save_trajectories(self, trajectories, file_path):
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        torch.save(trajectories, file_path)
        print(f"Saved {len(trajectories)} trajectories to {file_path}")

if __name__ == "__main__":
    harvester = PPOHarvester(model_path="ppo_minigrid_teacher.zip")
    trajs = harvester.collect_trajectories(num_episodes=50)
    harvester.save_trajectories(trajs, "data/trajectories.pt")