Spaces:
Sleeping
Sleeping
| """ | |
| Latent Space Roadmap (LSR). | |
| Graph construction: | |
| 1. Encode all training frames β latent cloud Z. | |
| 2. For each node, find its k nearest neighbours. | |
| 3. Keep an edge (i, j) only when the pair appears as consecutive | |
| timesteps in at least one training trajectory (i.e. a real action | |
| connects them). This prevents the planner from taking shortcuts | |
| through physically unreachable states. | |
| Planning: | |
| Given z_start and z_goal, snap to nearest graph nodes then run | |
| Dijkstra weighted by Euclidean latent distance. | |
| """ | |
| import heapq | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| from scipy.spatial import KDTree | |
| class LSR: | |
| def __init__(self, k: int = 10): | |
| self.k = k | |
| self.latents: Optional[np.ndarray] = None | |
| self.tree: Optional[KDTree] = None | |
| # adjacency: node β [(neighbour, euclidean_weight), ...] | |
| self.graph: Dict[int, List[Tuple[int, float]]] = {} | |
| # valid_transitions: (i, j) β action that takes state_i β state_j | |
| self.valid_transitions: Dict[Tuple[int, int], np.ndarray] = {} | |
| # ------------------------------------------------------------------ | |
| # Graph construction | |
| # ------------------------------------------------------------------ | |
| def build( | |
| self, | |
| latents: np.ndarray, | |
| episodes: List[List[int]], | |
| actions: np.ndarray, | |
| ) -> None: | |
| """ | |
| latents : (N, z_dim) β encoded training frames | |
| episodes : list of index-lists, one list per trajectory | |
| actions : (N, action_dim) β action[i] moves frame i β frame i+1 | |
| """ | |
| self.latents = latents.copy() | |
| self.tree = KDTree(latents) | |
| self.graph = {} | |
| self.valid_transitions = {} | |
| # Record every consecutive transition | |
| for ep in episodes: | |
| for t in range(len(ep) - 1): | |
| i, j = ep[t], ep[t + 1] | |
| self.valid_transitions[(i, j)] = actions[i] | |
| valid_set = set(self.valid_transitions.keys()) | |
| n = len(latents) | |
| for i in range(n): | |
| _, nbrs = self.tree.query(latents[i], k=min(self.k + 1, n)) | |
| for j in nbrs[1:]: # skip self | |
| j = int(j) | |
| if (i, j) in valid_set or (j, i) in valid_set: | |
| w = float(np.linalg.norm(latents[i] - latents[j])) | |
| self.graph.setdefault(i, []).append((j, w)) | |
| self.graph.setdefault(j, []).append((i, w)) | |
| # ------------------------------------------------------------------ | |
| # Planning | |
| # ------------------------------------------------------------------ | |
| def _dijkstra(self, src: int, dst: int) -> Optional[List[int]]: | |
| dist: Dict[int, float] = {src: 0.0} | |
| prev: Dict[int, int] = {} | |
| pq = [(0.0, src)] | |
| visited: set = set() | |
| while pq: | |
| d, u = heapq.heappop(pq) | |
| if u in visited: | |
| continue | |
| visited.add(u) | |
| if u == dst: | |
| path, cur = [], dst | |
| while cur != src: | |
| path.append(cur) | |
| cur = prev[cur] | |
| path.append(src) | |
| return path[::-1] | |
| for v, w in self.graph.get(u, []): | |
| nd = d + w | |
| if nd < dist.get(v, float("inf")): | |
| dist[v] = nd | |
| prev[v] = u | |
| heapq.heappush(pq, (nd, v)) | |
| return None # no path | |
| def plan( | |
| self, | |
| start_z: np.ndarray, | |
| goal_z: np.ndarray, | |
| ) -> Optional[Tuple[List[int], np.ndarray]]: | |
| """ | |
| Returns (node_indices_along_path, latent_path) or None if no path. | |
| """ | |
| if self.tree is None or self.latents is None: | |
| raise RuntimeError("Call build() before plan().") | |
| _, (n_start,) = self.tree.query(start_z.reshape(1, -1), k=1) | |
| _, (n_goal,) = self.tree.query(goal_z.reshape(1, -1), k=1) | |
| n_start, n_goal = int(n_start), int(n_goal) | |
| if n_start == n_goal: | |
| return [n_start], self.latents[[n_start]] | |
| path = self._dijkstra(n_start, n_goal) | |
| if path is None: | |
| return None | |
| return path, self.latents[path] | |
| def get_actions_for_path(self, path: List[int]) -> List[Optional[np.ndarray]]: | |
| """Return the training action for each consecutive node pair in path.""" | |
| actions = [] | |
| for i in range(len(path) - 1): | |
| key = (path[i], path[i + 1]) | |
| rev = (path[i + 1], path[i]) | |
| if key in self.valid_transitions: | |
| actions.append(self.valid_transitions[key]) | |
| elif rev in self.valid_transitions: | |
| actions.append(self.valid_transitions[rev]) | |
| else: | |
| actions.append(None) | |
| return actions | |
| # ------------------------------------------------------------------ | |
| # Persistence | |
| # ------------------------------------------------------------------ | |
| def save(self, path: str) -> None: | |
| import pickle | |
| state = { | |
| "k": self.k, | |
| "latents": self.latents, | |
| "graph": self.graph, | |
| "valid_transitions": self.valid_transitions, | |
| } | |
| with open(path, "wb") as f: | |
| pickle.dump(state, f) | |
| def load(cls, path: str) -> "LSR": | |
| import pickle | |
| with open(path, "rb") as f: | |
| state = pickle.load(f) | |
| lsr = cls(k=state["k"]) | |
| lsr.latents = state["latents"] | |
| lsr.graph = state["graph"] | |
| lsr.valid_transitions = state["valid_transitions"] | |
| if lsr.latents is not None: | |
| lsr.tree = KDTree(lsr.latents) | |
| return lsr | |