Spaces:
Sleeping
Sleeping
Upload simoprl/collect_data.py with huggingface_hub
Browse files- simoprl/collect_data.py +87 -0
simoprl/collect_data.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
STATE_DIM = 4
|
| 8 |
+
ACTION_DIM = 2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _heuristic_action(state, noise: float) -> int:
|
| 12 |
+
"""Simple pole-angle heuristic with configurable noise."""
|
| 13 |
+
if np.random.random() < noise:
|
| 14 |
+
return np.random.randint(ACTION_DIM)
|
| 15 |
+
# Push in the direction the pole is leaning
|
| 16 |
+
return 1 if state[2] > 0 else 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def collect_trajectory(env, noise: float) -> list:
|
| 20 |
+
"""
|
| 21 |
+
Collect one (state, action, next_state) trajectory.
|
| 22 |
+
No reward labels stored — consistent with offline RLHF setup.
|
| 23 |
+
"""
|
| 24 |
+
state, _ = env.reset()
|
| 25 |
+
trajectory = []
|
| 26 |
+
for _ in range(500):
|
| 27 |
+
action = _heuristic_action(state, noise)
|
| 28 |
+
next_state, _, terminated, truncated, _ = env.step(action)
|
| 29 |
+
trajectory.append((state.copy(), int(action), next_state.copy()))
|
| 30 |
+
state = next_state
|
| 31 |
+
if terminated or truncated:
|
| 32 |
+
break
|
| 33 |
+
return trajectory
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def collect_offline_dataset(n_trajectories=800, save_path="data/offline_dataset.pkl") -> list:
|
| 37 |
+
"""
|
| 38 |
+
Collect a mixed-quality offline dataset from CartPole-v1.
|
| 39 |
+
|
| 40 |
+
Quality levels:
|
| 41 |
+
- noise=1.0 → fully random (low quality)
|
| 42 |
+
- noise=0.7 → mostly random (medium-low)
|
| 43 |
+
- noise=0.3 → mostly heuristic (medium-high)
|
| 44 |
+
- noise=0.05 → near-expert (high quality)
|
| 45 |
+
|
| 46 |
+
Returns list of trajectories, each a list of (s, a, s') tuples.
|
| 47 |
+
"""
|
| 48 |
+
env = gym.make("CartPole-v1")
|
| 49 |
+
noise_levels = [1.0, 0.7, 0.3, 0.05]
|
| 50 |
+
per_level = n_trajectories // len(noise_levels)
|
| 51 |
+
trajectories = []
|
| 52 |
+
|
| 53 |
+
for noise in noise_levels:
|
| 54 |
+
for _ in tqdm(range(per_level), desc=f"Collecting (noise={noise:.2f})", leave=False):
|
| 55 |
+
trajectories.append(collect_trajectory(env, noise))
|
| 56 |
+
|
| 57 |
+
env.close()
|
| 58 |
+
|
| 59 |
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
with open(save_path, "wb") as f:
|
| 61 |
+
pickle.dump(trajectories, f)
|
| 62 |
+
|
| 63 |
+
lengths = [len(t) for t in trajectories]
|
| 64 |
+
print(f"Collected {len(trajectories)} trajectories | "
|
| 65 |
+
f"len: min={min(lengths)} mean={np.mean(lengths):.1f} max={max(lengths)}")
|
| 66 |
+
return trajectories
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_dataset(path: str = "data/offline_dataset.pkl") -> list:
|
| 70 |
+
with open(path, "rb") as f:
|
| 71 |
+
return pickle.load(f)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def dataset_to_arrays(dataset: list) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 75 |
+
"""Flatten dataset into (states, actions_onehot, next_states) arrays."""
|
| 76 |
+
states, actions, next_states = [], [], []
|
| 77 |
+
for traj in dataset:
|
| 78 |
+
for s, a, ns in traj:
|
| 79 |
+
states.append(s)
|
| 80 |
+
actions.append(np.eye(ACTION_DIM)[a])
|
| 81 |
+
next_states.append(ns)
|
| 82 |
+
return np.array(states, dtype=np.float32), np.array(actions, dtype=np.float32), np.array(next_states, dtype=np.float32)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def extract_sa_trajectories(dataset: list) -> list[list[tuple]]:
|
| 86 |
+
"""Convert (s, a, s') trajectories to (s, a) trajectories for reward model."""
|
| 87 |
+
return [[(s, a) for s, a, ns in traj] for traj in dataset]
|