lsr-lang / models /lsr.py
singhanshuman's picture
Upload models/lsr.py with huggingface_hub
1624b35 verified
"""
Latent Space Roadmap (LSR).
Graph construction:
1. Encode all training frames β†’ latent cloud Z.
2. For each node, find its k nearest neighbours.
3. Keep an edge (i, j) only when the pair appears as consecutive
timesteps in at least one training trajectory (i.e. a real action
connects them). This prevents the planner from taking shortcuts
through physically unreachable states.
Planning:
Given z_start and z_goal, snap to nearest graph nodes then run
Dijkstra weighted by Euclidean latent distance.
"""
import heapq
from typing import Dict, List, Optional, Tuple
import numpy as np
from scipy.spatial import KDTree
class LSR:
def __init__(self, k: int = 10):
self.k = k
self.latents: Optional[np.ndarray] = None
self.tree: Optional[KDTree] = None
# adjacency: node β†’ [(neighbour, euclidean_weight), ...]
self.graph: Dict[int, List[Tuple[int, float]]] = {}
# valid_transitions: (i, j) β†’ action that takes state_i β†’ state_j
self.valid_transitions: Dict[Tuple[int, int], np.ndarray] = {}
# ------------------------------------------------------------------
# Graph construction
# ------------------------------------------------------------------
def build(
self,
latents: np.ndarray,
episodes: List[List[int]],
actions: np.ndarray,
) -> None:
"""
latents : (N, z_dim) β€” encoded training frames
episodes : list of index-lists, one list per trajectory
actions : (N, action_dim) β€” action[i] moves frame i β†’ frame i+1
"""
self.latents = latents.copy()
self.tree = KDTree(latents)
self.graph = {}
self.valid_transitions = {}
# Record every consecutive transition
for ep in episodes:
for t in range(len(ep) - 1):
i, j = ep[t], ep[t + 1]
self.valid_transitions[(i, j)] = actions[i]
valid_set = set(self.valid_transitions.keys())
n = len(latents)
for i in range(n):
_, nbrs = self.tree.query(latents[i], k=min(self.k + 1, n))
for j in nbrs[1:]: # skip self
j = int(j)
if (i, j) in valid_set or (j, i) in valid_set:
w = float(np.linalg.norm(latents[i] - latents[j]))
self.graph.setdefault(i, []).append((j, w))
self.graph.setdefault(j, []).append((i, w))
# ------------------------------------------------------------------
# Planning
# ------------------------------------------------------------------
def _dijkstra(self, src: int, dst: int) -> Optional[List[int]]:
dist: Dict[int, float] = {src: 0.0}
prev: Dict[int, int] = {}
pq = [(0.0, src)]
visited: set = set()
while pq:
d, u = heapq.heappop(pq)
if u in visited:
continue
visited.add(u)
if u == dst:
path, cur = [], dst
while cur != src:
path.append(cur)
cur = prev[cur]
path.append(src)
return path[::-1]
for v, w in self.graph.get(u, []):
nd = d + w
if nd < dist.get(v, float("inf")):
dist[v] = nd
prev[v] = u
heapq.heappush(pq, (nd, v))
return None # no path
def plan(
self,
start_z: np.ndarray,
goal_z: np.ndarray,
) -> Optional[Tuple[List[int], np.ndarray]]:
"""
Returns (node_indices_along_path, latent_path) or None if no path.
"""
if self.tree is None or self.latents is None:
raise RuntimeError("Call build() before plan().")
_, (n_start,) = self.tree.query(start_z.reshape(1, -1), k=1)
_, (n_goal,) = self.tree.query(goal_z.reshape(1, -1), k=1)
n_start, n_goal = int(n_start), int(n_goal)
if n_start == n_goal:
return [n_start], self.latents[[n_start]]
path = self._dijkstra(n_start, n_goal)
if path is None:
return None
return path, self.latents[path]
def get_actions_for_path(self, path: List[int]) -> List[Optional[np.ndarray]]:
"""Return the training action for each consecutive node pair in path."""
actions = []
for i in range(len(path) - 1):
key = (path[i], path[i + 1])
rev = (path[i + 1], path[i])
if key in self.valid_transitions:
actions.append(self.valid_transitions[key])
elif rev in self.valid_transitions:
actions.append(self.valid_transitions[rev])
else:
actions.append(None)
return actions
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def save(self, path: str) -> None:
import pickle
state = {
"k": self.k,
"latents": self.latents,
"graph": self.graph,
"valid_transitions": self.valid_transitions,
}
with open(path, "wb") as f:
pickle.dump(state, f)
@classmethod
def load(cls, path: str) -> "LSR":
import pickle
with open(path, "rb") as f:
state = pickle.load(f)
lsr = cls(k=state["k"])
lsr.latents = state["latents"]
lsr.graph = state["graph"]
lsr.valid_transitions = state["valid_transitions"]
if lsr.latents is not None:
lsr.tree = KDTree(lsr.latents)
return lsr