Spaces:
Sleeping
Sleeping
| import gymnasium as gym | |
| import numpy as np | |
| import pickle | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| STATE_DIM = 4 | |
| ACTION_DIM = 2 | |
| def _heuristic_action(state, noise: float) -> int: | |
| """Simple pole-angle heuristic with configurable noise.""" | |
| if np.random.random() < noise: | |
| return np.random.randint(ACTION_DIM) | |
| # Push in the direction the pole is leaning | |
| return 1 if state[2] > 0 else 0 | |
| def collect_trajectory(env, noise: float) -> list: | |
| """ | |
| Collect one (state, action, next_state) trajectory. | |
| No reward labels stored β consistent with offline RLHF setup. | |
| """ | |
| state, _ = env.reset() | |
| trajectory = [] | |
| for _ in range(500): | |
| action = _heuristic_action(state, noise) | |
| next_state, _, terminated, truncated, _ = env.step(action) | |
| trajectory.append((state.copy(), int(action), next_state.copy())) | |
| state = next_state | |
| if terminated or truncated: | |
| break | |
| return trajectory | |
| def collect_offline_dataset(n_trajectories=800, save_path="data/offline_dataset.pkl") -> list: | |
| """ | |
| Collect a mixed-quality offline dataset from CartPole-v1. | |
| Quality levels: | |
| - noise=1.0 β fully random (low quality) | |
| - noise=0.7 β mostly random (medium-low) | |
| - noise=0.3 β mostly heuristic (medium-high) | |
| - noise=0.05 β near-expert (high quality) | |
| Returns list of trajectories, each a list of (s, a, s') tuples. | |
| """ | |
| env = gym.make("CartPole-v1") | |
| noise_levels = [1.0, 0.7, 0.3, 0.05] | |
| per_level = n_trajectories // len(noise_levels) | |
| trajectories = [] | |
| for noise in noise_levels: | |
| for _ in tqdm(range(per_level), desc=f"Collecting (noise={noise:.2f})", leave=False): | |
| trajectories.append(collect_trajectory(env, noise)) | |
| env.close() | |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, "wb") as f: | |
| pickle.dump(trajectories, f) | |
| lengths = [len(t) for t in trajectories] | |
| print(f"Collected {len(trajectories)} trajectories | " | |
| f"len: min={min(lengths)} mean={np.mean(lengths):.1f} max={max(lengths)}") | |
| return trajectories | |
| def load_dataset(path: str = "data/offline_dataset.pkl") -> list: | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |
| def dataset_to_arrays(dataset: list) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Flatten dataset into (states, actions_onehot, next_states) arrays.""" | |
| states, actions, next_states = [], [], [] | |
| for traj in dataset: | |
| for s, a, ns in traj: | |
| states.append(s) | |
| actions.append(np.eye(ACTION_DIM)[a]) | |
| next_states.append(ns) | |
| return np.array(states, dtype=np.float32), np.array(actions, dtype=np.float32), np.array(next_states, dtype=np.float32) | |
| def extract_sa_trajectories(dataset: list) -> list[list[tuple]]: | |
| """Convert (s, a, s') trajectories to (s, a) trajectories for reward model.""" | |
| return [[(s, a) for s, a, ns in traj] for traj in dataset] | |