| """ |
| Product automaton: GridWorld Γ BΓΌchi_1 Γ ... Γ BΓΌchi_n |
| |
| A product state is (grid_pos, aut_state_1, ..., aut_state_n). |
| We build the graph lazily via BFS, then run Tarjan's SCC algorithm. |
| """ |
|
|
| from collections import defaultdict, deque |
| from typing import Dict, FrozenSet, List, Optional, Set, Tuple |
|
|
| from .grid_world import GridWorld |
| from .automata import BuchiAut |
|
|
|
|
| |
| ProductState = tuple |
|
|
|
|
| class ProductGraph: |
| def __init__(self, grid: GridWorld, automata: List[BuchiAut]): |
| self.grid = grid |
| self.automata = automata |
| self.n_aut = len(automata) |
|
|
| |
| init_aut = tuple(a.initial for a in automata) |
| self.initial: ProductState = (grid.start,) + init_aut |
|
|
| |
| self.states: List[ProductState] = [] |
| self.state_index: Dict[ProductState, int] = {} |
| self.adj: Dict[int, List[int]] = defaultdict(list) |
| self.radj: Dict[int, List[int]] = defaultdict(list) |
| self._build() |
|
|
| |
|
|
| def _build(self): |
| queue = deque([self.initial]) |
| self._add_state(self.initial) |
|
|
| while queue: |
| ps = queue.popleft() |
| src_idx = self.state_index[ps] |
| grid_pos = ps[0] |
| aut_states = ps[1:] |
|
|
| label = self.grid.label(grid_pos) |
|
|
| for _, next_pos in self.grid.successors(grid_pos): |
| next_label = self.grid.label(next_pos) |
| |
| |
| next_aut = [] |
| valid = True |
| for i, aut in enumerate(self.automata): |
| nq = aut.step(aut_states[i], next_label) |
| if nq is None: |
| valid = False |
| break |
| next_aut.append(nq) |
|
|
| if not valid: |
| continue |
|
|
| next_ps: ProductState = (next_pos,) + tuple(next_aut) |
| if next_ps not in self.state_index: |
| self._add_state(next_ps) |
| queue.append(next_ps) |
|
|
| dst_idx = self.state_index[next_ps] |
| self.adj[src_idx].append(dst_idx) |
| self.radj[dst_idx].append(src_idx) |
|
|
| def _add_state(self, ps: ProductState) -> int: |
| idx = len(self.states) |
| self.states.append(ps) |
| self.state_index[ps] = idx |
| return idx |
|
|
| |
|
|
| def compute_sccs(self) -> List[List[int]]: |
| """Returns list of SCCs (each a list of state indices), largest first.""" |
| n = len(self.states) |
| index_counter = [0] |
| stack = [] |
| lowlink = {} |
| index = {} |
| on_stack = {} |
| sccs = [] |
|
|
| def strongconnect(v): |
| index[v] = index_counter[0] |
| lowlink[v] = index_counter[0] |
| index_counter[0] += 1 |
| stack.append(v) |
| on_stack[v] = True |
|
|
| for w in self.adj[v]: |
| if w not in index: |
| strongconnect(w) |
| lowlink[v] = min(lowlink[v], lowlink[w]) |
| elif on_stack.get(w): |
| lowlink[v] = min(lowlink[v], index[w]) |
|
|
| if lowlink[v] == index[v]: |
| scc = [] |
| while True: |
| w = stack.pop() |
| on_stack[w] = False |
| scc.append(w) |
| if w == v: |
| break |
| sccs.append(scc) |
|
|
| import sys |
| sys.setrecursionlimit(100000) |
|
|
| for v in range(n): |
| if v not in index: |
| strongconnect(v) |
|
|
| return sccs |
|
|
| |
|
|
| def scc_satisfied_specs(self, scc: List[int], rewards: List[float]) -> Tuple[float, Set[int]]: |
| """ |
| For an SCC, compute which specs have their accepting states inside it. |
| Returns (total_reward, set_of_satisfied_spec_indices). |
| """ |
| satisfied = set() |
| for idx in scc: |
| ps = self.states[idx] |
| aut_states = ps[1:] |
| for i, aut in enumerate(self.automata): |
| if aut.is_accepting(aut_states[i]): |
| satisfied.add(i) |
|
|
| total = sum(rewards[i] for i in satisfied) |
| return total, satisfied |
|
|
| def is_nontrivial_scc(self, scc: List[int]) -> bool: |
| """An SCC is nontrivial if it has >1 state, or 1 state with a self-loop.""" |
| if len(scc) > 1: |
| return True |
| v = scc[0] |
| return v in self.adj[v] |
|
|
| |
|
|
| def reachable_from_initial(self) -> Set[int]: |
| visited = set() |
| queue = deque([self.state_index[self.initial]]) |
| while queue: |
| v = queue.popleft() |
| if v in visited: |
| continue |
| visited.add(v) |
| for w in self.adj[v]: |
| if w not in visited: |
| queue.append(w) |
| return visited |
|
|
| def bfs_path(self, src: int, targets: Set[int]) -> Optional[List[int]]: |
| """BFS from src to any state in targets. Returns list of state indices.""" |
| if src in targets: |
| return [src] |
| parent = {src: None} |
| queue = deque([src]) |
| while queue: |
| v = queue.popleft() |
| for w in self.adj[v]: |
| if w not in parent: |
| parent[w] = v |
| if w in targets: |
| |
| path = [] |
| cur = w |
| while cur is not None: |
| path.append(cur) |
| cur = parent[cur] |
| return list(reversed(path)) |
| queue.append(w) |
| return None |
|
|
| def find_cycle_through(self, scc_set: Set[int], required_accepting: List[Set[int]]) -> Optional[List[int]]: |
| """ |
| Find a cycle within the SCC that passes through at least one accepting |
| state for each required spec. |
| Returns a list of state indices forming the cycle (first == last). |
| """ |
| |
| |
| |
| |
|
|
| if not scc_set: |
| return None |
|
|
| start = next(iter(scc_set)) |
|
|
| |
| checkpoints = [] |
| for acc_set in required_accepting: |
| candidates = acc_set & scc_set |
| if candidates: |
| checkpoints.append(next(iter(candidates))) |
|
|
| if not checkpoints: |
| |
| if start in self.adj.get(start, []): |
| return [start, start] |
| |
| path = self._bfs_in_scc(start, {start}, scc_set) |
| return path |
|
|
| |
| waypoints = [start] + checkpoints + [start] |
| full_path = [] |
| for i in range(len(waypoints) - 1): |
| seg = self._bfs_in_scc(waypoints[i], {waypoints[i + 1]}, scc_set) |
| if seg is None: |
| return None |
| if full_path: |
| full_path.extend(seg[1:]) |
| else: |
| full_path.extend(seg) |
|
|
| return full_path |
|
|
| def _bfs_in_scc(self, src: int, targets: Set[int], scc_set: Set[int]) -> Optional[List[int]]: |
| """BFS from src to any target, restricted to scc_set.""" |
| if src in targets: |
| return [src] |
| parent = {src: None} |
| queue = deque([src]) |
| while queue: |
| v = queue.popleft() |
| for w in self.adj[v]: |
| if w in scc_set and w not in parent: |
| parent[w] = v |
| if w in targets: |
| path = [] |
| cur = w |
| while cur is not None: |
| path.append(cur) |
| cur = parent[cur] |
| return list(reversed(path)) |
| queue.append(w) |
| return None |
|
|