singhanshuman commited on
Commit
1624b35
·
verified ·
1 Parent(s): 619b31e

Upload models/lsr.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/lsr.py +169 -0
models/lsr.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Latent Space Roadmap (LSR).
3
+
4
+ Graph construction:
5
+ 1. Encode all training frames → latent cloud Z.
6
+ 2. For each node, find its k nearest neighbours.
7
+ 3. Keep an edge (i, j) only when the pair appears as consecutive
8
+ timesteps in at least one training trajectory (i.e. a real action
9
+ connects them). This prevents the planner from taking shortcuts
10
+ through physically unreachable states.
11
+
12
+ Planning:
13
+ Given z_start and z_goal, snap to nearest graph nodes then run
14
+ Dijkstra weighted by Euclidean latent distance.
15
+ """
16
+
17
+ import heapq
18
+ from typing import Dict, List, Optional, Tuple
19
+
20
+ import numpy as np
21
+ from scipy.spatial import KDTree
22
+
23
+
24
+ class LSR:
25
+ def __init__(self, k: int = 10):
26
+ self.k = k
27
+ self.latents: Optional[np.ndarray] = None
28
+ self.tree: Optional[KDTree] = None
29
+ # adjacency: node → [(neighbour, euclidean_weight), ...]
30
+ self.graph: Dict[int, List[Tuple[int, float]]] = {}
31
+ # valid_transitions: (i, j) → action that takes state_i → state_j
32
+ self.valid_transitions: Dict[Tuple[int, int], np.ndarray] = {}
33
+
34
+ # ------------------------------------------------------------------
35
+ # Graph construction
36
+ # ------------------------------------------------------------------
37
+
38
+ def build(
39
+ self,
40
+ latents: np.ndarray,
41
+ episodes: List[List[int]],
42
+ actions: np.ndarray,
43
+ ) -> None:
44
+ """
45
+ latents : (N, z_dim) — encoded training frames
46
+ episodes : list of index-lists, one list per trajectory
47
+ actions : (N, action_dim) — action[i] moves frame i → frame i+1
48
+ """
49
+ self.latents = latents.copy()
50
+ self.tree = KDTree(latents)
51
+ self.graph = {}
52
+ self.valid_transitions = {}
53
+
54
+ # Record every consecutive transition
55
+ for ep in episodes:
56
+ for t in range(len(ep) - 1):
57
+ i, j = ep[t], ep[t + 1]
58
+ self.valid_transitions[(i, j)] = actions[i]
59
+
60
+ valid_set = set(self.valid_transitions.keys())
61
+ n = len(latents)
62
+
63
+ for i in range(n):
64
+ _, nbrs = self.tree.query(latents[i], k=min(self.k + 1, n))
65
+ for j in nbrs[1:]: # skip self
66
+ j = int(j)
67
+ if (i, j) in valid_set or (j, i) in valid_set:
68
+ w = float(np.linalg.norm(latents[i] - latents[j]))
69
+ self.graph.setdefault(i, []).append((j, w))
70
+ self.graph.setdefault(j, []).append((i, w))
71
+
72
+ # ------------------------------------------------------------------
73
+ # Planning
74
+ # ------------------------------------------------------------------
75
+
76
+ def _dijkstra(self, src: int, dst: int) -> Optional[List[int]]:
77
+ dist: Dict[int, float] = {src: 0.0}
78
+ prev: Dict[int, int] = {}
79
+ pq = [(0.0, src)]
80
+ visited: set = set()
81
+
82
+ while pq:
83
+ d, u = heapq.heappop(pq)
84
+ if u in visited:
85
+ continue
86
+ visited.add(u)
87
+
88
+ if u == dst:
89
+ path, cur = [], dst
90
+ while cur != src:
91
+ path.append(cur)
92
+ cur = prev[cur]
93
+ path.append(src)
94
+ return path[::-1]
95
+
96
+ for v, w in self.graph.get(u, []):
97
+ nd = d + w
98
+ if nd < dist.get(v, float("inf")):
99
+ dist[v] = nd
100
+ prev[v] = u
101
+ heapq.heappush(pq, (nd, v))
102
+
103
+ return None # no path
104
+
105
+ def plan(
106
+ self,
107
+ start_z: np.ndarray,
108
+ goal_z: np.ndarray,
109
+ ) -> Optional[Tuple[List[int], np.ndarray]]:
110
+ """
111
+ Returns (node_indices_along_path, latent_path) or None if no path.
112
+ """
113
+ if self.tree is None or self.latents is None:
114
+ raise RuntimeError("Call build() before plan().")
115
+
116
+ _, (n_start,) = self.tree.query(start_z.reshape(1, -1), k=1)
117
+ _, (n_goal,) = self.tree.query(goal_z.reshape(1, -1), k=1)
118
+ n_start, n_goal = int(n_start), int(n_goal)
119
+
120
+ if n_start == n_goal:
121
+ return [n_start], self.latents[[n_start]]
122
+
123
+ path = self._dijkstra(n_start, n_goal)
124
+ if path is None:
125
+ return None
126
+
127
+ return path, self.latents[path]
128
+
129
+ def get_actions_for_path(self, path: List[int]) -> List[Optional[np.ndarray]]:
130
+ """Return the training action for each consecutive node pair in path."""
131
+ actions = []
132
+ for i in range(len(path) - 1):
133
+ key = (path[i], path[i + 1])
134
+ rev = (path[i + 1], path[i])
135
+ if key in self.valid_transitions:
136
+ actions.append(self.valid_transitions[key])
137
+ elif rev in self.valid_transitions:
138
+ actions.append(self.valid_transitions[rev])
139
+ else:
140
+ actions.append(None)
141
+ return actions
142
+
143
+ # ------------------------------------------------------------------
144
+ # Persistence
145
+ # ------------------------------------------------------------------
146
+
147
+ def save(self, path: str) -> None:
148
+ import pickle
149
+ state = {
150
+ "k": self.k,
151
+ "latents": self.latents,
152
+ "graph": self.graph,
153
+ "valid_transitions": self.valid_transitions,
154
+ }
155
+ with open(path, "wb") as f:
156
+ pickle.dump(state, f)
157
+
158
+ @classmethod
159
+ def load(cls, path: str) -> "LSR":
160
+ import pickle
161
+ with open(path, "rb") as f:
162
+ state = pickle.load(f)
163
+ lsr = cls(k=state["k"])
164
+ lsr.latents = state["latents"]
165
+ lsr.graph = state["graph"]
166
+ lsr.valid_transitions = state["valid_transitions"]
167
+ if lsr.latents is not None:
168
+ lsr.tree = KDTree(lsr.latents)
169
+ return lsr