Anshuman Singh
Initial implementation of minimum-violation LTL planner
ba4f7c6
"""
Product automaton: GridWorld Γ— BΓΌchi_1 Γ— ... Γ— BΓΌchi_n
A product state is (grid_pos, aut_state_1, ..., aut_state_n).
We build the graph lazily via BFS, then run Tarjan's SCC algorithm.
"""
from collections import defaultdict, deque
from typing import Dict, FrozenSet, List, Optional, Set, Tuple
from .grid_world import GridWorld
from .automata import BuchiAut
# A product state is a tuple: (grid_pos, q1, q2, ..., qn)
ProductState = tuple
class ProductGraph:
def __init__(self, grid: GridWorld, automata: List[BuchiAut]):
self.grid = grid
self.automata = automata
self.n_aut = len(automata)
# Initial product state
init_aut = tuple(a.initial for a in automata)
self.initial: ProductState = (grid.start,) + init_aut
# Build graph
self.states: List[ProductState] = []
self.state_index: Dict[ProductState, int] = {}
self.adj: Dict[int, List[int]] = defaultdict(list) # forward edges
self.radj: Dict[int, List[int]] = defaultdict(list) # reverse edges
self._build()
# ── graph construction ────────────────────────────────────────────────────
def _build(self):
queue = deque([self.initial])
self._add_state(self.initial)
while queue:
ps = queue.popleft()
src_idx = self.state_index[ps]
grid_pos = ps[0]
aut_states = ps[1:]
label = self.grid.label(grid_pos)
for _, next_pos in self.grid.successors(grid_pos):
next_label = self.grid.label(next_pos)
# Advance each automaton on next_label (transition happens
# when entering the next cell, consistent with standard semantics)
next_aut = []
valid = True
for i, aut in enumerate(self.automata):
nq = aut.step(aut_states[i], next_label)
if nq is None:
valid = False
break
next_aut.append(nq)
if not valid:
continue
next_ps: ProductState = (next_pos,) + tuple(next_aut)
if next_ps not in self.state_index:
self._add_state(next_ps)
queue.append(next_ps)
dst_idx = self.state_index[next_ps]
self.adj[src_idx].append(dst_idx)
self.radj[dst_idx].append(src_idx)
def _add_state(self, ps: ProductState) -> int:
idx = len(self.states)
self.states.append(ps)
self.state_index[ps] = idx
return idx
# ── Tarjan's SCC ─────────────────────────────────────────────────────────
def compute_sccs(self) -> List[List[int]]:
"""Returns list of SCCs (each a list of state indices), largest first."""
n = len(self.states)
index_counter = [0]
stack = []
lowlink = {}
index = {}
on_stack = {}
sccs = []
def strongconnect(v):
index[v] = index_counter[0]
lowlink[v] = index_counter[0]
index_counter[0] += 1
stack.append(v)
on_stack[v] = True
for w in self.adj[v]:
if w not in index:
strongconnect(w)
lowlink[v] = min(lowlink[v], lowlink[w])
elif on_stack.get(w):
lowlink[v] = min(lowlink[v], index[w])
if lowlink[v] == index[v]:
scc = []
while True:
w = stack.pop()
on_stack[w] = False
scc.append(w)
if w == v:
break
sccs.append(scc)
import sys
sys.setrecursionlimit(100000)
for v in range(n):
if v not in index:
strongconnect(v)
return sccs
# ── SCC reward analysis ───────────────────────────────────────────────────
def scc_satisfied_specs(self, scc: List[int], rewards: List[float]) -> Tuple[float, Set[int]]:
"""
For an SCC, compute which specs have their accepting states inside it.
Returns (total_reward, set_of_satisfied_spec_indices).
"""
satisfied = set()
for idx in scc:
ps = self.states[idx]
aut_states = ps[1:]
for i, aut in enumerate(self.automata):
if aut.is_accepting(aut_states[i]):
satisfied.add(i)
total = sum(rewards[i] for i in satisfied)
return total, satisfied
def is_nontrivial_scc(self, scc: List[int]) -> bool:
"""An SCC is nontrivial if it has >1 state, or 1 state with a self-loop."""
if len(scc) > 1:
return True
v = scc[0]
return v in self.adj[v]
# ── reachability ─────────────────────────────────────────────────────────
def reachable_from_initial(self) -> Set[int]:
visited = set()
queue = deque([self.state_index[self.initial]])
while queue:
v = queue.popleft()
if v in visited:
continue
visited.add(v)
for w in self.adj[v]:
if w not in visited:
queue.append(w)
return visited
def bfs_path(self, src: int, targets: Set[int]) -> Optional[List[int]]:
"""BFS from src to any state in targets. Returns list of state indices."""
if src in targets:
return [src]
parent = {src: None}
queue = deque([src])
while queue:
v = queue.popleft()
for w in self.adj[v]:
if w not in parent:
parent[w] = v
if w in targets:
# reconstruct
path = []
cur = w
while cur is not None:
path.append(cur)
cur = parent[cur]
return list(reversed(path))
queue.append(w)
return None
def find_cycle_through(self, scc_set: Set[int], required_accepting: List[Set[int]]) -> Optional[List[int]]:
"""
Find a cycle within the SCC that passes through at least one accepting
state for each required spec.
Returns a list of state indices forming the cycle (first == last).
"""
# Restrict graph to SCC nodes
# Strategy: chain BFS paths through each required accepting set
# Start from any state in scc, visit a state in required_accepting[0],
# then required_accepting[1], ..., then return to start.
if not scc_set:
return None
start = next(iter(scc_set))
# Build checkpoints: for each spec, one state in scc that is accepting
checkpoints = []
for acc_set in required_accepting:
candidates = acc_set & scc_set
if candidates:
checkpoints.append(next(iter(candidates)))
if not checkpoints:
# trivial cycle: just loop at start (if self-loop exists)
if start in self.adj.get(start, []):
return [start, start]
# find any 2-cycle
path = self._bfs_in_scc(start, {start}, scc_set)
return path
# chain: start -> cp0 -> cp1 -> ... -> cpN -> start
waypoints = [start] + checkpoints + [start]
full_path = []
for i in range(len(waypoints) - 1):
seg = self._bfs_in_scc(waypoints[i], {waypoints[i + 1]}, scc_set)
if seg is None:
return None
if full_path:
full_path.extend(seg[1:]) # skip duplicate junction
else:
full_path.extend(seg)
return full_path
def _bfs_in_scc(self, src: int, targets: Set[int], scc_set: Set[int]) -> Optional[List[int]]:
"""BFS from src to any target, restricted to scc_set."""
if src in targets:
return [src]
parent = {src: None}
queue = deque([src])
while queue:
v = queue.popleft()
for w in self.adj[v]:
if w in scc_set and w not in parent:
parent[w] = v
if w in targets:
path = []
cur = w
while cur is not None:
path.append(cur)
cur = parent[cur]
return list(reversed(path))
queue.append(w)
return None