singhanshuman commited on
Commit
58b6f47
·
verified ·
1 Parent(s): c900716

Upload simoprl/collect_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. simoprl/collect_data.py +87 -0
simoprl/collect_data.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import numpy as np
3
+ import pickle
4
+ from pathlib import Path
5
+ from tqdm import tqdm
6
+
7
+ STATE_DIM = 4
8
+ ACTION_DIM = 2
9
+
10
+
11
+ def _heuristic_action(state, noise: float) -> int:
12
+ """Simple pole-angle heuristic with configurable noise."""
13
+ if np.random.random() < noise:
14
+ return np.random.randint(ACTION_DIM)
15
+ # Push in the direction the pole is leaning
16
+ return 1 if state[2] > 0 else 0
17
+
18
+
19
+ def collect_trajectory(env, noise: float) -> list:
20
+ """
21
+ Collect one (state, action, next_state) trajectory.
22
+ No reward labels stored — consistent with offline RLHF setup.
23
+ """
24
+ state, _ = env.reset()
25
+ trajectory = []
26
+ for _ in range(500):
27
+ action = _heuristic_action(state, noise)
28
+ next_state, _, terminated, truncated, _ = env.step(action)
29
+ trajectory.append((state.copy(), int(action), next_state.copy()))
30
+ state = next_state
31
+ if terminated or truncated:
32
+ break
33
+ return trajectory
34
+
35
+
36
+ def collect_offline_dataset(n_trajectories=800, save_path="data/offline_dataset.pkl") -> list:
37
+ """
38
+ Collect a mixed-quality offline dataset from CartPole-v1.
39
+
40
+ Quality levels:
41
+ - noise=1.0 → fully random (low quality)
42
+ - noise=0.7 → mostly random (medium-low)
43
+ - noise=0.3 → mostly heuristic (medium-high)
44
+ - noise=0.05 → near-expert (high quality)
45
+
46
+ Returns list of trajectories, each a list of (s, a, s') tuples.
47
+ """
48
+ env = gym.make("CartPole-v1")
49
+ noise_levels = [1.0, 0.7, 0.3, 0.05]
50
+ per_level = n_trajectories // len(noise_levels)
51
+ trajectories = []
52
+
53
+ for noise in noise_levels:
54
+ for _ in tqdm(range(per_level), desc=f"Collecting (noise={noise:.2f})", leave=False):
55
+ trajectories.append(collect_trajectory(env, noise))
56
+
57
+ env.close()
58
+
59
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
60
+ with open(save_path, "wb") as f:
61
+ pickle.dump(trajectories, f)
62
+
63
+ lengths = [len(t) for t in trajectories]
64
+ print(f"Collected {len(trajectories)} trajectories | "
65
+ f"len: min={min(lengths)} mean={np.mean(lengths):.1f} max={max(lengths)}")
66
+ return trajectories
67
+
68
+
69
+ def load_dataset(path: str = "data/offline_dataset.pkl") -> list:
70
+ with open(path, "rb") as f:
71
+ return pickle.load(f)
72
+
73
+
74
+ def dataset_to_arrays(dataset: list) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
75
+ """Flatten dataset into (states, actions_onehot, next_states) arrays."""
76
+ states, actions, next_states = [], [], []
77
+ for traj in dataset:
78
+ for s, a, ns in traj:
79
+ states.append(s)
80
+ actions.append(np.eye(ACTION_DIM)[a])
81
+ next_states.append(ns)
82
+ return np.array(states, dtype=np.float32), np.array(actions, dtype=np.float32), np.array(next_states, dtype=np.float32)
83
+
84
+
85
+ def extract_sa_trajectories(dataset: list) -> list[list[tuple]]:
86
+ """Convert (s, a, s') trajectories to (s, a) trajectories for reward model."""
87
+ return [[(s, a) for s, a, ns in traj] for traj in dataset]