Spaces:
Sleeping
Sleeping
File size: 2,995 Bytes
58b6f47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import gymnasium as gym
import numpy as np
import pickle
from pathlib import Path
from tqdm import tqdm
STATE_DIM = 4
ACTION_DIM = 2
def _heuristic_action(state, noise: float) -> int:
"""Simple pole-angle heuristic with configurable noise."""
if np.random.random() < noise:
return np.random.randint(ACTION_DIM)
# Push in the direction the pole is leaning
return 1 if state[2] > 0 else 0
def collect_trajectory(env, noise: float) -> list:
"""
Collect one (state, action, next_state) trajectory.
No reward labels stored — consistent with offline RLHF setup.
"""
state, _ = env.reset()
trajectory = []
for _ in range(500):
action = _heuristic_action(state, noise)
next_state, _, terminated, truncated, _ = env.step(action)
trajectory.append((state.copy(), int(action), next_state.copy()))
state = next_state
if terminated or truncated:
break
return trajectory
def collect_offline_dataset(n_trajectories=800, save_path="data/offline_dataset.pkl") -> list:
"""
Collect a mixed-quality offline dataset from CartPole-v1.
Quality levels:
- noise=1.0 → fully random (low quality)
- noise=0.7 → mostly random (medium-low)
- noise=0.3 → mostly heuristic (medium-high)
- noise=0.05 → near-expert (high quality)
Returns list of trajectories, each a list of (s, a, s') tuples.
"""
env = gym.make("CartPole-v1")
noise_levels = [1.0, 0.7, 0.3, 0.05]
per_level = n_trajectories // len(noise_levels)
trajectories = []
for noise in noise_levels:
for _ in tqdm(range(per_level), desc=f"Collecting (noise={noise:.2f})", leave=False):
trajectories.append(collect_trajectory(env, noise))
env.close()
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "wb") as f:
pickle.dump(trajectories, f)
lengths = [len(t) for t in trajectories]
print(f"Collected {len(trajectories)} trajectories | "
f"len: min={min(lengths)} mean={np.mean(lengths):.1f} max={max(lengths)}")
return trajectories
def load_dataset(path: str = "data/offline_dataset.pkl") -> list:
with open(path, "rb") as f:
return pickle.load(f)
def dataset_to_arrays(dataset: list) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Flatten dataset into (states, actions_onehot, next_states) arrays."""
states, actions, next_states = [], [], []
for traj in dataset:
for s, a, ns in traj:
states.append(s)
actions.append(np.eye(ACTION_DIM)[a])
next_states.append(ns)
return np.array(states, dtype=np.float32), np.array(actions, dtype=np.float32), np.array(next_states, dtype=np.float32)
def extract_sa_trajectories(dataset: list) -> list[list[tuple]]:
"""Convert (s, a, s') trajectories to (s, a) trajectories for reward model."""
return [[(s, a) for s, a, ns in traj] for traj in dataset]
|