File size: 5,757 Bytes
1624b35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Latent Space Roadmap (LSR).

Graph construction:
  1. Encode all training frames β†’ latent cloud Z.
  2. For each node, find its k nearest neighbours.
  3. Keep an edge (i, j) only when the pair appears as consecutive
     timesteps in at least one training trajectory (i.e. a real action
     connects them).  This prevents the planner from taking shortcuts
     through physically unreachable states.

Planning:
  Given z_start and z_goal, snap to nearest graph nodes then run
  Dijkstra weighted by Euclidean latent distance.
"""

import heapq
from typing import Dict, List, Optional, Tuple

import numpy as np
from scipy.spatial import KDTree


class LSR:
    def __init__(self, k: int = 10):
        self.k = k
        self.latents: Optional[np.ndarray] = None
        self.tree: Optional[KDTree] = None
        # adjacency: node β†’ [(neighbour, euclidean_weight), ...]
        self.graph: Dict[int, List[Tuple[int, float]]] = {}
        # valid_transitions: (i, j) β†’ action that takes state_i β†’ state_j
        self.valid_transitions: Dict[Tuple[int, int], np.ndarray] = {}

    # ------------------------------------------------------------------
    # Graph construction
    # ------------------------------------------------------------------

    def build(
        self,
        latents: np.ndarray,
        episodes: List[List[int]],
        actions: np.ndarray,
    ) -> None:
        """
        latents  : (N, z_dim)  β€” encoded training frames
        episodes : list of index-lists, one list per trajectory
        actions  : (N, action_dim) β€” action[i] moves frame i β†’ frame i+1
        """
        self.latents = latents.copy()
        self.tree = KDTree(latents)
        self.graph = {}
        self.valid_transitions = {}

        # Record every consecutive transition
        for ep in episodes:
            for t in range(len(ep) - 1):
                i, j = ep[t], ep[t + 1]
                self.valid_transitions[(i, j)] = actions[i]

        valid_set = set(self.valid_transitions.keys())
        n = len(latents)

        for i in range(n):
            _, nbrs = self.tree.query(latents[i], k=min(self.k + 1, n))
            for j in nbrs[1:]:          # skip self
                j = int(j)
                if (i, j) in valid_set or (j, i) in valid_set:
                    w = float(np.linalg.norm(latents[i] - latents[j]))
                    self.graph.setdefault(i, []).append((j, w))
                    self.graph.setdefault(j, []).append((i, w))

    # ------------------------------------------------------------------
    # Planning
    # ------------------------------------------------------------------

    def _dijkstra(self, src: int, dst: int) -> Optional[List[int]]:
        dist: Dict[int, float] = {src: 0.0}
        prev: Dict[int, int] = {}
        pq = [(0.0, src)]
        visited: set = set()

        while pq:
            d, u = heapq.heappop(pq)
            if u in visited:
                continue
            visited.add(u)

            if u == dst:
                path, cur = [], dst
                while cur != src:
                    path.append(cur)
                    cur = prev[cur]
                path.append(src)
                return path[::-1]

            for v, w in self.graph.get(u, []):
                nd = d + w
                if nd < dist.get(v, float("inf")):
                    dist[v] = nd
                    prev[v] = u
                    heapq.heappush(pq, (nd, v))

        return None  # no path

    def plan(
        self,
        start_z: np.ndarray,
        goal_z: np.ndarray,
    ) -> Optional[Tuple[List[int], np.ndarray]]:
        """
        Returns (node_indices_along_path, latent_path) or None if no path.
        """
        if self.tree is None or self.latents is None:
            raise RuntimeError("Call build() before plan().")

        _, (n_start,) = self.tree.query(start_z.reshape(1, -1), k=1)
        _, (n_goal,)  = self.tree.query(goal_z.reshape(1, -1),  k=1)
        n_start, n_goal = int(n_start), int(n_goal)

        if n_start == n_goal:
            return [n_start], self.latents[[n_start]]

        path = self._dijkstra(n_start, n_goal)
        if path is None:
            return None

        return path, self.latents[path]

    def get_actions_for_path(self, path: List[int]) -> List[Optional[np.ndarray]]:
        """Return the training action for each consecutive node pair in path."""
        actions = []
        for i in range(len(path) - 1):
            key = (path[i], path[i + 1])
            rev = (path[i + 1], path[i])
            if key in self.valid_transitions:
                actions.append(self.valid_transitions[key])
            elif rev in self.valid_transitions:
                actions.append(self.valid_transitions[rev])
            else:
                actions.append(None)
        return actions

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self, path: str) -> None:
        import pickle
        state = {
            "k": self.k,
            "latents": self.latents,
            "graph": self.graph,
            "valid_transitions": self.valid_transitions,
        }
        with open(path, "wb") as f:
            pickle.dump(state, f)

    @classmethod
    def load(cls, path: str) -> "LSR":
        import pickle
        with open(path, "rb") as f:
            state = pickle.load(f)
        lsr = cls(k=state["k"])
        lsr.latents = state["latents"]
        lsr.graph = state["graph"]
        lsr.valid_transitions = state["valid_transitions"]
        if lsr.latents is not None:
            lsr.tree = KDTree(lsr.latents)
        return lsr