""" """ from typing import Any from typing import Callable from typing import ParamSpec # spaces import REMOVED — spaces.GPU / spaces.aoti_* are HuggingFace ZeroGPU-only APIs. # They do not exist outside of HF infrastructure and would crash on a local VPS. # Replaced below with standard torch.compile() which gives equivalent or better # performance on a dedicated GPU (RTX 4090 / PRO 6000 Blackwell). import torch from torch.utils._pytree import tree_map P = ParamSpec('P') TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length') TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length') TRANSFORMER_DYNAMIC_SHAPES = { 'hidden_states': { 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, }, 'encoder_hidden_states': { 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, }, 'encoder_hidden_states_mask': { 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, }, 'image_rotary_emb': ({ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, }, { 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, }), } INDUCTOR_CONFIGS = { 'conv_1x1_as_mm': True, 'epilogue_fusion': False, 'coordinate_descent_tuning': True, 'coordinate_descent_check_all_directions': True, 'max_autotune': True, 'triton.cudagraphs': True, } def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs): # ------------------------------------------------------------------------- # CHANGE 1 of 1: spaces.GPU / spaces.aoti_capture / spaces.aoti_compile / # spaces.aoti_apply REMOVED. # # What the original did on HuggingFace ZeroGPU: # 1. spaces.aoti_capture() — traced one forward pass to record all inputs # 2. torch.export.export() — exported the transformer to a portable graph # 3. spaces.aoti_compile() — compiled via torch inductor with INDUCTOR_CONFIGS # 4. spaces.aoti_apply() — swapped the live module for the compiled artifact # # Local replacement: # torch.compile() with the same inductor backend and the same INDUCTOR_CONFIGS # achieves the identical goal — persistent kernel fusion, cudagraphs, and # coordinate-descent autotuning — without any HF-specific infrastructure. # The compiled transformer is set back onto the pipeline in-place, so every # call to pipeline() automatically uses the optimised version, exactly as # the original spaces.aoti_apply() did. # # NOTE: The first inference call after optimize_pipeline_() will be slow # while torch.compile() traces and compiles the kernels. Subsequent calls # are fast. This is the same warm-up behaviour as the original AOTI path. # # To enable Float8 quantisation (commented out in the original too), uncomment: # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig()) # ------------------------------------------------------------------------- # quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig()) pipeline.transformer = torch.compile( pipeline.transformer, backend="inductor", options=INDUCTOR_CONFIGS, fullgraph=False, # False = safe for models with dynamic control flow dynamic=True, # honours the dynamic sequence-length dims declared above )