picgen / optimization.py
devindevine's picture
Upload 22 files
02be60b verified
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
# spaces import REMOVED — spaces.GPU / spaces.aoti_* are HuggingFace ZeroGPU-only APIs.
# They do not exist outside of HF infrastructure and would crash on a local VPS.
# Replaced below with standard torch.compile() which gives equivalent or better
# performance on a dedicated GPU (RTX 4090 / PRO 6000 Blackwell).
import torch
from torch.utils._pytree import tree_map
P = ParamSpec('P')
TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length')
TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length')
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
},
'encoder_hidden_states': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'encoder_hidden_states_mask': {
1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
},
'image_rotary_emb': ({
0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM,
}, {
0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM,
}),
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
# -------------------------------------------------------------------------
# CHANGE 1 of 1: spaces.GPU / spaces.aoti_capture / spaces.aoti_compile /
# spaces.aoti_apply REMOVED.
#
# What the original did on HuggingFace ZeroGPU:
# 1. spaces.aoti_capture() — traced one forward pass to record all inputs
# 2. torch.export.export() — exported the transformer to a portable graph
# 3. spaces.aoti_compile() — compiled via torch inductor with INDUCTOR_CONFIGS
# 4. spaces.aoti_apply() — swapped the live module for the compiled artifact
#
# Local replacement:
# torch.compile() with the same inductor backend and the same INDUCTOR_CONFIGS
# achieves the identical goal — persistent kernel fusion, cudagraphs, and
# coordinate-descent autotuning — without any HF-specific infrastructure.
# The compiled transformer is set back onto the pipeline in-place, so every
# call to pipeline() automatically uses the optimised version, exactly as
# the original spaces.aoti_apply() did.
#
# NOTE: The first inference call after optimize_pipeline_() will be slow
# while torch.compile() traces and compiles the kernels. Subsequent calls
# are fast. This is the same warm-up behaviour as the original AOTI path.
#
# To enable Float8 quantisation (commented out in the original too), uncomment:
# quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
# -------------------------------------------------------------------------
# quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
pipeline.transformer = torch.compile(
pipeline.transformer,
backend="inductor",
options=INDUCTOR_CONFIGS,
fullgraph=False, # False = safe for models with dynamic control flow
dynamic=True, # honours the dynamic sequence-length dims declared above
)