""" Twill's Main Search Procedure (Algorithm 1 from the paper). Combines Phase 1 (ZLP modulo scheduling) and Phase 2 (SMT joint SWP+WS) in an iterative search over initiation intervals and schedule lengths. Algorithm 1: Twill(G) I ← 0 while true: I ← I + 1 M ← Optimal-Modulo-Schedule(G, I) if M = failure: continue L ← Len(M) while ⌈L/I⌉ = ⌈Len(M)/I⌉: (M*, A*) ← SWP-and-WS(G, M, I, L) if (M*, A*) = failure: L ← L+1; continue return (M*, I, A*) """ import time import math from typing import Optional, Tuple from twill.graph import DependenceGraph from twill.cost_normalization import normalize_costs from twill.modulo_scheduler import optimal_modulo_schedule, ModuloScheduleResult, validate_schedule from twill.smt_joint import swp_and_ws, JointSWPWSResult class TwillResult: """Complete result from the Twill solver. Attributes: joint_result: The JointSWPWSResult containing schedule and warp assignment initial_modulo_schedule: The Phase 1 modulo schedule that seeded the search normalized_costs: The cost normalization result (if used) solve_time_seconds: Total wall-clock time for the solver iterations_tried: Number of I values tried before finding a solution """ def __init__( self, joint_result: JointSWPWSResult, initial_schedule: ModuloScheduleResult, solve_time: float, iterations_tried: int, normalized_costs: Optional[dict] = None, ): self.joint_result = joint_result self.initial_modulo_schedule = initial_schedule self.solve_time_seconds = solve_time self.iterations_tried = iterations_tried self.normalized_costs = normalized_costs @property def schedule(self): return self.joint_result.schedule @property def I(self): return self.joint_result.I @property def warp_assignment(self): return self.joint_result.warp_assignment def __repr__(self): return ( f"TwillResult(\n" f" solve_time={self.solve_time_seconds:.2f}s\n" f" iterations_tried={self.iterations_tried}\n" f" {self.joint_result}\n" f")" ) def twill_solve( graph: DependenceGraph, max_I: int = 20, enable_cost_normalization: bool = True, cost_norm_U: int = 300, enable_memory_constraints: bool = True, enable_warp_constraints: bool = True, modulo_solver_timeout: int = 120, smt_solver_timeout_ms: int = 120000, verbose: bool = True, ) -> Optional[TwillResult]: """Run the full Twill search procedure. This is the main entry point implementing Algorithm 1 from the paper. Args: graph: Loop dependence graph with machine description max_I: Maximum initiation interval to search up to enable_cost_normalization: Apply cost normalization before solving cost_norm_U: Upper bound for cost normalization (Section 5.2) enable_memory_constraints: Include memory capacity constraints (Section 4.2) enable_warp_constraints: Include warp assignment constraints (Section 4.3) modulo_solver_timeout: Timeout for Phase 1 ILP solver (seconds) smt_solver_timeout_ms: Timeout for Phase 2 SMT solver (milliseconds) verbose: Print progress information Returns: TwillResult if a valid schedule is found, None otherwise """ start_time = time.time() if verbose: print(f"=" * 60) print(f"Twill Solver v0.1") print(f"=" * 60) print(f"Graph: {graph}") print(f"Instructions: {[v.name for v in graph.V]}") print(f"Edges: {graph.E}") print(f"Machine: {graph.machine.name}") print(f"Functional units: {graph.machine.functional_units}") print(f"Capacities: {graph.machine.capacities}") print() # Step 0: Cost normalization (Section 5.2) normalized_costs_dict = None if enable_cost_normalization: # Collect all unique cycle counts from instructions and edges cost_items = {} for v in graph.V: cost_items[v.name] = v.cycles if max(cost_items.values()) > cost_norm_U // len(cost_items): if verbose: print(f"Cost Normalization (U={cost_norm_U}):") print(f" Original costs: {cost_items}") normalized_costs_dict, F = normalize_costs(cost_items, U=cost_norm_U) if verbose: print(f" Normalized costs: {normalized_costs_dict}") print(f" Distortion F: {F}") print() # Note: In a full implementation, we would rebuild the graph with # normalized costs. For this implementation, costs are typically # already small (from the input specification) so normalization # is primarily for real GPU cycle counts (e.g., ~1000 cycles for WGMMA). # Compute resource lower bound on I min_I = graph.compute_min_initiation_interval() if verbose: print(f"Minimum I (resource bound): {min_I}") print() iterations_tried = 0 # Algorithm 1: Main search loop for I in range(max(1, min_I), max_I + 1): iterations_tried += 1 if verbose: print(f"--- Trying I = {I} ---") # Phase 1: Optimal Modulo Schedule if verbose: print(f" Phase 1: ILP Modulo Scheduling...") M = optimal_modulo_schedule( graph, I, solver_time_limit=modulo_solver_timeout, verbose=False, ) if M is None: if verbose: print(f" Phase 1: INFEASIBLE for I={I}") continue if verbose: print(f" Phase 1: Found M with L={M.length}, copies={M.num_copies}") print(f" Schedule: {M.schedule}") # Validate valid, violations = validate_schedule(graph, M) if not valid: print(f" WARNING: Schedule validation failed!") for v in violations: print(f" {v}") # Phase 2: Joint SWP + WS, searching over L L = M.length initial_num_copies = M.num_copies while math.ceil(L / I) == initial_num_copies: if verbose: print(f" Phase 2: SMT Joint SWP+WS with L={L}...") result = swp_and_ws( graph=graph, initial_schedule=M, I=I, L=L, enable_memory_constraints=enable_memory_constraints, enable_warp_constraints=enable_warp_constraints, timeout_ms=smt_solver_timeout_ms, verbose=verbose, ) if result is not None: solve_time = time.time() - start_time if verbose: print() print(f"=" * 60) print(f"SOLUTION FOUND in {solve_time:.2f}s") print(f"=" * 60) print(f" Initiation Interval I = {I}") print(f" Schedule Length L = {L}") print(f" Overlapping copies = {result.num_copies}") print(f" Schedule M*: {result.schedule}") print(f" {result.warp_assignment}") return TwillResult( joint_result=result, initial_schedule=M, solve_time=solve_time, iterations_tried=iterations_tried, normalized_costs=normalized_costs_dict, ) if verbose: print(f" Phase 2: UNSAT for L={L}, trying L={L+1}") L += 1 if verbose: print(f" Exhausted L search for I={I} (would change num_copies)") if verbose: print(f"\nNo solution found up to I={max_I}") return None