jiadisu Claude Opus 4.6 commited on
Commit ·
e6066e8
1
Parent(s): 1b389ac
Switch back to Docker SDK with local pkgs
Browse files- Dockerfile: CUDA 12.4 base image, install MagiCompiler from pkgs/,
flash-attn, and stable-audio whl
- README.md: sdk: docker
- app.py: remove spaces.GPU decorator
- pkgs/: MagiCompiler source + stable-audio whl
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- Dockerfile +30 -18
- README.md +1 -2
- app.py +1 -4
- pkgs/MagiCompiler/.gitignore +216 -0
- pkgs/MagiCompiler/.pre-commit-config.yaml +60 -0
- pkgs/MagiCompiler/LICENSE +201 -0
- pkgs/MagiCompiler/README.md +186 -0
- pkgs/MagiCompiler/docs/AutoCudaGraphDesign.md +174 -0
- pkgs/MagiCompiler/docs/Hunyuan15Benchmark.md +79 -0
- pkgs/MagiCompiler/docs/Wan2.2Benchmark.md +72 -0
- pkgs/MagiCompiler/docs/WhyMagiCompiler.md +246 -0
- pkgs/MagiCompiler/docs/WhyMagiDepyf.md +175 -0
- pkgs/MagiCompiler/docs/assets/submod_0_rank_0.pdf +3 -0
- pkgs/MagiCompiler/magi_compiler/__init__.py +17 -0
- pkgs/MagiCompiler/magi_compiler/_cache_data_cls.py +28 -0
- pkgs/MagiCompiler/magi_compiler/api.py +666 -0
- pkgs/MagiCompiler/magi_compiler/compile_artifacts.py +125 -0
- pkgs/MagiCompiler/magi_compiler/config.py +282 -0
- pkgs/MagiCompiler/magi_compiler/cuda/cudart.py +60 -0
- pkgs/MagiCompiler/magi_compiler/cuda_graph_mgr.py +931 -0
- pkgs/MagiCompiler/magi_compiler/joint_graph_partition.py +180 -0
- pkgs/MagiCompiler/magi_compiler/magi_backend.py +607 -0
- pkgs/MagiCompiler/magi_compiler/magi_compiler_base.py +219 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/__init__.py +21 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/__init__.py +19 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/__init__.py +22 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/decompile_context.py +53 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handler_registry.py +62 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/__init__.py +22 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/arithmetic.py +144 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/calls.py +200 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/containers.py +200 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/control_flow.py +273 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/load_store.py +262 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/stack_ops.py +84 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/instruction.py +129 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/source_emitter.py +153 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/decompiler.py +230 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/__init__.py +35 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/branch_dedup.py +99 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/for_temps.py +57 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/inline_temps.py +165 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/recompiler.py +53 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/demo_toy_example.py +54 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/__init__.py +57 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/dump_src.py +78 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/introspect.py +524 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/model.py +241 -0
- pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/result.py +51 -0
.gitattributes
CHANGED
|
@@ -4,3 +4,5 @@
|
|
| 4 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 4 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
|
@@ -1,28 +1,46 @@
|
|
| 1 |
# =============================================================================
|
| 2 |
# HF Spaces Docker image for daVinci-MagiHuman
|
| 3 |
-
# Hardware: A100-80GB (
|
| 4 |
# =============================================================================
|
| 5 |
-
|
| 6 |
-
# - CUDA 12.4, cuDNN, Python 3.12, PyTorch 2.9
|
| 7 |
-
# - MagiCompiler (pre-installed)
|
| 8 |
-
# - Flash Attention 3 (Hopper) (pre-installed)
|
| 9 |
-
# =============================================================================
|
| 10 |
-
FROM sandai/magi-compiler:latest
|
| 11 |
|
| 12 |
ENV DEBIAN_FRONTEND=noninteractive
|
| 13 |
ENV PYTHONUNBUFFERED=1
|
| 14 |
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 15 |
ENV GRADIO_SERVER_PORT=7860
|
| 16 |
|
| 17 |
-
# System deps
|
| 18 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
WORKDIR /app
|
| 23 |
|
| 24 |
# ---------------------------------------------------------------------------
|
| 25 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# ---------------------------------------------------------------------------
|
| 27 |
COPY requirements.txt requirements-nodeps.txt ./
|
| 28 |
RUN pip install --no-cache-dir -r requirements.txt && \
|
|
@@ -36,16 +54,10 @@ COPY inference/ inference/
|
|
| 36 |
COPY example/ example/
|
| 37 |
COPY app.py .
|
| 38 |
|
| 39 |
-
# ---------------------------------------------------------------------------
|
| 40 |
# Model weights are downloaded at runtime from HF Hub.
|
| 41 |
-
#
|
| 42 |
-
#
|
| 43 |
-
# Persistent storage (/data) is recommended on HF Spaces so weights survive
|
| 44 |
-
# container restarts. Enable it in Space settings → "Persistent storage".
|
| 45 |
-
# ---------------------------------------------------------------------------
|
| 46 |
ENV MODEL_ROOT=/data/models
|
| 47 |
|
| 48 |
-
# HF Spaces requires the app to listen on port 7860
|
| 49 |
EXPOSE 7860
|
| 50 |
|
| 51 |
CMD ["python", "app.py"]
|
|
|
|
| 1 |
# =============================================================================
|
| 2 |
# HF Spaces Docker image for daVinci-MagiHuman
|
| 3 |
+
# Hardware: A100-80GB (recommended)
|
| 4 |
# =============================================================================
|
| 5 |
+
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
ENV DEBIAN_FRONTEND=noninteractive
|
| 8 |
ENV PYTHONUNBUFFERED=1
|
| 9 |
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 10 |
ENV GRADIO_SERVER_PORT=7860
|
| 11 |
|
| 12 |
+
# System deps
|
| 13 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
python3.12 python3.12-dev python3.12-venv python3-pip \
|
| 15 |
+
git ffmpeg libsndfile1 ninja-build && \
|
| 16 |
+
rm -rf /var/lib/apt/lists/* && \
|
| 17 |
+
ln -sf /usr/bin/python3.12 /usr/bin/python && \
|
| 18 |
+
ln -sf /usr/bin/python3.12 /usr/bin/python3
|
| 19 |
|
| 20 |
WORKDIR /app
|
| 21 |
|
| 22 |
# ---------------------------------------------------------------------------
|
| 23 |
+
# PyTorch (must be installed first — MagiCompiler build depends on it)
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 26 |
+
pip install --no-cache-dir torch torchvision torchaudio \
|
| 27 |
+
--index-url https://download.pytorch.org/whl/cu124
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Local packages: MagiCompiler + stable-audio whl
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
COPY pkgs/ pkgs/
|
| 33 |
+
RUN pip install -e ./pkgs/MagiCompiler \
|
| 34 |
+
--no-build-isolation --config-settings editable_mode=compat && \
|
| 35 |
+
pip install --no-cache-dir pkgs/magife_stable_audio_open-1.0.0+mav.1-py3-none-any.whl
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Flash Attention (pre-built wheel for CUDA 12.4 + PyTorch 2.9)
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
RUN pip install --no-cache-dir flash-attn --no-build-isolation
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Project Python dependencies
|
| 44 |
# ---------------------------------------------------------------------------
|
| 45 |
COPY requirements.txt requirements-nodeps.txt ./
|
| 46 |
RUN pip install --no-cache-dir -r requirements.txt && \
|
|
|
|
| 54 |
COPY example/ example/
|
| 55 |
COPY app.py .
|
| 56 |
|
|
|
|
| 57 |
# Model weights are downloaded at runtime from HF Hub.
|
| 58 |
+
# Enable "Persistent storage" in Space settings so /data survives restarts.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
ENV MODEL_ROOT=/data/models
|
| 60 |
|
|
|
|
| 61 |
EXPOSE 7860
|
| 62 |
|
| 63 |
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -3,8 +3,7 @@ title: daVinci-MagiHuman
|
|
| 3 |
emoji: 🎬
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
-
sdk:
|
| 7 |
-
sdk_version: 5.23.0
|
| 8 |
app_port: 7860
|
| 9 |
---
|
| 10 |
|
|
|
|
| 3 |
emoji: 🎬
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: purple
|
| 6 |
+
sdk: docker
|
|
|
|
| 7 |
app_port: 7860
|
| 8 |
---
|
| 9 |
|
app.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"""
|
| 3 |
Gradio frontend for daVinci-MagiHuman distilled model.
|
| 4 |
|
| 5 |
-
Designed for Hugging Face Spaces
|
| 6 |
Accepts an image + text prompt + duration, generates audio-video output.
|
| 7 |
"""
|
| 8 |
|
|
@@ -12,8 +12,6 @@ import sys
|
|
| 12 |
import tempfile
|
| 13 |
import uuid
|
| 14 |
|
| 15 |
-
import spaces
|
| 16 |
-
|
| 17 |
# ---------------------------------------------------------------------------
|
| 18 |
# 1. Download all model weights from HF Hub (runs on CPU, cached)
|
| 19 |
# ---------------------------------------------------------------------------
|
|
@@ -132,7 +130,6 @@ print("[app] Pipeline ready.")
|
|
| 132 |
# 4. Inference wrapper — @spaces.GPU requests a ZeroGPU allocation
|
| 133 |
# duration= sets the max GPU time in seconds (default 60, max 300)
|
| 134 |
# ---------------------------------------------------------------------------
|
| 135 |
-
@spaces.GPU(duration=300)
|
| 136 |
def generate_video(
|
| 137 |
image,
|
| 138 |
prompt: str,
|
|
|
|
| 2 |
"""
|
| 3 |
Gradio frontend for daVinci-MagiHuman distilled model.
|
| 4 |
|
| 5 |
+
Designed for Hugging Face Spaces (Docker SDK, A100-80GB GPU).
|
| 6 |
Accepts an image + text prompt + duration, generates audio-video output.
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 12 |
import tempfile
|
| 13 |
import uuid
|
| 14 |
|
|
|
|
|
|
|
| 15 |
# ---------------------------------------------------------------------------
|
| 16 |
# 1. Download all model weights from HF Hub (runs on CPU, cached)
|
| 17 |
# ---------------------------------------------------------------------------
|
|
|
|
| 130 |
# 4. Inference wrapper — @spaces.GPU requests a ZeroGPU allocation
|
| 131 |
# duration= sets the max GPU time in seconds (default 60, max 300)
|
| 132 |
# ---------------------------------------------------------------------------
|
|
|
|
| 133 |
def generate_video(
|
| 134 |
image,
|
| 135 |
prompt: str,
|
pkgs/MagiCompiler/.gitignore
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# magi_compiler
|
| 2 |
+
magi_compiler/_version.py
|
| 3 |
+
magi_dump_src_dir/
|
| 4 |
+
*.nsys-rep
|
| 5 |
+
*.ncu-rep
|
| 6 |
+
|
| 7 |
+
# Byte-compiled / optimized / DLL files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[codz]
|
| 10 |
+
*$py.class
|
| 11 |
+
|
| 12 |
+
# C extensions
|
| 13 |
+
*.so
|
| 14 |
+
|
| 15 |
+
# Distribution / packaging
|
| 16 |
+
.Python
|
| 17 |
+
build/
|
| 18 |
+
develop-eggs/
|
| 19 |
+
dist/
|
| 20 |
+
downloads/
|
| 21 |
+
eggs/
|
| 22 |
+
.eggs/
|
| 23 |
+
lib/
|
| 24 |
+
lib64/
|
| 25 |
+
parts/
|
| 26 |
+
sdist/
|
| 27 |
+
var/
|
| 28 |
+
wheels/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py.cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
cover/
|
| 59 |
+
|
| 60 |
+
# Translations
|
| 61 |
+
*.mo
|
| 62 |
+
*.pot
|
| 63 |
+
|
| 64 |
+
# Vscode stuff:
|
| 65 |
+
.vscode
|
| 66 |
+
|
| 67 |
+
# Django stuff:
|
| 68 |
+
*.log
|
| 69 |
+
local_settings.py
|
| 70 |
+
db.sqlite3
|
| 71 |
+
db.sqlite3-journal
|
| 72 |
+
|
| 73 |
+
# Flask stuff:
|
| 74 |
+
instance/
|
| 75 |
+
.webassets-cache
|
| 76 |
+
|
| 77 |
+
# Scrapy stuff:
|
| 78 |
+
.scrapy
|
| 79 |
+
|
| 80 |
+
# Sphinx documentation
|
| 81 |
+
docs/_build/
|
| 82 |
+
|
| 83 |
+
# PyBuilder
|
| 84 |
+
.pybuilder/
|
| 85 |
+
target/
|
| 86 |
+
|
| 87 |
+
# Jupyter Notebook
|
| 88 |
+
.ipynb_checkpoints
|
| 89 |
+
|
| 90 |
+
# IPython
|
| 91 |
+
profile_default/
|
| 92 |
+
ipython_config.py
|
| 93 |
+
|
| 94 |
+
# pyenv
|
| 95 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 96 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 97 |
+
# .python-version
|
| 98 |
+
|
| 99 |
+
# pipenv
|
| 100 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 101 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 102 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 103 |
+
# install all needed dependencies.
|
| 104 |
+
#Pipfile.lock
|
| 105 |
+
|
| 106 |
+
# UV
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
#uv.lock
|
| 111 |
+
|
| 112 |
+
# poetry
|
| 113 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 114 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 115 |
+
# commonly ignored for libraries.
|
| 116 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 117 |
+
#poetry.lock
|
| 118 |
+
#poetry.toml
|
| 119 |
+
|
| 120 |
+
# pdm
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 122 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 123 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 124 |
+
#pdm.lock
|
| 125 |
+
pdm.toml
|
| 126 |
+
.pdm-python
|
| 127 |
+
.pdm-build/
|
| 128 |
+
|
| 129 |
+
# pixi
|
| 130 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 131 |
+
#pixi.lock
|
| 132 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 133 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 134 |
+
.pixi
|
| 135 |
+
|
| 136 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 137 |
+
__pypackages__/
|
| 138 |
+
|
| 139 |
+
# Celery stuff
|
| 140 |
+
celerybeat-schedule
|
| 141 |
+
celerybeat.pid
|
| 142 |
+
|
| 143 |
+
# SageMath parsed files
|
| 144 |
+
*.sage.py
|
| 145 |
+
|
| 146 |
+
# Environments
|
| 147 |
+
.env
|
| 148 |
+
.envrc
|
| 149 |
+
.venv
|
| 150 |
+
env/
|
| 151 |
+
venv/
|
| 152 |
+
ENV/
|
| 153 |
+
env.bak/
|
| 154 |
+
venv.bak/
|
| 155 |
+
|
| 156 |
+
# Spyder project settings
|
| 157 |
+
.spyderproject
|
| 158 |
+
.spyproject
|
| 159 |
+
|
| 160 |
+
# Rope project settings
|
| 161 |
+
.ropeproject
|
| 162 |
+
|
| 163 |
+
# mkdocs documentation
|
| 164 |
+
/site
|
| 165 |
+
|
| 166 |
+
# mypy
|
| 167 |
+
.mypy_cache/
|
| 168 |
+
.dmypy.json
|
| 169 |
+
dmypy.json
|
| 170 |
+
|
| 171 |
+
# Pyre type checker
|
| 172 |
+
.pyre/
|
| 173 |
+
|
| 174 |
+
# pytype static type analyzer
|
| 175 |
+
.pytype/
|
| 176 |
+
|
| 177 |
+
# Cython debug symbols
|
| 178 |
+
cython_debug/
|
| 179 |
+
|
| 180 |
+
# PyCharm
|
| 181 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 182 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 183 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 184 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 185 |
+
#.idea/
|
| 186 |
+
|
| 187 |
+
# Abstra
|
| 188 |
+
# Abstra is an AI-powered process automation framework.
|
| 189 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 190 |
+
# Learn more at https://abstra.io/docs
|
| 191 |
+
.abstra/
|
| 192 |
+
|
| 193 |
+
# Visual Studio Code
|
| 194 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 195 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 196 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 197 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 198 |
+
# .vscode/
|
| 199 |
+
|
| 200 |
+
# Ruff stuff:
|
| 201 |
+
.ruff_cache/
|
| 202 |
+
|
| 203 |
+
# PyPI configuration file
|
| 204 |
+
.pypirc
|
| 205 |
+
|
| 206 |
+
# Cursor
|
| 207 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 208 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 209 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 210 |
+
.cursorignore
|
| 211 |
+
.cursorindexingignore
|
| 212 |
+
|
| 213 |
+
# Marimo
|
| 214 |
+
marimo/_static/
|
| 215 |
+
marimo/_lsp/
|
| 216 |
+
__marimo__/
|
pkgs/MagiCompiler/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exclude: \.patch$
|
| 2 |
+
repos:
|
| 3 |
+
- repo: local
|
| 4 |
+
hooks:
|
| 5 |
+
- id: copyright_checker
|
| 6 |
+
name: copyright_checker
|
| 7 |
+
entry: python3 ./.github/.codestyle/copyright.hook
|
| 8 |
+
language: system
|
| 9 |
+
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
|
| 10 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 11 |
+
rev: v4.4.0
|
| 12 |
+
hooks:
|
| 13 |
+
- id: check-added-large-files
|
| 14 |
+
args:
|
| 15 |
+
- --maxkb=30720
|
| 16 |
+
- id: check-merge-conflict
|
| 17 |
+
- id: check-symlinks
|
| 18 |
+
- id: detect-private-key
|
| 19 |
+
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
|
| 20 |
+
- id: end-of-file-fixer
|
| 21 |
+
- id: trailing-whitespace
|
| 22 |
+
- id: requirements-txt-fixer
|
| 23 |
+
- id: sort-simple-yaml
|
| 24 |
+
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
|
| 25 |
+
rev: v1.5.1
|
| 26 |
+
hooks:
|
| 27 |
+
- id: remove-crlf
|
| 28 |
+
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
|
| 29 |
+
- id: remove-tabs
|
| 30 |
+
name: Tabs remover (C++)
|
| 31 |
+
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
|
| 32 |
+
args: [--whitespaces-count, '2']
|
| 33 |
+
- id: remove-tabs
|
| 34 |
+
name: Tabs remover (Python)
|
| 35 |
+
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
| 36 |
+
args: [--whitespaces-count, '4']
|
| 37 |
+
- repo: https://github.com/psf/black.git
|
| 38 |
+
rev: 23.3.0
|
| 39 |
+
hooks:
|
| 40 |
+
- id: black
|
| 41 |
+
args: [--line-length=127, --skip-string-normalization, --skip-magic-trailing-comma]
|
| 42 |
+
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
| 43 |
+
- repo: https://github.com/pre-commit/mirrors-isort
|
| 44 |
+
rev: v5.10.1
|
| 45 |
+
hooks:
|
| 46 |
+
- id: isort
|
| 47 |
+
args: [--profile=black, --line-length=127, --multi-line=3, --force-grid-wrap=0]
|
| 48 |
+
files: \.py$
|
| 49 |
+
- repo: https://github.com/PyCQA/autoflake
|
| 50 |
+
rev: v2.3.1
|
| 51 |
+
hooks:
|
| 52 |
+
- id: autoflake
|
| 53 |
+
args: [--remove-all-unused-imports, --remove-unused-variables, --in-place, --ignore-init-module-imports, --ignore-pass-after-docstring]
|
| 54 |
+
files: \.py$
|
| 55 |
+
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks.git
|
| 56 |
+
rev: v2.9.0
|
| 57 |
+
hooks:
|
| 58 |
+
- id: pretty-format-yaml
|
| 59 |
+
args: [--autofix, --indent, '4']
|
| 60 |
+
additional_dependencies: [setuptools]
|
pkgs/MagiCompiler/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
pkgs/MagiCompiler/README.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## MagiCompiler
|
| 2 |
+
|
| 3 |
+
An engineering-oriented compiler and execution augmentation library for PyTorch 2.8+, providing module-level compilation decorators, backend adapters, graph partitioning strategies, readable and reusable compile artifacts, and tightly integrated runtime scheduling for any inference engine. The design goal is to systematically expose capabilities of PyTorch Dynamo / AOTAutograd / Inductor / Triton while prioritizing correctness, stability, observability, and maintainability.
|
| 4 |
+
|
| 5 |
+
### Design Overview
|
| 6 |
+
|
| 7 |
+
- Compilation entrypoint: the `@magi_compile` decorator augments `nn.Module.forward` with compilation, including dynamic-shape annotation and argument validation.
|
| 8 |
+
- Partitioning and passes: configurable graph partitioning and pass management (e.g., `InductorPass`, `PostGradPassManager`) for fusion, kernel generation, and tuning.
|
| 9 |
+
- Artifact system: persists compile artifacts using a Python file/directory layout for readability, auditability, and portability (see “Compile Cache Overview”).
|
| 10 |
+
- Configuration: `CompileConfig` centralizes backend, partition rules, cache root, runtime shapes, and other key parameters.
|
| 11 |
+
|
| 12 |
+
### Key Features
|
| 13 |
+
|
| 14 |
+
- Dynamic-shape annotations:
|
| 15 |
+
- Automatic inference: when a `forward` parameter is annotated as `torch.Tensor` or `torch.Tensor | None`, dimension 0 is treated as dynamic by default.
|
| 16 |
+
- Explicit specification: use `@magi_compile(dynamic_arg_dims={...})` to mark dimensions (negative indices supported).
|
| 17 |
+
- Consistency constraints: parameters that alternately appear as `None` and non-`None` across the model lifetime cannot be captured into the same computation graph.
|
| 18 |
+
- Backend selection and standalone compilation:
|
| 19 |
+
- `inductor` mode defaults to PyTorch 2.8+ `standalone_compile`, producing reusable artifacts.
|
| 20 |
+
- `eager` mode is available for debugging or fallback paths.
|
| 21 |
+
- Partitioning and passes: operator-set-driven partition rules and pass contexts that stabilize subgraph boundaries and kernel generation across runtime shapes.
|
| 22 |
+
- Readable, portable artifacts: structured directories with Python files for quick triage and cross-environment debugging.
|
| 23 |
+
- Engine integration: the decorator reads engine-level `CompileConfig` to stay aligned with distributed/scheduling components.
|
| 24 |
+
|
| 25 |
+
## Installation and Requirements
|
| 26 |
+
|
| 27 |
+
- Python ≥ 3.10
|
| 28 |
+
- PyTorch ≥ 2.8 (with `torch._inductor.standalone_compile` available)
|
| 29 |
+
- Recommended to be used within the Athena environment, together with its dependencies and distributed components (e.g., CUDA Graph manager).
|
| 30 |
+
|
| 31 |
+
For local development, install in editable mode:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
pip install -e . --no-build-isolation --config-settings editable_mode=compat
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Quick Start
|
| 38 |
+
|
| 39 |
+
### Minimal Example (automatic dynamic-dim inference)
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
import torch
|
| 43 |
+
from torch import nn
|
| 44 |
+
from magi_compiler.decorator import magi_compile
|
| 45 |
+
|
| 46 |
+
@magi_compile
|
| 47 |
+
class MyModel(nn.Module):
|
| 48 |
+
def __init__(self, *, model_config):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.linear = nn.Linear(10, 5)
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor | None) -> torch.Tensor:
|
| 53 |
+
if y is not None:
|
| 54 |
+
return self.linear(x + y)
|
| 55 |
+
return self.linear(x)
|
| 56 |
+
|
| 57 |
+
# In Athena, model_config is typically provided by the engine
|
| 58 |
+
model = MyModel(model_config=...)
|
| 59 |
+
out1 = model(torch.randn(4, 10), torch.randn(4, 10))
|
| 60 |
+
out2 = model(torch.randn(8, 10), None) # dynamic batch dimension
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
### Explicit Dynamic-Dim Specification
|
| 64 |
+
|
| 65 |
+
```python
|
| 66 |
+
@magi_compile(dynamic_arg_dims={"x": -1}) # mark the last dimension as dynamic
|
| 67 |
+
class DynamicDimModel(nn.Module):
|
| 68 |
+
def __init__(self, *, model_config):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.proj = nn.Linear(16, 16, bias=False)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return self.proj(x)
|
| 74 |
+
|
| 75 |
+
m = DynamicDimModel(model_config=...)
|
| 76 |
+
_ = m(torch.randn(2, 16))
|
| 77 |
+
_ = m(torch.randn(2, 32)) # allow the last dimension to vary
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Configuration and Modes
|
| 81 |
+
|
| 82 |
+
- `CompileConfig`: centralizes compile parameters (backend, cache paths, partition strategy, dynamic shapes, traced files, etc.).
|
| 83 |
+
- `CompileMode`: typical setting is `CompileMode.TORCH_COMPILE`.
|
| 84 |
+
- Backends:
|
| 85 |
+
- `inductor`: uses `standalone_compile` to produce reusable artifacts, ideal for production deployments.
|
| 86 |
+
- `eager`: convenient for rapid debugging or as a fallback.
|
| 87 |
+
|
| 88 |
+
## Architecture and Execution Flow (Brief)
|
| 89 |
+
|
| 90 |
+
1. `@magi_compile` wraps `nn.Module`:
|
| 91 |
+
- infers/validates `dynamic_arg_dims`;
|
| 92 |
+
- extends MRO by injecting `MagiCompilerBase`;
|
| 93 |
+
- reads engine-level `CompileConfig` in MagiCompiler.
|
| 94 |
+
2. `CompilerManager`:
|
| 95 |
+
- defines cache keys using `(runtime_shape, graph_index, backend)`;
|
| 96 |
+
- dispatches to backends via `CompilerInterface` (`InductorStandaloneAdaptor` or `EagerAdaptor`);
|
| 97 |
+
- applies partition rules and pass contexts within `compile_context(...)`;
|
| 98 |
+
- serializes compile artifacts into a human-readable directory structure.
|
| 99 |
+
3. Monitoring and statistics:
|
| 100 |
+
- counters and timestamps report per-shape/per-subgraph latencies and milestones.
|
| 101 |
+
|
| 102 |
+
## Compile Cache Overview
|
| 103 |
+
|
| 104 |
+
This document summarizes the cache files generated by `torch.compile` (TorchDynamo + TorchInductor + AOTAutograd + Triton). Reference path: `cache/`.
|
| 105 |
+
|
| 106 |
+
### Directory Layout (Tree)
|
| 107 |
+
|
| 108 |
+
```text
|
| 109 |
+
cache/
|
| 110 |
+
├─ depyf/
|
| 111 |
+
│ └─ rank_0/
|
| 112 |
+
│ ├─ __transformed_code_0_for_forward.py
|
| 113 |
+
│ ├─ decompiled_code.py
|
| 114 |
+
│ ├─ full_code_for_forward_0.py
|
| 115 |
+
│ ├─ __compiled_fn_1.BEFORE_PRE_GRAD.{0..N}.py
|
| 116 |
+
│ ├─ __compiled_fn_1.kernel_{0..K}.py
|
| 117 |
+
│ ├─ __compiled_fn_1.__compiled_fn_1_<uuid>.0.py
|
| 118 |
+
│ ├─ __compiled_fn_1.Before_split.0.py
|
| 119 |
+
│ ├─ __compiled_fn_1.After_split.0.py
|
| 120 |
+
│ ├─ __compiled_fn_1.pre_split_module.0.py
|
| 121 |
+
│ ├─ __compiled_fn_1.post_split_module.0.py
|
| 122 |
+
│ └─ __compiled_fn_1.pre_insert_deferred_runtime_asserts__<uuid>.0.py
|
| 123 |
+
│
|
| 124 |
+
└─ torch_compile_cache/
|
| 125 |
+
└─ bfa0df33ea/ # Hash for graph + compile options + device, etc.
|
| 126 |
+
└─ rank_0/ # Rank id in distributed/multi-GPU runs
|
| 127 |
+
└─ backbone/
|
| 128 |
+
├─ computation_graph.py
|
| 129 |
+
├─ magi_compile_cache.py
|
| 130 |
+
├─ artifact_shape_None_subgraph_0/
|
| 131 |
+
├─ artifact_shape_None_subgraph_1/
|
| 132 |
+
├─ ...
|
| 133 |
+
└─ artifact_shape_None_subgraph_30/
|
| 134 |
+
├─ ir/
|
| 135 |
+
│ └─ *.py # Python/Triton kernels generated by Inductor for this subgraph
|
| 136 |
+
├─ fxgraph/
|
| 137 |
+
│ └─ */*/<hash> # Binary FX IR/metadata snapshots (not human-readable)
|
| 138 |
+
├─ aotautograd/
|
| 139 |
+
│ └─ */*/<hash> # AOTAutograd partition/capture metadata and artifacts
|
| 140 |
+
├─ 44/
|
| 141 |
+
│ └─ *.py # Other sharded/generated code buckets
|
| 142 |
+
└─ ... # Structure may vary slightly across subgraphs
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### What Each File/Dir Is For
|
| 146 |
+
|
| 147 |
+
- `cache/depyf/` (TorchDynamo/Depyf debug exports)
|
| 148 |
+
- `__transformed_code_0_for_forward.py`: The Dynamo-transformed `forward` code (diff-friendly view of pre/post transformation).
|
| 149 |
+
- `decompiled_code.py`: Decompiled snapshot to help map traced graphs back to original Python.
|
| 150 |
+
- `full_code_for_forward_0.py`: A more complete expanded `forward` for inspection.
|
| 151 |
+
- `__compiled_fn_1.BEFORE_PRE_GRAD.{i}.py`: Intermediate wrapper snapshots at specific compile stages (e.g., before autodiff).
|
| 152 |
+
- `__compiled_fn_1.kernel_{k}.py`: Entrypoints/wrappers for kernels generated at various stages.
|
| 153 |
+
- `Before_split` / `After_split` / `pre_split_module` / `post_split_module`: Intermediate forms around graph partitioning.
|
| 154 |
+
- `pre_insert_deferred_runtime_asserts__*.py`: Snapshot before inserting deferred runtime assertions (dynamic shapes/guards).
|
| 155 |
+
|
| 156 |
+
- `cache/torch_compile_cache/` (TorchInductor artifacts)
|
| 157 |
+
- `bfa0df33ea/`: Namespace keyed by a hash of model structure, compile settings, and device info.
|
| 158 |
+
- `rank_0/`: Bucket per process rank for distributed runs.
|
| 159 |
+
- `backbone/`:
|
| 160 |
+
- `computation_graph.py`: Full model FX GraphModule with symbolic dims; shared across subgraph kernels.
|
| 161 |
+
- `magi_compile_cache.py`:
|
| 162 |
+
- Maps subgraph indices to artifact directories, e.g. `(None, i, 'inductor_standalone') -> artifact_shape_None_subgraph_i/`.
|
| 163 |
+
- Registers and asynchronously compiles Triton kernels via `AsyncCompile.triton(...)`, including autotune metadata, device properties, scheduling hints, etc.
|
| 164 |
+
- `artifact_shape_None_subgraph_{N}/`:
|
| 165 |
+
- `ir/*.py`: Inductor-generated Python/Triton kernels and scheduling code for this subgraph (readable).
|
| 166 |
+
- `fxgraph/*/*/<hash>`: FX IR/metadata snapshots for fast graph reconstruction (binary; do not edit).
|
| 167 |
+
- `aotautograd/*/*/<hash>`: AOTAutograd partitions/captures and replay requirements.
|
| 168 |
+
- Additional hashed/prefixed buckets (e.g., `44/`, `o5/`, `55/`, `br/`) containing generated operator/subtask code.
|
| 169 |
+
|
| 170 |
+
### FAQ
|
| 171 |
+
|
| 172 |
+
- How are these caches produced?
|
| 173 |
+
- At runtime by `torch.compile(...)`, after TorchDynamo tracing, AOTAutograd partitioning, TorchInductor lowering/fusion, and Triton codegen.
|
| 174 |
+
- Will they change across runs?
|
| 175 |
+
- Yes. Different input shapes, env vars, device info, or compile options can produce different hash namespaces (e.g., a new `bfa0df33ea`).
|
| 176 |
+
- Is it safe to delete them?
|
| 177 |
+
- Yes. You can delete `cache/`. It will be rebuilt on demand; the next run will be slower due to recompilation.
|
| 178 |
+
|
| 179 |
+
## Compatibility and Recommendations
|
| 180 |
+
|
| 181 |
+
- Prefer official PyTorch ≥ 2.8 builds to ensure `standalone_compile` availability.
|
| 182 |
+
- For highly dynamic models, explicitly mark key dynamic dimensions to improve graph capture and cache reuse.
|
| 183 |
+
|
| 184 |
+
## Acknowledgments
|
| 185 |
+
|
| 186 |
+
This library builds upon capabilities of PyTorch Dynamo, AOTAutograd, Inductor, and Triton, and incorporates engineering practices and interface designs inspired by the vLLM community. We thank the relevant open-source communities and contributors.
|
pkgs/MagiCompiler/docs/AutoCudaGraphDesign.md
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## AutoCudaGraph Design
|
| 2 |
+
|
| 3 |
+
Author: ZhiyaoCen
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:
|
| 7 |
+
* Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
|
| 8 |
+
* Support dynamic shape adaptation (sequence length expansion)
|
| 9 |
+
* Optimize memory efficiency via global memory pool and static tensor reuse
|
| 10 |
+
* Ensure consistency between cached graphs and runtime inputs/outputs
|
| 11 |
+
## Key Components
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
### CudaGraphMgr (Core Manager)
|
| 15 |
+
Singleton class managing all CUDA Graph operations:
|
| 16 |
+
```python
|
| 17 |
+
class CudaGraphMgr:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
|
| 20 |
+
self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
**Core Methods**
|
| 24 |
+
| Method | Purpose |
|
| 25 |
+
|---------------------------------|----------------------------------------|
|
| 26 |
+
| run() | Main entry: Replay cached graph or warm up & capture new graph|
|
| 27 |
+
| wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors |
|
| 28 |
+
| wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping
|
| 29 |
+
| get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible|
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
### Signature System
|
| 33 |
+
|
| 34 |
+
StaticSignature
|
| 35 |
+
```python
|
| 36 |
+
@dataclass(unsafe_hash=True)
|
| 37 |
+
class StaticSignature(HashableDataclass):
|
| 38 |
+
func_name: str = ""
|
| 39 |
+
tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
|
| 40 |
+
```
|
| 41 |
+
* Encodes fixed properties of input tensors (dtype, static dimensions)
|
| 42 |
+
* Used as primary key for static tensor buffer caching
|
| 43 |
+
|
| 44 |
+
DynamicSignature
|
| 45 |
+
```python
|
| 46 |
+
@dataclass(unsafe_hash=True)
|
| 47 |
+
class DynamicSignature(HashableDataclass):
|
| 48 |
+
tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
|
| 49 |
+
literals_info: LiteralsInfo = None
|
| 50 |
+
```
|
| 51 |
+
* Tracks dynamic dimensions (sequence length) and literal parameters
|
| 52 |
+
* Secondary key for graph entry lookup
|
| 53 |
+
|
| 54 |
+
### Tensor Management
|
| 55 |
+
```python
|
| 56 |
+
@dataclass
|
| 57 |
+
class StaticTensorEntry:
|
| 58 |
+
input_tensors: Optional[List[torch.Tensor]] = None
|
| 59 |
+
output_tensors: Optional[List[torch.Tensor]] = None
|
| 60 |
+
template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
|
| 61 |
+
```
|
| 62 |
+
* Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
|
| 63 |
+
* Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
|
| 64 |
+
* Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
### Graph Management
|
| 68 |
+
```python
|
| 69 |
+
@dataclass
|
| 70 |
+
class GraphEntry:
|
| 71 |
+
graph: Optional[torch.cuda.CUDAGraph] = None
|
| 72 |
+
inconsistent: bool = False
|
| 73 |
+
invalid: bool = False
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class OutputTemplateEntry:
|
| 77 |
+
graph_entry_dict: Dict[int, GraphEntry] = None
|
| 78 |
+
output_template: Any = None
|
| 79 |
+
```
|
| 80 |
+
* Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
|
| 81 |
+
* Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
|
| 82 |
+
* Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.
|
| 83 |
+
|
| 84 |
+
## Execution Flow
|
| 85 |
+
### Inline Replay (Fast Path)
|
| 86 |
+
* Extract input signatures from runtime arguments
|
| 87 |
+
* Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
|
| 88 |
+
* Validate graph consistency (not inconsistent/invalid)
|
| 89 |
+
* Reuse static tensors with dynamic slicing
|
| 90 |
+
* Replay graph and return sliced output
|
| 91 |
+
### Graph Capture (Slow Path)
|
| 92 |
+
Triggered when no valid cached graph exists or tensor expansion is needed:
|
| 93 |
+
* Execute function to get output tensors
|
| 94 |
+
* Ensure input signatures match post-warmup
|
| 95 |
+
* Expand static buffers if new shapes require it
|
| 96 |
+
* Capture new CUDA Graph with static tensors
|
| 97 |
+
* Store new graph and update tensor entries
|
| 98 |
+
* Return warmup execution output as final result
|
| 99 |
+
### Sequence Length Handling
|
| 100 |
+
* Only last dimension is static for ND tensors (ND > 1)
|
| 101 |
+
* All dimension is dynamic for 1D tensors (ND=1)
|
| 102 |
+
* Automatic buffer expansion for increasing sequence lengths
|
| 103 |
+
* Invalidates old graphs when tensors are expanded
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
## Examples
|
| 107 |
+
```python
|
| 108 |
+
import torch
|
| 109 |
+
import torch.nn as nn
|
| 110 |
+
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if
|
| 111 |
+
|
| 112 |
+
class SimpleTransformerLayer(nn.Module):
|
| 113 |
+
def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
|
| 116 |
+
self.linear = nn.Linear(hidden_dim, hidden_dim)
|
| 117 |
+
self.layer_norm = nn.LayerNorm(hidden_dim)
|
| 118 |
+
self.layer_number = 0
|
| 119 |
+
|
| 120 |
+
@cuda_graph_enable_if(lambda: torch.cuda.is_available())
|
| 121 |
+
def forward(self, x: torch.Tensor):
|
| 122 |
+
attn_out, _ = self.self_attn(x, x, x)
|
| 123 |
+
out = self.linear(self.layer_norm(x + attn_out))
|
| 124 |
+
return out
|
| 125 |
+
|
| 126 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 127 |
+
model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
|
| 128 |
+
graph_mgr = cuda_graph_mgr()
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
input_1 = torch.randn(2, 512, 1024, device=device)
|
| 132 |
+
output_1 = model(input_1)
|
| 133 |
+
print(f"First run (graph capture): Output shape = {output_1.shape}")
|
| 134 |
+
print(f"Cached graphs count: {graph_mgr.graph_count}")
|
| 135 |
+
|
| 136 |
+
input_2 = torch.randn(2, 512, 1024, device=device)
|
| 137 |
+
output_2 = model(input_2)
|
| 138 |
+
print(f"Second run (graph replay): Output shape = {output_2.shape}")
|
| 139 |
+
print(f"Cached graphs count: {graph_mgr.graph_count}")
|
| 140 |
+
|
| 141 |
+
input_3 = torch.randn(2, 1024, 1024, device=device)
|
| 142 |
+
output_3 = model(input_3)
|
| 143 |
+
print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
|
| 144 |
+
print(f"Cached graphs count: {graph_mgr.graph_count}")
|
| 145 |
+
print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")
|
| 146 |
+
|
| 147 |
+
print("\nCUDA Graph Cache Details:")
|
| 148 |
+
print(graph_mgr.formatted_cache_str())
|
| 149 |
+
|
| 150 |
+
# StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
|
| 151 |
+
# Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
|
| 152 |
+
# Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
|
| 153 |
+
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
|
| 154 |
+
# Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
|
| 155 |
+
# Layer 0: Graph Status: Invalid
|
| 156 |
+
# DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
|
| 157 |
+
# Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
|
| 158 |
+
# Layer 0: Graph Status: Valid
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## Limitations and Constraints
|
| 162 |
+
* No support for data-dependent control flow in captured functions
|
| 163 |
+
* Graph capture fails if function contains CPU/GPU synchronization
|
| 164 |
+
* Only supports CUDA tensors (CPU tensors trigger fallback)
|
| 165 |
+
* Custom input classes must inherit from InplaceSubstituteFakeClass
|
| 166 |
+
* Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
|
| 167 |
+
* Relies on identical function, input tensors shapes, and constants for valid graph reuse
|
| 168 |
+
* No support for multi-stream execution scenarios
|
| 169 |
+
|
| 170 |
+
## Best Practices
|
| 171 |
+
* Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
|
| 172 |
+
* Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
|
| 173 |
+
* Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
|
| 174 |
+
* LRU Cache (Future): Implement cache eviction to limit total graph/tensor count
|
pkgs/MagiCompiler/docs/Hunyuan15Benchmark.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Hunyuan1.5 Benchmark
|
| 2 |
+
|
| 3 |
+
### Executive Summary
|
| 4 |
+
This report presents a comprehensive performance evaluation of the **[Athena](https://github.com/world-sim-dev/athena)** framework compared to the baseline **[LightX2V](https://github.com/ModelTC/LightX2V)** framework. The benchmarks were conducted using the **[Hunyuan-1.5](https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5)** model on NVIDIA H100 hardware.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
### 🎯Test Environment & Versioning
|
| 8 |
+
#### Hardware & Settings
|
| 9 |
+
|
| 10 |
+
| Parameter | Value |
|
| 11 |
+
| ------------------- | -------------- |
|
| 12 |
+
| Hardware | NVIDIA H100 |
|
| 13 |
+
| Model | Hunyuan-1.5 480p_t2v_distilled |
|
| 14 |
+
| Precision | torch.bfloat16 |
|
| 15 |
+
| Inference Steps | 20 |
|
| 16 |
+
| Resolution | 480p |
|
| 17 |
+
| FPS | 24 |
|
| 18 |
+
| CFG | Disable |
|
| 19 |
+
#### Software Versioning
|
| 20 |
+
To ensure reproducibility, the following specific commits were used for this benchmark:
|
| 21 |
+
| Framework | Branch / Tag | Commit |
|
| 22 |
+
| --------- | ------------ | ------ |
|
| 23 |
+
| Athena | main|[5e6086b](https://github.com/world-sim-dev/athena/commit/5e6086b4dc2ab60bc4d44dbe39745b4354075121) |
|
| 24 |
+
| LightX2V | main | [5573905](https://github.com/ModelTC/LightX2V/commit/5573905f3f38d876d468b815f86d417a608975b6) |
|
| 25 |
+
|
| 26 |
+
### 🏆 Performance Benchmarks
|
| 27 |
+
📊 We compared the iteration speed (seconds per iteration) between Athena and LightX2V across three distinct Context Parallel (CP) configurations.
|
| 28 |
+
| Configuration | Frames | LightX2V (s/it) | Athena (s/it) | Speedup |
|
| 29 |
+
| ------------- | ------ | -------------- | -------------- | ------- |
|
| 30 |
+
| CP1 | 121 | 2.42 | **2.06** | **1.17x** 🚀|
|
| 31 |
+
| CP2 | 121 | 1.38 | **1.13** | **1.22x** 🚀|
|
| 32 |
+
| CP4 | 241 | 2.25 | **1.85** | **1.22x** 🚀|
|
| 33 |
+
| CP8 | 241 | 1.28 | **1.01** | **1.27x** 🚀|
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
### 📹 Output Comparison
|
| 37 |
+
| Framework | Video Result |
|
| 38 |
+
| --------- | ---------------------------- |
|
| 39 |
+
| Athena | <img src="../../../assets/athena_hunyuan_1_5_test_videos_20260213_155842_idx0_A_close-up815965.gif" width="480" /> |
|
| 40 |
+
| LightX2V | <img src="../../../assets/lightx2v_hunyuan_1_5_result_A_close-up122526.gif" width="480" /> |
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
### 💡 Reproduction Guide
|
| 44 |
+
To reproduce the results presented in this report, follow the steps below using the specified commit hashes.
|
| 45 |
+
#### Setup
|
| 46 |
+
```bash
|
| 47 |
+
git clone https://github.com/world-sim-dev/athena
|
| 48 |
+
cd athena
|
| 49 |
+
git checkout 5e6086b
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
pip install -r requirements-nodeps.txt
|
| 52 |
+
pip install -e ./pkgs/MagiCompiler --no-build-isolation --config-settings editable_mode=compat
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Clone and install LightX2V (for baseline comparison)
|
| 56 |
+
git clone https://github.com/ModelTC/LightX2V
|
| 57 |
+
cd lightx2v
|
| 58 |
+
git checkout 5573905
|
| 59 |
+
pip install -v .
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
#### Running Benchmarks
|
| 63 |
+
For Athena, run:
|
| 64 |
+
```
|
| 65 |
+
RESOLUTION=480p CFG_DISTILLED=true TASK=t2v CHECKPOINT_PATH=path/to/480p_t2v_distilled bash ./scripts/run_hunyuan.sh
|
| 66 |
+
```
|
| 67 |
+
For LightX2V:
|
| 68 |
+
Clone the scripts from [Benchmark for LightX2V](https://gist.github.com/wtr0504/d80bbebb7da1ef7b58f3e6faf1c68880) and run:
|
| 69 |
+
```
|
| 70 |
+
git clone https://gist.github.com/wtr0504/d80bbebb7da1ef7b58f3e6faf1c68880
|
| 71 |
+
MODEL_PATH=path/to/HunyuanVideo-1.5 DISTILL_CKPT=path/to/480p_t2v_distilled/diffusion_pytorch_model.safetensors bash run_hunyuan.sh
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### 🔎 MagiCompiler Optimization Methodology
|
| 75 |
+
**Whole Graph Compilation**
|
| 76 |
+
Constant Folding & Dead Code Elimination: Streamlining the computation graph prior to execution.
|
| 77 |
+
|
| 78 |
+
**Coarse-grained Kernel Fusion**
|
| 79 |
+
MagiCompiler aggregates multiple smaller operators into larger, fused kernels. This optimization is critical for efficient execution on the GPU.
|
pkgs/MagiCompiler/docs/Wan2.2Benchmark.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Wan2.2 Benchmark
|
| 2 |
+
|
| 3 |
+
### Executive Summary
|
| 4 |
+
This report presents a comprehensive performance evaluation of the **[Athena](https://github.com/world-sim-dev/athena)** framework compared to the baseline **[LightX2V](https://github.com/ModelTC/LightX2V)** framework. The benchmarks were conducted using the **[Wan2.2-TI2V-5B](https://huggingface.co/Wan-AI)** model on NVIDIA H100 hardware.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
### 🎯Test Environment & Versioning
|
| 8 |
+
#### Hardware & Settings
|
| 9 |
+
|
| 10 |
+
| Parameter | Value |
|
| 11 |
+
| ------------------- | -------------- |
|
| 12 |
+
| Hardware | NVIDIA H100 |
|
| 13 |
+
| Model | Wan2.2-TI2V-5B |
|
| 14 |
+
| Precision | torch.bfloat16 |
|
| 15 |
+
| Inference Steps | 50 |
|
| 16 |
+
| Resolution | 704 × 1280(720p)|
|
| 17 |
+
| FPS | 24 |
|
| 18 |
+
| CFG | Enabled |
|
| 19 |
+
#### Software Versioning
|
| 20 |
+
To ensure reproducibility, the following specific commits were used for this benchmark:
|
| 21 |
+
| Framework | Branch / Tag | Commit |
|
| 22 |
+
| --------- | ------------ | ------ |
|
| 23 |
+
| Athena | main|[f676ae6](https://github.com/world-sim-dev/athena/commit/f676ae64ad2fc581289d1c3ae5eb51c15ce76f1d) |
|
| 24 |
+
| LightX2V | main | [33f0f67](https://github.com/ModelTC/LightX2V/commit/33f0f67f4ecdff86b1db676d3e0786628cc31c7b) |
|
| 25 |
+
|
| 26 |
+
### 🏆 Performance Benchmarks
|
| 27 |
+
📊 We compared the iteration speed (seconds per iteration) between Athena and LightX2V across three distinct Context Parallel (CP) configurations.
|
| 28 |
+
| Configuration | Frames | LightX2V (s/it) | Athena (s/it) | Speedup |
|
| 29 |
+
| ------------- | ------ | -------------- | -------------- | ------- |
|
| 30 |
+
| CP1 | 121 | 1.928 | **1.69** | **1.14x** 🚀|
|
| 31 |
+
| CP2 | 121 | 1.197 | **1.06** | **1.13x** 🚀|
|
| 32 |
+
| CP4 | 241 | 1.767 | **1.32** | **1.34x** 🚀|
|
| 33 |
+
| CP8 | 241 | 1.507 | **1.35** | **1.12x** 🚀|
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
### 💡 Reproduction Guide
|
| 38 |
+
To reproduce the results presented in this report, follow the steps below using the specified commit hashes.
|
| 39 |
+
#### Setup
|
| 40 |
+
```bash
|
| 41 |
+
git clone https://github.com/world-sim-dev/athena
|
| 42 |
+
cd athena
|
| 43 |
+
git checkout f676ae6
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
|
| 46 |
+
# Clone and install LightX2V (for baseline comparison)
|
| 47 |
+
git clone https://github.com/ModelTC/LightX2V
|
| 48 |
+
cd lightx2v
|
| 49 |
+
git checkout 33f0f67
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
#### Running Benchmarks
|
| 55 |
+
For Athena, run:
|
| 56 |
+
```
|
| 57 |
+
bash ./scripts/run_wan2_2_ti2v_i2v.sh
|
| 58 |
+
```
|
| 59 |
+
For LightX2V:
|
| 60 |
+
Clone the scripts from [Benchmark for LightX2V](https://gist.github.com/wtr0504/629388f17ed38d1c12d5ef5c25a15197) and run:
|
| 61 |
+
```
|
| 62 |
+
git clone https://gist.github.com/wtr0504/629388f17ed38d1c12d5ef5c25a15197
|
| 63 |
+
bash run_wan.sh
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### 🔎 MagiCompiler Optimization Methodology
|
| 67 |
+
**Whole Graph Compilation**
|
| 68 |
+
Constant Folding & Dead Code Elimination: Streamlining the computation graph prior to execution.
|
| 69 |
+
**Coarse-grained Kernel Fusion**
|
| 70 |
+
MagiCompiler aggregates multiple smaller operators into larger, fused kernels. This optimization is critical for efficient execution on the GPU.
|
| 71 |
+
**All to All Communication**
|
| 72 |
+
MagiCompiler Uses ``all_to_all_single`` (1 communication op per sync point) while LightX2V Uses all_to_all x 3 (3 separate communication ops).
|
pkgs/MagiCompiler/docs/WhyMagiCompiler.md
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Why MagiCompiler?
|
| 2 |
+
|
| 3 |
+
## 1. Compiler Overview
|
| 4 |
+
|
| 5 |
+
### 1.1 Background
|
| 6 |
+
|
| 7 |
+
We have long encountered several significant challenges in model optimization:
|
| 8 |
+
|
| 9 |
+
1. **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance.
|
| 10 |
+
2. **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention.
|
| 11 |
+
3. **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap.
|
| 12 |
+
|
| 13 |
+
MagiCompiler addresses these issues through the following approaches:
|
| 14 |
+
|
| 15 |
+
* **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion.
|
| 16 |
+
* **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`.
|
| 17 |
+
* **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level.
|
| 18 |
+
|
| 19 |
+
#### Illustrative Example
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
from magi_compiler import magi_compile
|
| 23 |
+
|
| 24 |
+
@magi_compile()
|
| 25 |
+
class TinyModel(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.linear = nn.Linear(1024, 1024, device="cuda")
|
| 29 |
+
|
| 30 |
+
@no_grad()
|
| 31 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
return self.linear(x + y - z + 1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def magi_compiler_demo():
|
| 36 |
+
model = TinyModel()
|
| 37 |
+
x = torch.randn(1024, 1024, device="cuda")
|
| 38 |
+
y = torch.randn(1024, 1024, device="cuda")
|
| 39 |
+
z = torch.randn(1024, 1024, device="cuda")
|
| 40 |
+
model(x, y, z)
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
**Optimized Code (Triton Kernel):**
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', '''
|
| 47 |
+
import triton
|
| 48 |
+
import triton.language as tl
|
| 49 |
+
|
| 50 |
+
from torch._inductor.runtime import triton_helpers, triton_heuristics
|
| 51 |
+
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
|
| 52 |
+
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
|
| 53 |
+
triton_helpers.set_driver_to_gpu()
|
| 54 |
+
|
| 55 |
+
@triton_heuristics.pointwise(
|
| 56 |
+
size_hints={'x': 1048576},
|
| 57 |
+
filename=__file__,
|
| 58 |
+
triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
|
| 59 |
+
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
|
| 60 |
+
min_elem_per_thread=0
|
| 61 |
+
)
|
| 62 |
+
@triton.jit
|
| 63 |
+
def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
| 64 |
+
xoffset = tl.program_id(0) * XBLOCK
|
| 65 |
+
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
| 66 |
+
xmask = xindex < xnumel
|
| 67 |
+
x0 = xindex
|
| 68 |
+
tmp0 = tl.load(in_ptr0 + (x0), xmask)
|
| 69 |
+
tmp1 = tl.load(in_ptr1 + (x0), xmask)
|
| 70 |
+
tmp3 = tl.load(in_ptr2 + (x0), xmask)
|
| 71 |
+
tmp2 = tmp0 + tmp1
|
| 72 |
+
tmp4 = tmp2 - tmp3
|
| 73 |
+
tmp5 = 1.0
|
| 74 |
+
tmp6 = tmp4 + tmp5
|
| 75 |
+
tl.store(out_ptr0 + (x0), tmp6, xmask)
|
| 76 |
+
''', device_str='cuda')
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### 1.2 Frontend (Dynamo)
|
| 80 |
+
|
| 81 |
+

|
| 82 |
+
|
| 83 |
+
* **PyFrameObject (Dynamic Call Stack):**
|
| 84 |
+
* Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call.
|
| 85 |
+
* **PyCodeObject (Static Bytecode):**
|
| 86 |
+
* The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked.
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
def f(x, mod):
|
| 90 |
+
for guard, transformed_code in f.compiled_entries:
|
| 91 |
+
if guard(x, mod):
|
| 92 |
+
return transformed_code(x, mod)
|
| 93 |
+
try:
|
| 94 |
+
guard, transformed_code = compile_and_optimize(x, mod)
|
| 95 |
+
f.compiled_entries.append([guard, transformed_code])
|
| 96 |
+
return transformed_code(x, mod)
|
| 97 |
+
except FailToCompileError:
|
| 98 |
+
y = mod(x)
|
| 99 |
+
z = torch.log(y)
|
| 100 |
+
return z
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
#### Symbolic Shape
|
| 104 |
+
|
| 105 |
+
MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`).
|
| 106 |
+
|
| 107 |
+
**Example:**
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
@magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0})
|
| 111 |
+
class TinyModel(nn.Module):
|
| 112 |
+
def __init__(self):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.linear = nn.Linear(1024, 1024, device="cuda")
|
| 115 |
+
|
| 116 |
+
@no_grad()
|
| 117 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
|
| 118 |
+
return self.linear(x + y - z + 1)
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
**Guard Mechanism and Elimination in Symbolic Shape Deduction:**
|
| 122 |
+
|
| 123 |
+
```log
|
| 124 |
+
I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66
|
| 125 |
+
I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env
|
| 126 |
+
I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
|
| 127 |
+
I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
|
| 128 |
+
I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)"
|
| 129 |
+
I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo]
|
| 130 |
+
I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
|
| 131 |
+
I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)"
|
| 132 |
+
I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo]
|
| 133 |
+
I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
|
| 134 |
+
I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo]
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### 1.3 Backend (Inductor, MagiBackend, etc.)
|
| 138 |
+
|
| 139 |
+

|
| 140 |
+
|
| 141 |
+
MagiCompiler hijack the `torch.compile` logic through the following components:
|
| 142 |
+
|
| 143 |
+
* **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass.
|
| 144 |
+
* **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting).
|
| 145 |
+
* **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points.
|
| 146 |
+
|
| 147 |
+

|
| 148 |
+
|
| 149 |
+
* **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap).
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
## 2. Best Practices
|
| 154 |
+
|
| 155 |
+
### 2.1 Model Adaptation
|
| 156 |
+
|
| 157 |
+
MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios:
|
| 158 |
+
|
| 159 |
+
**1. Computational Graph Dependencies or CPU/GPU Synchronization**
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
@magi_compile
|
| 163 |
+
class MeanModule(torch.nn.Module):
|
| 164 |
+
def __init__(self):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 168 |
+
x = x.cos().sin()
|
| 169 |
+
if x.mean() > 0.5:
|
| 170 |
+
x = x - 1
|
| 171 |
+
return x * y
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
> **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload.
|
| 175 |
+
|
| 176 |
+
**2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)**
|
| 177 |
+
|
| 178 |
+
* **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered.
|
| 179 |
+
|
| 180 |
+
```python
|
| 181 |
+
# Operator Registration
|
| 182 |
+
@torch.library.custom_op("athena::flash_attn_func", mutates_args=())
|
| 183 |
+
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
...
|
| 185 |
+
|
| 186 |
+
# Operator Deduce Function
|
| 187 |
+
@flash_attn_func.register_fake
|
| 188 |
+
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 189 |
+
return torch.empty_like(query)
|
| 190 |
+
|
| 191 |
+
# Call flash_attn_func
|
| 192 |
+
self_attn_out = torch.ops.athena.flash_attn_func(q, k, v)
|
| 193 |
+
out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges)
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
* **Unit Testing:** Independent unit tests for operators should be provided in the production environment.
|
| 197 |
+
|
| 198 |
+
```python
|
| 199 |
+
@pytest.mark.parametrize("batch_size", [1])
|
| 200 |
+
@pytest.mark.parametrize("seq_len", [1024, 2048, 4096])
|
| 201 |
+
@pytest.mark.parametrize("query_head", [48])
|
| 202 |
+
@pytest.mark.parametrize("kv_head", [4, 8])
|
| 203 |
+
@pytest.mark.parametrize("head_dim", [128, 256])
|
| 204 |
+
def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim):
|
| 205 |
+
q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16)
|
| 206 |
+
k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
|
| 207 |
+
v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
|
| 208 |
+
torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v))
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
### 2.2 Debugging Methods
|
| 212 |
+
|
| 213 |
+
Key questions for debugging:
|
| 214 |
+
* Is the bug originating from the compiler?
|
| 215 |
+
* Which specific component of the compiler is causing the bug?
|
| 216 |
+
|
| 217 |
+

|
| 218 |
+
|
| 219 |
+
```python
|
| 220 |
+
class CompileConfig(BaseModel):
|
| 221 |
+
# Basic configs
|
| 222 |
+
backend: str = Field("inductor", description="Compilation backend.")
|
| 223 |
+
compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
|
| 224 |
+
...
|
| 225 |
+
|
| 226 |
+
# Cudagraph configs
|
| 227 |
+
cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
|
| 228 |
+
...
|
| 229 |
+
|
| 230 |
+
# Pass configs
|
| 231 |
+
pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
|
| 232 |
+
...
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
### 2.3 Profiling Results
|
| 236 |
+
|
| 237 |
+
For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md).
|
| 238 |
+
|
| 239 |
+
---
|
| 240 |
+
|
| 241 |
+
## References
|
| 242 |
+
|
| 243 |
+
1. [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf)
|
| 244 |
+
2. [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)
|
| 245 |
+
3. [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html)
|
| 246 |
+
4. [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html)
|
pkgs/MagiCompiler/docs/WhyMagiDepyf.md
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# magi_depyf
|
| 2 |
+
|
| 3 |
+
A structured inspector for `torch.compile` and MagiCompiler compilation
|
| 4 |
+
artifacts — decompiled source, Inductor kernels, guard conditions, graph break
|
| 5 |
+
chains, and more — all organized in a navigable directory tree.
|
| 6 |
+
|
| 7 |
+
## Why
|
| 8 |
+
|
| 9 |
+
### The problem: compilation is a black box
|
| 10 |
+
|
| 11 |
+
`torch.compile` and MagiCompiler accelerate models by transforming Python
|
| 12 |
+
functions through a deep pipeline: Dynamo captures bytecode into FX graphs,
|
| 13 |
+
a backend (Inductor, etc.) compiles them into optimized kernels, and the
|
| 14 |
+
runtime dispatches through a chain of cache entries, compiled functions, and
|
| 15 |
+
resume functions. The result is fast — but opaque.
|
| 16 |
+
|
| 17 |
+
When something goes wrong — a correctness bug, an unexpected graph break, a
|
| 18 |
+
performance cliff — you need to see what the compiler actually produced.
|
| 19 |
+
What bytecode did Dynamo generate? Which subgraphs went to Inductor vs.
|
| 20 |
+
eager fallback? What do the kernels look like? How do resume functions
|
| 21 |
+
chain together? MagiCompiler adds further layers: CUDA graph capture
|
| 22 |
+
regions, piecewise subgraph splits, and its own dispatch logic.
|
| 23 |
+
|
| 24 |
+
None of this is easily accessible.
|
| 25 |
+
|
| 26 |
+
### depyf: a pioneering effort
|
| 27 |
+
|
| 28 |
+
[depyf](https://github.com/thuml/depyf) was the first tool to address this,
|
| 29 |
+
hooking into `torch._dynamo` to dump decompiled source, FX graphs, and
|
| 30 |
+
Inductor output. It made `torch.compile` significantly more transparent.
|
| 31 |
+
|
| 32 |
+
### Why a new tool?
|
| 33 |
+
|
| 34 |
+
magi_depyf is purpose-built for MagiCompiler's compilation stack, and takes
|
| 35 |
+
a fundamentally different approach from depyf:
|
| 36 |
+
|
| 37 |
+
| | depyf | magi_depyf |
|
| 38 |
+
|-|-------|------------|
|
| 39 |
+
| **When artifacts are collected** | During compilation, via monkey-patching internal hooks | **After** compilation completes, by walking the final CacheEntry chain — a single, clean post-hoc pass |
|
| 40 |
+
| **Output structure** | Flat files (`full_code_0.py`, `__transformed_code_0_for_xxx.py`, …) — hard to navigate for complex models | **Hierarchical directory tree** mirroring the compilation structure: function → entries → compiled\_fns / resume\_fns |
|
| 41 |
+
| **MagiCompiler support** | None | First-class: per-subgraph Inductor source, CUDA graph mode, piecewise split metadata |
|
| 42 |
+
| **Decompiler** | Monolithic class supporting Python 3.8–3.12 | Modular handler registry; focused on 3.10+ |
|
| 43 |
+
|
| 44 |
+
### Key features
|
| 45 |
+
|
| 46 |
+
**See everything the compiler produced, in one structured tree.**
|
| 47 |
+
One context manager call gives you a complete, navigable dump: decompiled
|
| 48 |
+
bytecode (before and after Dynamo), Inductor kernel source for every compiled
|
| 49 |
+
function, guard conditions, bytecode metadata (`co_flags`, `co_consts`,
|
| 50 |
+
`dis` output), and the full resume function chain — recursively.
|
| 51 |
+
|
| 52 |
+
**MagiCompiler-native.**
|
| 53 |
+
Understands MagiCompiler's backend, extracting per-subgraph Inductor source,
|
| 54 |
+
CUDA graph capture mode (full / piecewise), and split metadata that
|
| 55 |
+
`torch.compile`-only tools cannot see.
|
| 56 |
+
|
| 57 |
+
**Post-hoc introspection.**
|
| 58 |
+
Artifacts are collected after compilation finishes, by walking the CacheEntry
|
| 59 |
+
linked list and extracting what Dynamo and the backend actually produced.
|
| 60 |
+
No monkey-patching of internal compilation hooks, no interference with the
|
| 61 |
+
compilation process itself.
|
| 62 |
+
|
| 63 |
+
## Usage
|
| 64 |
+
|
| 65 |
+
### `dump_src` — the main entry point
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
import torch
|
| 69 |
+
from magi_compiler.magi_depyf.inspect import dump_src
|
| 70 |
+
|
| 71 |
+
@torch.compile
|
| 72 |
+
def toy_example(a, b):
|
| 73 |
+
x = a / (torch.abs(a) + 1)
|
| 74 |
+
if b.sum() < 0:
|
| 75 |
+
b = b * -1
|
| 76 |
+
return x * b
|
| 77 |
+
|
| 78 |
+
with dump_src("./output"):
|
| 79 |
+
for _ in range(100):
|
| 80 |
+
toy_example(torch.randn(10), torch.randn(10))
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
This produces:
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
output/
|
| 87 |
+
toy_example/
|
| 88 |
+
overview.md # Navigable index with links to everything
|
| 89 |
+
decompiled_code.py # Original function source
|
| 90 |
+
bytecode_info.txt # CodeType metadata + dis output
|
| 91 |
+
entry_0/
|
| 92 |
+
decompiled_code.py # Dynamo-transformed bytecode → Python
|
| 93 |
+
bytecode_info.txt # Transformed code metadata
|
| 94 |
+
guards.txt # Guard conditions for this cache entry
|
| 95 |
+
compiled_fns/
|
| 96 |
+
__compiled_fn_1_xxx.py # FX graph (readable)
|
| 97 |
+
__compiled_fn_1_xxx_post_grad.py # Post-grad graph
|
| 98 |
+
__compiled_fn_1_xxx_runnable.py # Inductor kernel source
|
| 99 |
+
resume_fns/
|
| 100 |
+
__resume_at_94_2/ # Resume function after graph break
|
| 101 |
+
overview.md
|
| 102 |
+
decompiled_code.py # Resume function source
|
| 103 |
+
bytecode_info.txt
|
| 104 |
+
entry_0/ # Dynamo compiles resume fns too
|
| 105 |
+
decompiled_code.py
|
| 106 |
+
guards.txt
|
| 107 |
+
compiled_fns/
|
| 108 |
+
...
|
| 109 |
+
__resume_at_104_3/
|
| 110 |
+
...
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### Programmatic API
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
from magi_compiler.magi_depyf import decompile
|
| 117 |
+
|
| 118 |
+
# Decompile a code object to Python source
|
| 119 |
+
source = decompile(my_function.__code__)
|
| 120 |
+
|
| 121 |
+
# Introspect a compiled function
|
| 122 |
+
from magi_compiler.magi_depyf.inspect import Introspector
|
| 123 |
+
info = Introspector.build_function_info(fn, fn_globals=fn.__globals__)
|
| 124 |
+
# info.entries[0].decompiled_src — decompiled transformed code
|
| 125 |
+
# info.entries[0].compiled_fns — backend-compiled functions
|
| 126 |
+
# info.entries[0].resume_fns — resume functions after graph breaks
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Tested model architectures
|
| 130 |
+
|
| 131 |
+
The test suite verifies the decompile → recompile round-trip on real model
|
| 132 |
+
structures, ensuring the decompiler produces correct source for Dynamo output:
|
| 133 |
+
|
| 134 |
+
| Category | Models |
|
| 135 |
+
|----------|--------|
|
| 136 |
+
| **PyTorch core** | MLP, Conv-BN-ReLU, MultiheadAttention, TransformerEncoderLayer, Embedding, residual blocks, depthwise separable conv |
|
| 137 |
+
| **Diffusion blocks** | GEGLU, RMSNorm, sinusoidal embeddings, cross-attention, AdaLayerNorm, DiT blocks, timestep MLP |
|
| 138 |
+
| **HuggingFace transformers** | BERT, GPT-2, T5 encoder (tiny configs) |
|
| 139 |
+
| **HuggingFace diffusers** | Attention (self / cross), BasicTransformerBlock |
|
| 140 |
+
| **timm** | ResNet-18, MobileNetV3, EfficientNet-B0, ViT, ConvNeXt, Swin, DeiT |
|
| 141 |
+
| **Graph breaks** | `print()` breaks, explicit `graph_break()`, multi-break chains — with recursive resume function round-tripping |
|
| 142 |
+
|
| 143 |
+
## Code structure
|
| 144 |
+
|
| 145 |
+
```
|
| 146 |
+
magi_depyf/
|
| 147 |
+
├── __init__.py # Public API: decompile, safe_decompile
|
| 148 |
+
│
|
| 149 |
+
├── decompile/ # Bytecode → Python source (no torch dependency)
|
| 150 |
+
│ ├── decompiler.py # Decompiler: orchestrates the pipeline
|
| 151 |
+
│ ├── recompiler.py # CodeRecompiler: decompile → compile() → CodeType
|
| 152 |
+
│ ├── bytecode/
|
| 153 |
+
│ │ ├── instruction.py # Mutable wrapper over dis.Instruction
|
| 154 |
+
│ │ ├── source_emitter.py # Stack machine + source accumulator
|
| 155 |
+
│ │ ├── decompile_context.py # Read-only context for handlers
|
| 156 |
+
│ │ ├── handler_registry.py # Opcode → handler dispatch table
|
| 157 |
+
│ │ └── handlers/ # One module per opcode category
|
| 158 |
+
│ └── postprocess/ # Source-level cleanup passes
|
| 159 |
+
│
|
| 160 |
+
└── inspect/ # torch.compile introspection (requires torch)
|
| 161 |
+
├── dump_src.py # dump_src(): the main entry point
|
| 162 |
+
├── introspect.py # Introspector: walk CacheEntry chain
|
| 163 |
+
├── model.py # Data model (FunctionInfo, EntryInfo, ...)
|
| 164 |
+
├── writer.py # Serialize to directory tree
|
| 165 |
+
├── session.py # CaptureSession: bytecode hook lifecycle
|
| 166 |
+
└── result.py # CaptureResult: one compilation event
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
## Compatibility
|
| 170 |
+
|
| 171 |
+
| Requirement | Version |
|
| 172 |
+
|-------------|---------|
|
| 173 |
+
| **Python** | >= 3.10 |
|
| 174 |
+
| **PyTorch** | >= 2.0 (requires `torch._dynamo` internals) |
|
| 175 |
+
| **depyf** | Optional; used as fallback by `safe_decompile` |
|
pkgs/MagiCompiler/docs/assets/submod_0_rank_0.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:583d97460eb7ebf48efbdeb7a6ae424f640de9ff99e6f33bbadef432583f40d3
|
| 3 |
+
size 16122
|
pkgs/MagiCompiler/magi_compiler/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .api import magi_compile
|
| 16 |
+
|
| 17 |
+
__all__ = ["magi_compile"]
|
pkgs/MagiCompiler/magi_compiler/_cache_data_cls.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import dataclasses
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclasses.dataclass(frozen=True)
|
| 19 |
+
class CacheHandle:
|
| 20 |
+
key: str | None
|
| 21 |
+
path: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclasses.dataclass(frozen=True)
|
| 25 |
+
class CacheEntry:
|
| 26 |
+
runtime_shape: int | None
|
| 27 |
+
graph_index: int
|
| 28 |
+
backend_name: str
|
pkgs/MagiCompiler/magi_compiler/api.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import gc
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
from contextlib import contextmanager
|
| 20 |
+
from typing import Callable, TypeVar, get_args, get_origin, overload
|
| 21 |
+
from unittest.mock import patch
|
| 22 |
+
|
| 23 |
+
import magi_compiler.utils.envs as envs
|
| 24 |
+
import torch
|
| 25 |
+
from magi_compiler.cuda.cudart import pin_memory_in_place
|
| 26 |
+
from magi_compiler.magi_compiler_base import MagiCompilerBase
|
| 27 |
+
from magi_compiler.utils import compilation_counter, magi_logger
|
| 28 |
+
from magi_compiler.utils.compile_time_monitor import CompileMonitor
|
| 29 |
+
from torch import distributed as dist
|
| 30 |
+
from torch import nn
|
| 31 |
+
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
| 32 |
+
|
| 33 |
+
from .config import CompileConfig, CompileMode, get_compile_config
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# =============================================================================
|
| 37 |
+
# Workaround: TorchInductor autotune get_raw_stream
|
| 38 |
+
# =============================================================================
|
| 39 |
+
# TorchInductor autotune code blocks may reference get_raw_stream() without
|
| 40 |
+
# defining it, causing "name 'get_raw_stream' is not defined" at runtime.
|
| 41 |
+
# Register it as a builtin so the exec'd autotune snippets can always find it.
|
| 42 |
+
def _patch_get_raw_stream():
|
| 43 |
+
try:
|
| 44 |
+
import builtins
|
| 45 |
+
|
| 46 |
+
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
|
| 47 |
+
except Exception:
|
| 48 |
+
return
|
| 49 |
+
if not hasattr(builtins, "get_raw_stream"):
|
| 50 |
+
builtins.get_raw_stream = _get_raw_stream
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_patch_get_raw_stream()
|
| 54 |
+
|
| 55 |
+
# =============================================================================
|
| 56 |
+
# Dynamo Config Isolation
|
| 57 |
+
# =============================================================================
|
| 58 |
+
# Capture the default dynamo config at module load time (before any torch.compile).
|
| 59 |
+
# This ensures we have a "clean" baseline config that hasn't been modified by
|
| 60 |
+
# external torch.compile calls (e.g., with dynamic=True).
|
| 61 |
+
_DEFAULT_DYNAMO_CONFIG: dict = torch._dynamo.config.get_config_copy()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@contextmanager
|
| 65 |
+
def _isolated_dynamo_config():
|
| 66 |
+
"""
|
| 67 |
+
Context manager that provides an isolated dynamo config environment.
|
| 68 |
+
"""
|
| 69 |
+
with torch._dynamo.config.patch(**_DEFAULT_DYNAMO_CONFIG):
|
| 70 |
+
yield
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
_T = TypeVar("_T", bound=type[nn.Module])
|
| 74 |
+
_W = TypeVar("_W", bound="MagiCompilerBase")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@overload
|
| 78 |
+
def magi_compile(*, enable_if: Callable[None, bool] | None = None) -> Callable[[_T], _T]:
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@overload
|
| 83 |
+
def magi_compile(*, dynamic_arg_dims: dict[str, int | list[int]] | None) -> Callable[[_T], _T]:
|
| 84 |
+
...
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@overload
|
| 88 |
+
def magi_compile(*, config_patch: Callable[[CompileConfig], CompileConfig] | None = None) -> Callable[[_T], _T]:
|
| 89 |
+
...
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@overload
|
| 93 |
+
def magi_compile(cls: _T) -> _T:
|
| 94 |
+
...
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def magi_compile(
|
| 98 |
+
cls: _T | None = None,
|
| 99 |
+
*,
|
| 100 |
+
model_tag: str | None = None,
|
| 101 |
+
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
| 102 |
+
enable_if: Callable[None, bool] | None = None,
|
| 103 |
+
config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
|
| 104 |
+
) -> Callable[[_T], _T] | _T:
|
| 105 |
+
"""
|
| 106 |
+
A decorator to add support for compiling the forward method of a class.
|
| 107 |
+
|
| 108 |
+
Usage:
|
| 109 |
+
1. use directly as a decorator without arguments:
|
| 110 |
+
```python
|
| 111 |
+
@magi_compile
|
| 112 |
+
class MyModel(nn.Module):
|
| 113 |
+
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
2. use as a decorator with arguments:
|
| 117 |
+
```python
|
| 118 |
+
@magi_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
| 119 |
+
class MyModel(nn.Module):
|
| 120 |
+
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
- model_tag: optional tag in cache path (e.g. "wan_ti2v"). If not set, class name is used.
|
| 125 |
+
Path segment: model_{idx}_{model_tag}_rank_{rank}.
|
| 126 |
+
- dynamic_arg_dims: a dictionary that maps argument names to the dynamic
|
| 127 |
+
dimensions of the argument. The dynamic dimensions can be either a single
|
| 128 |
+
integer or a list of integers.
|
| 129 |
+
- enable_if: a function that returns a boolean value indicating whether to compile the model or not.
|
| 130 |
+
This is useful if you want to compile the model only when certain conditions are met.
|
| 131 |
+
|
| 132 |
+
Notes:
|
| 133 |
+
- dynamic_arg_dims will be inferred from the type annotation of the forward method if not provided,
|
| 134 |
+
if the argument is annotated as `torch.Tensor` or `Optional[torch.Tensor]`,
|
| 135 |
+
the first dimension will be marked as dynamic.
|
| 136 |
+
|
| 137 |
+
- if an argument is `None`, it should always be passed as `None` during
|
| 138 |
+
the lifetime of the model, otherwise, it cannot be captured as a single
|
| 139 |
+
computation graph.
|
| 140 |
+
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def cls_decorator_helper(cls: _T) -> _T:
|
| 144 |
+
nonlocal dynamic_arg_dims
|
| 145 |
+
dynamic_arg_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(cls)
|
| 146 |
+
|
| 147 |
+
# Accuracy check
|
| 148 |
+
assert hasattr(cls, "forward"), "decorated class should have a forward method."
|
| 149 |
+
assert len(dynamic_arg_dims) > 0, (
|
| 150 |
+
"No dynamic dimensions found in the forward method of " f"{cls}. Please provide dynamic_arg_dims explicitly."
|
| 151 |
+
)
|
| 152 |
+
for k in dynamic_arg_dims:
|
| 153 |
+
assert k in inspect.signature(cls.forward).parameters, f"Argument {k} not found in the forward method of {cls}"
|
| 154 |
+
|
| 155 |
+
return _magi_compile(cls, dynamic_arg_dims, enable_if, config_patch, model_tag=model_tag)
|
| 156 |
+
|
| 157 |
+
if cls is not None:
|
| 158 |
+
# use `magi_compile` as a decorator without arguments, cls is the class to be decorated
|
| 159 |
+
assert isinstance(cls, type)
|
| 160 |
+
return cls_decorator_helper(cls)
|
| 161 |
+
|
| 162 |
+
return cls_decorator_helper
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def offload(obj):
|
| 166 |
+
if isinstance(obj, torch.Tensor):
|
| 167 |
+
return obj.cpu()
|
| 168 |
+
elif isinstance(obj, dict):
|
| 169 |
+
return {k: offload(v) for k, v in obj.items()}
|
| 170 |
+
elif isinstance(obj, (list, tuple)):
|
| 171 |
+
return type(obj)(offload(item) for item in obj)
|
| 172 |
+
return obj
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _magi_compile(
|
| 176 |
+
cls: _T,
|
| 177 |
+
dynamic_arg_dims: dict[str, int | list[int]],
|
| 178 |
+
enable_if: Callable[None, bool] | None = None,
|
| 179 |
+
config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
|
| 180 |
+
model_tag: str | None = None,
|
| 181 |
+
) -> _T:
|
| 182 |
+
"""
|
| 183 |
+
A decorator to add support for compiling the forward method of a class.
|
| 184 |
+
"""
|
| 185 |
+
if MagiCompilerBase in cls.__bases__:
|
| 186 |
+
return cls
|
| 187 |
+
|
| 188 |
+
# take care of method resolution order, make sure super().__init__ is called on the base class
|
| 189 |
+
# other than MagiCompilerBase
|
| 190 |
+
cls.__bases__ = cls.__bases__ + (MagiCompilerBase,)
|
| 191 |
+
|
| 192 |
+
if get_compile_config().offload_config.model_cpu_offload:
|
| 193 |
+
magi_logger.info(f"Enabling CPU offload for {cls}")
|
| 194 |
+
_orig_apply = cls._apply
|
| 195 |
+
|
| 196 |
+
def _cpu_apply(self, fn):
|
| 197 |
+
if getattr(self, "_magi_offloaded_once", False):
|
| 198 |
+
return _orig_apply(self, fn)
|
| 199 |
+
|
| 200 |
+
# First, move all parameters/buffers to CPU
|
| 201 |
+
def _force_cpu(t):
|
| 202 |
+
return fn(t).cpu()
|
| 203 |
+
|
| 204 |
+
_orig_apply(self, _force_cpu)
|
| 205 |
+
|
| 206 |
+
# create shared memory tensors for all parameters/buffers on CPU
|
| 207 |
+
if dist.is_initialized():
|
| 208 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 209 |
+
full_state_dict = self.state_dict()
|
| 210 |
+
|
| 211 |
+
grouped_params = {} # {dtype: [(name, tensor), ...]}
|
| 212 |
+
for name, tensor in full_state_dict.items():
|
| 213 |
+
if tensor.device.type == 'cpu':
|
| 214 |
+
dt = tensor.dtype
|
| 215 |
+
if dt not in grouped_params:
|
| 216 |
+
grouped_params[dt] = []
|
| 217 |
+
grouped_params[dt].append((name, tensor))
|
| 218 |
+
|
| 219 |
+
shared_state_dict = {}
|
| 220 |
+
self._magi_giant_buffers = []
|
| 221 |
+
|
| 222 |
+
dist.barrier()
|
| 223 |
+
|
| 224 |
+
for dtype, param_list in grouped_params.items():
|
| 225 |
+
dtype_str = str(dtype).split('.')[-1]
|
| 226 |
+
shared_bin_path = (
|
| 227 |
+
f"{envs.MAGI_SHARED_BIN_PATH}/magi_model_shared_{dtype_str}_{self.__class__.__name__}.bin"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
total_numel = sum(t.numel() for _, t in param_list)
|
| 231 |
+
|
| 232 |
+
if local_rank == 0:
|
| 233 |
+
flat_buffer = torch.zeros(total_numel, dtype=dtype)
|
| 234 |
+
offset = 0
|
| 235 |
+
for _, tensor in param_list:
|
| 236 |
+
numel = tensor.numel()
|
| 237 |
+
flat_buffer[offset : offset + numel].copy_(tensor.view(-1))
|
| 238 |
+
offset += numel
|
| 239 |
+
|
| 240 |
+
if dtype == torch.bfloat16:
|
| 241 |
+
flat_buffer.view(torch.int16).numpy().tofile(shared_bin_path)
|
| 242 |
+
elif dtype.itemsize == 1 and dtype.is_floating_point:
|
| 243 |
+
# fp8
|
| 244 |
+
flat_buffer.view(torch.uint8).numpy().tofile(shared_bin_path)
|
| 245 |
+
else:
|
| 246 |
+
flat_buffer.numpy().tofile(shared_bin_path)
|
| 247 |
+
|
| 248 |
+
del flat_buffer
|
| 249 |
+
gc.collect()
|
| 250 |
+
|
| 251 |
+
dist.barrier()
|
| 252 |
+
|
| 253 |
+
giant_shared_tensor = torch.from_file(
|
| 254 |
+
shared_bin_path, shared=True, size=total_numel, dtype=dtype, device="cpu"
|
| 255 |
+
)
|
| 256 |
+
self._magi_giant_buffers.append(giant_shared_tensor)
|
| 257 |
+
|
| 258 |
+
pin_memory_in_place(giant_shared_tensor)
|
| 259 |
+
|
| 260 |
+
offset = 0
|
| 261 |
+
for name, original_tensor in param_list:
|
| 262 |
+
numel = original_tensor.numel()
|
| 263 |
+
shared_param = giant_shared_tensor[offset : offset + numel].view(original_tensor.shape)
|
| 264 |
+
|
| 265 |
+
if original_tensor.requires_grad:
|
| 266 |
+
shared_param.requires_grad_(True)
|
| 267 |
+
|
| 268 |
+
shared_state_dict[name] = shared_param
|
| 269 |
+
offset += numel
|
| 270 |
+
|
| 271 |
+
dist.barrier()
|
| 272 |
+
if local_rank == 0:
|
| 273 |
+
if os.path.exists(shared_bin_path):
|
| 274 |
+
os.remove(shared_bin_path)
|
| 275 |
+
|
| 276 |
+
self.load_state_dict(shared_state_dict, assign=True)
|
| 277 |
+
|
| 278 |
+
else:
|
| 279 |
+
|
| 280 |
+
def _pinner(t):
|
| 281 |
+
return t.pin_memory()
|
| 282 |
+
|
| 283 |
+
_orig_apply(self, _pinner)
|
| 284 |
+
|
| 285 |
+
self._magi_offloaded_once = True
|
| 286 |
+
return self
|
| 287 |
+
|
| 288 |
+
cls._apply = _cpu_apply
|
| 289 |
+
|
| 290 |
+
old_init = cls.__init__
|
| 291 |
+
|
| 292 |
+
def __init__(self: _W, *args, **kwargs):
|
| 293 |
+
old_init(self, *args, **kwargs)
|
| 294 |
+
compile_config = get_compile_config()
|
| 295 |
+
if config_patch is not None:
|
| 296 |
+
compile_config = config_patch(compile_config)
|
| 297 |
+
# deepcopy the compile config to avoid modifying the original compile config
|
| 298 |
+
self.compile_config = copy.deepcopy(compile_config)
|
| 299 |
+
|
| 300 |
+
enable_compile = enable_if is None or enable_if()
|
| 301 |
+
self.enable_compile = self.compile_config.compile_mode != CompileMode.NONE and enable_compile
|
| 302 |
+
if not self.enable_compile:
|
| 303 |
+
return
|
| 304 |
+
|
| 305 |
+
compilation_counter.num_models_seen += 1
|
| 306 |
+
self.compile_config.model_idx = compilation_counter.num_models_seen
|
| 307 |
+
self.compile_config.model_tag = model_tag if model_tag is not None else self.__class__.__name__
|
| 308 |
+
MagiCompilerBase.__init__(self, compile_config=self.compile_config)
|
| 309 |
+
|
| 310 |
+
cls.__init__ = __init__
|
| 311 |
+
|
| 312 |
+
old_call = cls.__call__
|
| 313 |
+
|
| 314 |
+
def __call__(self: _W, *args, **kwargs):
|
| 315 |
+
### Step1: Run compiled module directly if disable compile or captured before ###
|
| 316 |
+
if self.compile_config.offload_config.model_cpu_offload and self.compiled_code is None:
|
| 317 |
+
args = offload(args)
|
| 318 |
+
kwargs = offload(kwargs)
|
| 319 |
+
|
| 320 |
+
if not self.enable_compile or torch.compiler.is_compiling():
|
| 321 |
+
# Skip compiling the model if inside the compilation process.
|
| 322 |
+
return old_call(self, *args, **kwargs)
|
| 323 |
+
|
| 324 |
+
if self.compiled_code is not None:
|
| 325 |
+
# Run the compiled function if compiled code is available.
|
| 326 |
+
with self.dispatch_to_compiled_fwd(mode="jit"):
|
| 327 |
+
return old_call(self, *args, **kwargs)
|
| 328 |
+
|
| 329 |
+
if envs.MAGI_AOT_COMPILE:
|
| 330 |
+
# Try load AOT artifacts from cache and run directly.
|
| 331 |
+
self.aot_compiled_fn = self.try_load_aot_compile_artifacts()
|
| 332 |
+
if self.aot_compiled_fn is not None:
|
| 333 |
+
with self.dispatch_to_compiled_fwd(mode="aot"):
|
| 334 |
+
return old_call(self, *args, **kwargs)
|
| 335 |
+
|
| 336 |
+
### Step2: Mark dynamic shapes for the first compilation ###
|
| 337 |
+
bound_args = inspect.signature(self.__class__.forward).bind(self, *args, **kwargs)
|
| 338 |
+
bound_args.apply_defaults()
|
| 339 |
+
for k, dims in dynamic_arg_dims.items():
|
| 340 |
+
arg = bound_args.arguments.get(k)
|
| 341 |
+
if arg is None:
|
| 342 |
+
continue
|
| 343 |
+
dims = [dims] if isinstance(dims, int) else dims
|
| 344 |
+
assert isinstance(arg, torch.Tensor), f"Unsupported dynamic dim {dims} for argument {k} with type {type(arg)}."
|
| 345 |
+
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
| 346 |
+
torch._dynamo.mark_dynamic(arg, dims)
|
| 347 |
+
|
| 348 |
+
### Step3: Start compiling the model ###
|
| 349 |
+
magi_logger.info(f"Start compiling function {self.original_code_object}")
|
| 350 |
+
|
| 351 |
+
CompileMonitor().start(
|
| 352 |
+
self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path()
|
| 353 |
+
)
|
| 354 |
+
# Dynamo reuse the compilation across instances, but we need to make sure the compiled code is not reused.
|
| 355 |
+
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
|
| 356 |
+
|
| 357 |
+
with (
|
| 358 |
+
_hijack_inline_call_to_collect_traced_files(self),
|
| 359 |
+
patch.object(torch.compiler.config, "dynamic_sources", self.compile_config.dynamic_sources),
|
| 360 |
+
patch.object(torch._dynamo.config, "enable_cpp_symbolic_shape_guards", False),
|
| 361 |
+
# 允许 mark_dynamic 在 module 属性链上的 tensor 生效
|
| 362 |
+
# (默认 True 会强制 module property tensor 为 static shape,忽略 mark_dynamic)
|
| 363 |
+
patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False),
|
| 364 |
+
patch.dict(
|
| 365 |
+
os.environ, {"TORCHINDUCTOR_CACHE_DIR": (self.compile_config.cache_dump_path() / "inductor_cache").as_posix()}
|
| 366 |
+
),
|
| 367 |
+
):
|
| 368 |
+
if envs.MAGI_AOT_COMPILE:
|
| 369 |
+
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
| 370 |
+
self.aot_compiled_fn.save_compiled_function(self.aot_compilation_path)
|
| 371 |
+
with self.dispatch_to_compiled_fwd(mode="aot"):
|
| 372 |
+
output = old_call(self, *args, **kwargs)
|
| 373 |
+
else:
|
| 374 |
+
with patch.object(self, "forward", self.jit_compile):
|
| 375 |
+
output = old_call(self, *args, **kwargs)
|
| 376 |
+
|
| 377 |
+
return output
|
| 378 |
+
|
| 379 |
+
# 使用 @torch.compiler.disable 和 _isolated_dynamo_config 包裹整个 __call__
|
| 380 |
+
# 确保 magi compile 在外部嵌套 torch.compile 时也能独立工作不受影响
|
| 381 |
+
isolated_call = _isolated_dynamo_config()(__call__)
|
| 382 |
+
cls.__call__ = torch.compiler.disable(isolated_call)
|
| 383 |
+
return cls
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Collect all relevant files traced by Dynamo, re-compile the model when any of these files change.
|
| 387 |
+
# 1. the file containing the top-level forward function
|
| 388 |
+
# 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline
|
| 389 |
+
# the function by calling InliningInstructionTranslator.inline_call_
|
| 390 |
+
def _hijack_inline_call_to_collect_traced_files(owner: _W):
|
| 391 |
+
owner.compile_config.traced_files.add(owner.original_code_object.co_filename)
|
| 392 |
+
inline_call = InliningInstructionTranslator.inline_call_
|
| 393 |
+
|
| 394 |
+
def patched_inline_call(self_):
|
| 395 |
+
code = self_.f_code
|
| 396 |
+
owner.compile_config.traced_files.add(code.co_filename)
|
| 397 |
+
return inline_call(self_)
|
| 398 |
+
|
| 399 |
+
return patch.object(InliningInstructionTranslator, "inline_call_", patched_inline_call)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _infer_dynamic_arg_dims(cls: _T) -> dict[str, int | list[int]]:
|
| 403 |
+
sig = inspect.signature(cls.forward)
|
| 404 |
+
inferred_dynamic_arg_dims = {}
|
| 405 |
+
for k, v in sig.parameters.items():
|
| 406 |
+
if v.annotation in [torch.Tensor, torch.Tensor | None]:
|
| 407 |
+
inferred_dynamic_arg_dims[k] = 0
|
| 408 |
+
|
| 409 |
+
magi_logger.info(f"Inferred dynamic dimensions for forward method of {cls}: {list(inferred_dynamic_arg_dims.keys())}")
|
| 410 |
+
return inferred_dynamic_arg_dims
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _get_num_outputs_from_return_annotation(fn: Callable) -> int:
|
| 414 |
+
"""
|
| 415 |
+
Get the number of outputs from the function's return type annotation.
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
- 1 if the return type is a single Tensor
|
| 419 |
+
- N if the return type is tuple[Tensor, Tensor, ...] with N elements
|
| 420 |
+
- 1 if no annotation or unrecognized annotation (default to single output)
|
| 421 |
+
"""
|
| 422 |
+
sig = inspect.signature(fn)
|
| 423 |
+
return_annotation = sig.return_annotation
|
| 424 |
+
|
| 425 |
+
if return_annotation is inspect.Parameter.empty:
|
| 426 |
+
return 1
|
| 427 |
+
|
| 428 |
+
# Check if it's a tuple type (e.g., tuple[Tensor, Tensor])
|
| 429 |
+
origin = get_origin(return_annotation)
|
| 430 |
+
if origin is tuple:
|
| 431 |
+
args = get_args(return_annotation)
|
| 432 |
+
# Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...])
|
| 433 |
+
if args and args[-1] is not ...:
|
| 434 |
+
return len(args)
|
| 435 |
+
return 1
|
| 436 |
+
|
| 437 |
+
return 1
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _generate_op_name(fn: Callable) -> str:
|
| 441 |
+
"""
|
| 442 |
+
Generate a unique operator name from function's name and source file.
|
| 443 |
+
|
| 444 |
+
The generated name follows the format: namespace::op_name
|
| 445 |
+
- namespace: derived from the source file path (module-like structure)
|
| 446 |
+
- op_name: the function name
|
| 447 |
+
|
| 448 |
+
Example:
|
| 449 |
+
Function `_my_custom_op` in file `/path/to/my_module.py`
|
| 450 |
+
-> "my_module::_my_custom_op"
|
| 451 |
+
"""
|
| 452 |
+
import re
|
| 453 |
+
from pathlib import Path
|
| 454 |
+
|
| 455 |
+
func_name = fn.__name__
|
| 456 |
+
|
| 457 |
+
# Get the source file path
|
| 458 |
+
try:
|
| 459 |
+
source_file = inspect.getfile(fn)
|
| 460 |
+
# Extract the file stem (without extension) as namespace
|
| 461 |
+
namespace = Path(source_file).stem
|
| 462 |
+
# Clean up namespace: replace invalid characters with underscores
|
| 463 |
+
namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
|
| 464 |
+
except (TypeError, OSError):
|
| 465 |
+
# If we can't get the source file, use a default namespace
|
| 466 |
+
namespace = "magi_custom"
|
| 467 |
+
|
| 468 |
+
return f"{namespace}::{func_name}"
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _create_identity_meta_fn(fn: Callable) -> Callable:
|
| 472 |
+
"""
|
| 473 |
+
Create a default identity meta function for the given function.
|
| 474 |
+
|
| 475 |
+
This identity meta function assumes that:
|
| 476 |
+
- The number of outputs is determined by the function's return type annotation
|
| 477 |
+
- Each output's metadata (shape, dtype, device) matches the corresponding input tensor
|
| 478 |
+
|
| 479 |
+
For example, if a function has signature:
|
| 480 |
+
def my_op(a: Tensor, b: Tensor, scale: float) -> tuple[Tensor, Tensor]:
|
| 481 |
+
The identity meta function will return:
|
| 482 |
+
(torch.empty_like(a), torch.empty_like(b))
|
| 483 |
+
"""
|
| 484 |
+
num_outputs = _get_num_outputs_from_return_annotation(fn)
|
| 485 |
+
sig = inspect.signature(fn)
|
| 486 |
+
# Get parameter names, excluding 'self' if present
|
| 487 |
+
param_names = [name for name in sig.parameters.keys() if name != "self"]
|
| 488 |
+
|
| 489 |
+
def identity_meta_fn(*args, **kwargs):
|
| 490 |
+
# Bind arguments to get a mapping of param_name -> value
|
| 491 |
+
bound = sig.bind(*args, **kwargs)
|
| 492 |
+
bound.apply_defaults()
|
| 493 |
+
|
| 494 |
+
# Collect the first `num_outputs` tensor arguments
|
| 495 |
+
tensor_args = []
|
| 496 |
+
for name in param_names:
|
| 497 |
+
arg = bound.arguments.get(name)
|
| 498 |
+
if isinstance(arg, torch.Tensor):
|
| 499 |
+
tensor_args.append(arg)
|
| 500 |
+
if len(tensor_args) >= num_outputs:
|
| 501 |
+
break
|
| 502 |
+
|
| 503 |
+
if len(tensor_args) < num_outputs:
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"identity_meta_fn requires at least {num_outputs} tensor inputs to match "
|
| 506 |
+
f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. "
|
| 507 |
+
f"Please provide a custom infer_output_meta_fn."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Return outputs with same metadata as the first N inputs
|
| 511 |
+
if num_outputs == 1:
|
| 512 |
+
return torch.empty_like(tensor_args[0])
|
| 513 |
+
return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs])
|
| 514 |
+
|
| 515 |
+
return identity_meta_fn
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable:
|
| 519 |
+
"""
|
| 520 |
+
Create a meta function that returns torch.empty_like() for each specified parameter.
|
| 521 |
+
|
| 522 |
+
This is useful when output tensors have the same shape/dtype/device as specific input
|
| 523 |
+
parameters, but not necessarily in positional order.
|
| 524 |
+
|
| 525 |
+
Example:
|
| 526 |
+
param_names = ["weight", "bias"]
|
| 527 |
+
def my_op(grad: Tensor, weight: Tensor, bias: Tensor) -> tuple[Tensor, Tensor]:
|
| 528 |
+
...
|
| 529 |
+
|
| 530 |
+
Generated meta function returns:
|
| 531 |
+
(torch.empty_like(weight), torch.empty_like(bias))
|
| 532 |
+
"""
|
| 533 |
+
sig = inspect.signature(fn)
|
| 534 |
+
|
| 535 |
+
def meta_fn(*args, **kwargs):
|
| 536 |
+
# Bind arguments to get a mapping of param_name -> value
|
| 537 |
+
bound = sig.bind(*args, **kwargs)
|
| 538 |
+
bound.apply_defaults()
|
| 539 |
+
|
| 540 |
+
# Collect tensors for each specified parameter name
|
| 541 |
+
tensor_outputs = []
|
| 542 |
+
for name in param_names:
|
| 543 |
+
if name not in bound.arguments:
|
| 544 |
+
raise ValueError(
|
| 545 |
+
f"Parameter '{name}' not found in function signature. "
|
| 546 |
+
f"Available parameters: {list(bound.arguments.keys())}"
|
| 547 |
+
)
|
| 548 |
+
arg = bound.arguments[name]
|
| 549 |
+
if not isinstance(arg, torch.Tensor):
|
| 550 |
+
raise ValueError(
|
| 551 |
+
f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). "
|
| 552 |
+
f"infer_output_meta_fn list should only contain tensor parameter names."
|
| 553 |
+
)
|
| 554 |
+
tensor_outputs.append(torch.empty_like(arg))
|
| 555 |
+
|
| 556 |
+
# Return single tensor or tuple based on number of outputs
|
| 557 |
+
if len(tensor_outputs) == 1:
|
| 558 |
+
return tensor_outputs[0]
|
| 559 |
+
return tuple(tensor_outputs)
|
| 560 |
+
|
| 561 |
+
return meta_fn
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def magi_register_custom_op(
|
| 565 |
+
name: str | None = None,
|
| 566 |
+
mutates_args: tuple[str, ...] = (),
|
| 567 |
+
infer_output_meta_fn: Callable | list[str] | None = None,
|
| 568 |
+
setup_context_fn: Callable | None = None,
|
| 569 |
+
backward_fn: Callable | None = None,
|
| 570 |
+
):
|
| 571 |
+
"""
|
| 572 |
+
A unified decorator to register a custom operator with PyTorch's library.
|
| 573 |
+
|
| 574 |
+
This decorator combines the functionality of:
|
| 575 |
+
- @torch.library.custom_op
|
| 576 |
+
- @torch.library.register_fake
|
| 577 |
+
- fn.register_autograd
|
| 578 |
+
|
| 579 |
+
Arguments:
|
| 580 |
+
name: The fully qualified name of the operator (e.g., "namespace::op_name").
|
| 581 |
+
If None, auto-generated from the function name and source file.
|
| 582 |
+
mutates_args: Tuple of argument names that are mutated by the operator.
|
| 583 |
+
infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing.
|
| 584 |
+
- None (default): Assumes each output has the same metadata as the corresponding
|
| 585 |
+
input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.).
|
| 586 |
+
- list[str]: Parameter names whose metadata to use for outputs.
|
| 587 |
+
E.g., ["weight", "bias"] means output[0] has same shape as `weight`,
|
| 588 |
+
output[1] has same shape as `bias`.
|
| 589 |
+
- Callable: Custom function with same signature as the op, returns torch.empty_like()
|
| 590 |
+
tensors matching the expected output shapes.
|
| 591 |
+
setup_context_fn: Function to save tensors/values for backward.
|
| 592 |
+
Signature: setup_context_fn(ctx, inputs, output)
|
| 593 |
+
backward_fn: Function to compute gradients.
|
| 594 |
+
Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients
|
| 595 |
+
|
| 596 |
+
Returns:
|
| 597 |
+
The registered custom operator function.
|
| 598 |
+
|
| 599 |
+
Examples:
|
| 600 |
+
1. Basic usage (forward only, auto-generated name and meta function):
|
| 601 |
+
|
| 602 |
+
>>> @magi_register_custom_op()
|
| 603 |
+
... def my_relu(x: torch.Tensor) -> torch.Tensor:
|
| 604 |
+
... return torch.maximum(x, torch.zeros_like(x))
|
| 605 |
+
|
| 606 |
+
2. Multiple outputs with explicit output metadata via parameter names:
|
| 607 |
+
|
| 608 |
+
>>> @magi_register_custom_op(
|
| 609 |
+
... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias
|
| 610 |
+
... )
|
| 611 |
+
... def compute_gradients(
|
| 612 |
+
... grad_output: torch.Tensor,
|
| 613 |
+
... weight: torch.Tensor,
|
| 614 |
+
... bias: torch.Tensor,
|
| 615 |
+
... ) -> tuple[torch.Tensor, torch.Tensor]:
|
| 616 |
+
... grad_weight = grad_output.sum(dim=0).view_as(weight)
|
| 617 |
+
... grad_bias = grad_output.sum(dim=0).view_as(bias)
|
| 618 |
+
... return grad_weight, grad_bias
|
| 619 |
+
|
| 620 |
+
3. Full custom op with autograd support:
|
| 621 |
+
|
| 622 |
+
>>> def _square_meta(x: torch.Tensor) -> torch.Tensor:
|
| 623 |
+
... return torch.empty_like(x)
|
| 624 |
+
...
|
| 625 |
+
>>> def _square_setup_context(ctx, inputs, output):
|
| 626 |
+
... (x,) = inputs
|
| 627 |
+
... ctx.save_for_backward(x)
|
| 628 |
+
...
|
| 629 |
+
>>> def _square_backward(ctx, grad_output):
|
| 630 |
+
... (x,) = ctx.saved_tensors
|
| 631 |
+
... return grad_output * 2 * x
|
| 632 |
+
...
|
| 633 |
+
>>> @magi_register_custom_op(
|
| 634 |
+
... name="my_ops::square",
|
| 635 |
+
... infer_output_meta_fn=_square_meta,
|
| 636 |
+
... setup_context_fn=_square_setup_context,
|
| 637 |
+
... backward_fn=_square_backward,
|
| 638 |
+
... )
|
| 639 |
+
... def square(x: torch.Tensor) -> torch.Tensor:
|
| 640 |
+
... return x * x
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
def decorator(fn: Callable) -> Callable:
|
| 644 |
+
# Auto-generate name if not provided
|
| 645 |
+
op_name = name if name is not None else _generate_op_name(fn)
|
| 646 |
+
|
| 647 |
+
# Step 1: Register the custom op with torch.library.custom_op
|
| 648 |
+
registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn)
|
| 649 |
+
|
| 650 |
+
# Step 2: Register the output meta inference function
|
| 651 |
+
# Determine meta_fn based on the type of infer_output_meta_fn
|
| 652 |
+
if infer_output_meta_fn is None:
|
| 653 |
+
meta_fn = _create_identity_meta_fn(fn)
|
| 654 |
+
elif isinstance(infer_output_meta_fn, list):
|
| 655 |
+
meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn)
|
| 656 |
+
else:
|
| 657 |
+
meta_fn = infer_output_meta_fn
|
| 658 |
+
torch.library.register_fake(op_name)(meta_fn)
|
| 659 |
+
|
| 660 |
+
# Step 3: Register autograd if backward_fn is provided
|
| 661 |
+
if backward_fn is not None:
|
| 662 |
+
registered_op.register_autograd(backward_fn, setup_context=setup_context_fn)
|
| 663 |
+
|
| 664 |
+
return registered_op
|
| 665 |
+
|
| 666 |
+
return decorator
|
pkgs/MagiCompiler/magi_compiler/compile_artifacts.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
import pickle
|
| 20 |
+
from unittest.mock import patch
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.utils._pytree import tree_map_only
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from torch._dynamo.aot_compile import SerializableCallable
|
| 27 |
+
except ImportError:
|
| 28 |
+
SerializableCallable = object
|
| 29 |
+
|
| 30 |
+
assert isinstance(SerializableCallable, type)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MagiSerializableFunction(SerializableCallable):
|
| 34 |
+
"""
|
| 35 |
+
A wrapper around a compiled function by vllm. It will forward the tensor
|
| 36 |
+
inputs to the compiled function and return the result.
|
| 37 |
+
It also implements a serialization interface to support PyTorch's precompile
|
| 38 |
+
with custom backend, so that we can save and load the compiled function on
|
| 39 |
+
disk. There's no need to wrap around the compiled function if we don't want
|
| 40 |
+
to serialize them in particular cases.
|
| 41 |
+
Right now serialization for the custom backend is done via
|
| 42 |
+
serializing the Dynamo fx graph plus example inputs.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, graph_module, example_inputs, model_tag, optimized_call):
|
| 46 |
+
assert isinstance(graph_module, torch.fx.GraphModule)
|
| 47 |
+
self.graph_module = graph_module
|
| 48 |
+
self.example_inputs = example_inputs
|
| 49 |
+
self.model_tag = model_tag
|
| 50 |
+
self.optimized_call = optimized_call
|
| 51 |
+
self.shape_env = None
|
| 52 |
+
sym_input = next((i for i in self.example_inputs if isinstance(i, torch.SymInt)), None)
|
| 53 |
+
if sym_input is not None:
|
| 54 |
+
self.shape_env = sym_input.node.shape_env
|
| 55 |
+
|
| 56 |
+
def __call__(self, *args, **kwargs):
|
| 57 |
+
return self.optimized_call(*args, **kwargs)
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def serialize_compile_artifacts(cls, compiled_fn: "MagiSerializableFunction") -> bytes:
|
| 61 |
+
import sympy
|
| 62 |
+
from torch._subclasses import FakeTensorMode
|
| 63 |
+
from torch.fx._graph_pickler import GraphPickler, Options
|
| 64 |
+
|
| 65 |
+
state = compiled_fn.__dict__.copy()
|
| 66 |
+
state.pop("optimized_call")
|
| 67 |
+
state.pop("shape_env")
|
| 68 |
+
for node in state["graph_module"].graph.nodes:
|
| 69 |
+
node.meta.pop("source_fn_stack", None)
|
| 70 |
+
node.meta.pop("nn_module_stack", None)
|
| 71 |
+
|
| 72 |
+
graph_reducer_override = GraphPickler.reducer_override
|
| 73 |
+
|
| 74 |
+
def _graph_reducer_override(self, obj):
|
| 75 |
+
if inspect.isclass(obj) and issubclass(obj, sympy.Function) and hasattr(obj, "_torch_unpickler"):
|
| 76 |
+
return obj._torch_unpickler, (obj._torch_handler_name,)
|
| 77 |
+
if isinstance(obj, FakeTensorMode):
|
| 78 |
+
return type(None), ()
|
| 79 |
+
return graph_reducer_override(self, obj)
|
| 80 |
+
|
| 81 |
+
# Mask off tensor inputs since they are large and not needed.
|
| 82 |
+
state["example_inputs"] = tree_map_only(torch.Tensor, lambda _: None, state["example_inputs"])
|
| 83 |
+
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
| 84 |
+
state["graph_module"] = GraphPickler.dumps(state["graph_module"], Options(ops_filter=None))
|
| 85 |
+
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
| 86 |
+
return pickle.dumps(state)
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def deserialize_compile_artifacts(cls, data: bytes) -> "MagiSerializableFunction":
|
| 90 |
+
from torch._guards import TracingContext, tracing
|
| 91 |
+
from torch._subclasses import FakeTensorMode
|
| 92 |
+
from torch.fx._graph_pickler import GraphPickler
|
| 93 |
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
| 94 |
+
|
| 95 |
+
from .config import get_compile_config
|
| 96 |
+
from .magi_backend import MagiBackend
|
| 97 |
+
|
| 98 |
+
state = pickle.loads(data)
|
| 99 |
+
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
| 100 |
+
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
| 101 |
+
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
| 102 |
+
magi_backend = MagiBackend(get_compile_config(), state["model_tag"])
|
| 103 |
+
|
| 104 |
+
def optimized_call(*example_inputs):
|
| 105 |
+
"""
|
| 106 |
+
On the first run of the optimized call, we rerun the compiler
|
| 107 |
+
backend which should result in a cache hit. After the backend
|
| 108 |
+
call returns, we just do a one-time replacement of the optimized
|
| 109 |
+
call with the compiled function, so that subsequent calls are on
|
| 110 |
+
the AOT compiled path.
|
| 111 |
+
"""
|
| 112 |
+
compile_inputs = [inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)]
|
| 113 |
+
with tracing(TracingContext(fake_mode)):
|
| 114 |
+
fn.optimized_call = magi_backend(state["graph_module"], compile_inputs).optimized_call
|
| 115 |
+
return fn.optimized_call(*example_inputs)
|
| 116 |
+
|
| 117 |
+
fn = cls(**state, optimized_call=optimized_call)
|
| 118 |
+
return fn
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def co_name(self):
|
| 122 |
+
"""
|
| 123 |
+
Used for depyf debugging.
|
| 124 |
+
"""
|
| 125 |
+
return "MagiSerializableFunction"
|
pkgs/MagiCompiler/magi_compiler/config.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from enum import Enum, unique
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Literal
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from pydantic import BaseModel, Field
|
| 23 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 24 |
+
|
| 25 |
+
from .utils import OrderedSet, compute_hash, magi_logger
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@unique
|
| 29 |
+
class CompileMode(Enum):
|
| 30 |
+
"""
|
| 31 |
+
The compilation approach used for torch.compile-based compilation of the model.
|
| 32 |
+
|
| 33 |
+
NONE: No torch.compile compilation is applied, model runs in fully eager pytorch mode. The model runs as-is.
|
| 34 |
+
TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
|
| 35 |
+
MAGI_COMPILE: Custom Inductor-based backend with caching, piecewise compilation, shape specialization, and custom passes.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
NONE = 'NONE'
|
| 39 |
+
TORCH_COMPILE = 'TORCH_COMPILE'
|
| 40 |
+
MAGI_COMPILE = 'MAGI_COMPILE'
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@unique
|
| 44 |
+
class CudaGraphMode(Enum):
|
| 45 |
+
"""
|
| 46 |
+
Constants for the cudagraph mode in CompileConfig.
|
| 47 |
+
Different from the CUDAGraphMode for llm, PIECEWISE and FULL modes are enough for diffusion models.
|
| 48 |
+
|
| 49 |
+
NONE: No cudagraph is used.
|
| 50 |
+
PIECEWISE: Cudagraph is used for piecewise compilation.
|
| 51 |
+
FULL: Cudagraph is used for full compilation.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
NONE = 'NONE'
|
| 55 |
+
PIECEWISE = 'PIECEWISE'
|
| 56 |
+
FULL = 'FULL'
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class PassConfig(BaseModel):
|
| 60 |
+
"""Configuration for custom Inductor passes"""
|
| 61 |
+
|
| 62 |
+
enable_fusion: bool = Field(False, description="Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.")
|
| 63 |
+
enable_attn_fusion: bool = Field(False, description="Whether to enable the custom attention+quant fusion pass.")
|
| 64 |
+
enable_noop: bool = Field(False, description="Whether to enable the custom no-op elimination pass.")
|
| 65 |
+
enable_sequence_parallelism: bool = Field(False, description="Whether to enable sequence parallelism.")
|
| 66 |
+
enable_async_tp: bool = Field(False, description="Whether to enable async TP.")
|
| 67 |
+
enable_fi_allreduce_fusion: bool = Field(False, description="Whether to enable flashinfer allreduce fusion.")
|
| 68 |
+
enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.")
|
| 69 |
+
fi_allreduce_fusion_max_token_num: int = Field(
|
| 70 |
+
16384, description="Max number of tokens to used in flashinfer allreduce fusion."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def __post_init__(self) -> None:
|
| 74 |
+
if not self.enable_noop:
|
| 75 |
+
if self.enable_fusion:
|
| 76 |
+
magi_logger.warning(
|
| 77 |
+
"Fusion enabled but reshape elimination disabled. " "RMSNorm/SiluMul + quant (fp8) fusion might not work"
|
| 78 |
+
)
|
| 79 |
+
if self.enable_attn_fusion:
|
| 80 |
+
magi_logger.warning(
|
| 81 |
+
"Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def hash(self) -> str:
|
| 86 |
+
return compute_hash(self.model_dump(mode="json"))
|
| 87 |
+
|
| 88 |
+
# Compatible with torch pass
|
| 89 |
+
def uuid(self) -> str:
|
| 90 |
+
return self.hash
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@unique
|
| 94 |
+
class RecomputePolicy(Enum):
|
| 95 |
+
"""
|
| 96 |
+
Defines the strategy for activation recomputation (rematerialization) to trade off
|
| 97 |
+
memory usage against computational overhead.
|
| 98 |
+
|
| 99 |
+
HANDCRAFT:
|
| 100 |
+
A manual strategy where the user controls the trade-off via a `memory_budget`
|
| 101 |
+
parameter. This parameter acts as a threshold (0.0 to 1.0) determining the
|
| 102 |
+
target percentage of activations to save.
|
| 103 |
+
|
| 104 |
+
HEURISTIC:
|
| 105 |
+
A rule-based strategy that selectively saves activations from compute-bound
|
| 106 |
+
operators (e.g., MatMul, Attention). Conversely, outputs from memory-bound
|
| 107 |
+
or element-wise operators are prioritized for recomputation to save memory.
|
| 108 |
+
|
| 109 |
+
AUTOSEARCH:
|
| 110 |
+
An automated strategy that searches for the optimal set of saved tensors based
|
| 111 |
+
on available device memory. It prioritizes saving tensors with high computational
|
| 112 |
+
cost relative to their memory footprint.
|
| 113 |
+
|
| 114 |
+
.. note::
|
| 115 |
+
Currently, a `repeat_number` argument is required to stabilize the profiling/search
|
| 116 |
+
phase. This requirement is temporary and will be deprecated once full-graph
|
| 117 |
+
capture is natively supported.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
HANDCRAFT = "HANDCRAFT"
|
| 121 |
+
HEURISTIC = "HEURISTIC"
|
| 122 |
+
AUTOSEARCH = "AUTOSEARCH"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class RecomputeConfig(BaseModel):
|
| 126 |
+
recompute_policy: RecomputePolicy = Field(RecomputePolicy.HEURISTIC, description="Recompute policy.")
|
| 127 |
+
memory_budget: float = Field(0.5, description="Activation memory budget for recomputation, only used for handcraft.")
|
| 128 |
+
repeat_number: int = Field(default=1, description="Repeat number for recomputation, only used for autosearch.")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@unique
|
| 132 |
+
class OffloadPolicy(Enum):
|
| 133 |
+
"""
|
| 134 |
+
The policy for offloading the model to CPU.
|
| 135 |
+
|
| 136 |
+
BASE:
|
| 137 |
+
The base policy for offloading the model to CPU.
|
| 138 |
+
Offload all the submodules to CPU.
|
| 139 |
+
COST_EFFECTIVE:
|
| 140 |
+
The cost effective policy for offloading the model to CPU.
|
| 141 |
+
Offload the submodules to CPU based on the cost effective policy.
|
| 142 |
+
HEURISTIC:
|
| 143 |
+
The heuristic policy for offloading the model to CPU.
|
| 144 |
+
Offload the submodules to CPU based on the heuristic policy.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
BASE = "BASE"
|
| 148 |
+
COST_EFFECTIVE = "COST_EFFECTIVE"
|
| 149 |
+
HEURISTIC = "HEURISTIC"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class OffloadConfig(BaseModel):
|
| 153 |
+
model_cpu_offload: bool = Field(False, description="Whether to offload the model to CPU.")
|
| 154 |
+
gpu_resident_weight_ratio: float = Field(
|
| 155 |
+
0.3, description="The ratio of GPU memory to keep when offloading the model to CPU."
|
| 156 |
+
)
|
| 157 |
+
offload_policy: OffloadPolicy = Field(
|
| 158 |
+
OffloadPolicy.COST_EFFECTIVE, description="The policy for offloading the model to CPU."
|
| 159 |
+
)
|
| 160 |
+
bandwidth_safety_factor: float = Field(0.9, description="The safety factor for the H2D bandwidth.")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class CompileConfig(BaseSettings):
|
| 164 |
+
model_config = SettingsConfigDict(cli_parse_args=True, cli_ignore_unknown_args=True, cli_implicit_flags=True)
|
| 165 |
+
|
| 166 |
+
# Basic configs
|
| 167 |
+
backend: Literal["inductor", "eager"] = Field("inductor", description="Compilation backend.")
|
| 168 |
+
compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
|
| 169 |
+
cache_root_dir: str = Field(
|
| 170 |
+
default=os.path.expanduser("~/.cache/magi_compiler"), description="Directory to cache the compiled model."
|
| 171 |
+
)
|
| 172 |
+
dynamic_sources: str = Field(
|
| 173 |
+
default=os.environ.get("TORCH_COMPILE_DYNAMIC_SOURCES", ""),
|
| 174 |
+
description="Comma delimited list of sources that should be marked as dynamic.",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# CPU Offload
|
| 178 |
+
offload_config: OffloadConfig = Field(OffloadConfig(), description="Offload configuration.")
|
| 179 |
+
|
| 180 |
+
# Inductor configs
|
| 181 |
+
# TODO(hongyu): Add unittest for compile_sizes
|
| 182 |
+
compile_sizes: list[int] = Field(default_factory=list, description="Sizes to compile the model for.")
|
| 183 |
+
use_inductor_graph_partition: bool = Field(
|
| 184 |
+
False, description="Whether to use inductor graph partition. Not fully supported yet."
|
| 185 |
+
)
|
| 186 |
+
# TODO(hongyu): Find a better way to specify the splitting ops.
|
| 187 |
+
splitting_ops: list[str] = Field(
|
| 188 |
+
default_factory=lambda: [
|
| 189 |
+
"athena::flash_attn_func",
|
| 190 |
+
"athena::flex_flash_attn_func",
|
| 191 |
+
"athena::sage_attn_func",
|
| 192 |
+
"athena::flash_attn_with_cp",
|
| 193 |
+
"athena::flex_flash_attn_with_cp",
|
| 194 |
+
],
|
| 195 |
+
description="Operators to split the graph into piecewise graphs.",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Pass configs
|
| 199 |
+
pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
|
| 200 |
+
|
| 201 |
+
# Recompute configs
|
| 202 |
+
recompute_config: RecomputeConfig = Field(RecomputeConfig(), description="Recompute configuration.")
|
| 203 |
+
|
| 204 |
+
# Cudagraph configs
|
| 205 |
+
cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
|
| 206 |
+
cudagraph_copy_inputs: bool = Field(True, description="Whether to copy inputs for cudagraph.")
|
| 207 |
+
|
| 208 |
+
# Runtime configs, maybe changed at runtime
|
| 209 |
+
model_idx: int = Field(0, description="Index of the model.")
|
| 210 |
+
model_tag: str | None = Field(
|
| 211 |
+
default=None, description="Tag in cache path: model_{idx}_{model_tag}_rank_{rank}. Class name if unset."
|
| 212 |
+
)
|
| 213 |
+
inductor_compile_config: dict[str, Any] = Field(default_factory=dict, description="Inductor compilation configuration.")
|
| 214 |
+
traced_files: OrderedSet[str] = Field(default_factory=OrderedSet, description="Files traced by Dynamo.")
|
| 215 |
+
|
| 216 |
+
def _model_rank_dir_name(self) -> str:
|
| 217 |
+
"""Directory name for this model instance: model_{idx}[_{model_tag}]_rank_{rank}."""
|
| 218 |
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
| 219 |
+
if self.model_tag:
|
| 220 |
+
return f"model_{self.model_idx}_{self.model_tag}_rank_{rank}"
|
| 221 |
+
return f"model_{self.model_idx}_rank_{rank}"
|
| 222 |
+
|
| 223 |
+
def debug_dump_path(self) -> Path:
|
| 224 |
+
return Path(self.cache_root_dir) / "magi_depyf" / self._model_rank_dir_name()
|
| 225 |
+
|
| 226 |
+
def cache_dump_path(self) -> Path:
|
| 227 |
+
return Path(self.cache_root_dir) / "torch_compile_cache" / self._model_rank_dir_name()
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def hash(self) -> str:
|
| 231 |
+
# Create a copy of the config data for serialization
|
| 232 |
+
data = self.model_dump(mode="json", exclude={"inductor_compile_config"})
|
| 233 |
+
|
| 234 |
+
# Handle inductor_compile_config separately to serialize objects with uuid() method
|
| 235 |
+
# This is a workaround to support serialization of PostGradPassManager in Pydantic models.
|
| 236 |
+
if self.inductor_compile_config:
|
| 237 |
+
serialized_inductor_config = {}
|
| 238 |
+
for key, value in self.inductor_compile_config.items():
|
| 239 |
+
# If the value has a uuid() method (like PostGradPassManager), use it
|
| 240 |
+
if hasattr(value, "uuid") and callable(getattr(value, "uuid", None)):
|
| 241 |
+
try:
|
| 242 |
+
serialized_inductor_config[key] = value.uuid()
|
| 243 |
+
except (AttributeError, RuntimeError):
|
| 244 |
+
# Fallback to string representation if uuid() fails
|
| 245 |
+
serialized_inductor_config[key] = str(value)
|
| 246 |
+
else:
|
| 247 |
+
# For other types, try to serialize normally
|
| 248 |
+
try:
|
| 249 |
+
# Try to serialize as JSON-serializable
|
| 250 |
+
json.dumps(value)
|
| 251 |
+
serialized_inductor_config[key] = value
|
| 252 |
+
except (TypeError, ValueError):
|
| 253 |
+
# If not JSON-serializable, use string representation
|
| 254 |
+
serialized_inductor_config[key] = str(value)
|
| 255 |
+
data["inductor_compile_config"] = serialized_inductor_config
|
| 256 |
+
|
| 257 |
+
return compute_hash(data)
|
| 258 |
+
|
| 259 |
+
def __str__(self, indent: int = 4):
|
| 260 |
+
data = self.model_dump(mode="json")
|
| 261 |
+
formatted = json.dumps(data, indent=indent, ensure_ascii=False, sort_keys=False)
|
| 262 |
+
|
| 263 |
+
# add configuration class name as title
|
| 264 |
+
class_name = self.__class__.__name__
|
| 265 |
+
return f"{class_name}:\n{formatted}".replace('"', "")
|
| 266 |
+
|
| 267 |
+
def __repr__(self, indent: int = 4):
|
| 268 |
+
return self.__str__(indent=indent)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
_GLOBAL_COMPILE_CONFIG = None
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def get_compile_config() -> CompileConfig:
|
| 275 |
+
global _GLOBAL_COMPILE_CONFIG
|
| 276 |
+
if _GLOBAL_COMPILE_CONFIG is None:
|
| 277 |
+
_GLOBAL_COMPILE_CONFIG = CompileConfig()
|
| 278 |
+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
| 279 |
+
# 仅在首次初始化时打印一次编译配置,默认 WARNING 级别不会输出
|
| 280 |
+
magi_logger.info("compile config: %s", _GLOBAL_COMPILE_CONFIG)
|
| 281 |
+
assert _GLOBAL_COMPILE_CONFIG is not None, "compile config is not initialized"
|
| 282 |
+
return _GLOBAL_COMPILE_CONFIG
|
pkgs/MagiCompiler/magi_compiler/cuda/cudart.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import ctypes
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
_cudart = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def init_cudart():
|
| 25 |
+
global _cudart
|
| 26 |
+
if _cudart is not None:
|
| 27 |
+
return _cudart
|
| 28 |
+
candidates = ["libcudart.so", "libcudart.so.11.0", "libcudart.so.12.0, libcudart.so.13"]
|
| 29 |
+
try:
|
| 30 |
+
cuda_path = os.path.dirname(torch.utils.cpp_extension._find_cuda_home())
|
| 31 |
+
candidates.append(os.path.join(cuda_path, "lib64", "libcudart.so"))
|
| 32 |
+
except:
|
| 33 |
+
pass
|
| 34 |
+
for lib in candidates:
|
| 35 |
+
try:
|
| 36 |
+
_cudart = ctypes.CDLL(lib)
|
| 37 |
+
return _cudart
|
| 38 |
+
except OSError:
|
| 39 |
+
continue
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def pin_memory_in_place(tensor: torch.Tensor):
|
| 44 |
+
"""
|
| 45 |
+
Pin memory in-place using cudaHostRegister.
|
| 46 |
+
"""
|
| 47 |
+
if tensor.is_cuda:
|
| 48 |
+
return tensor
|
| 49 |
+
cudart = init_cudart()
|
| 50 |
+
if cudart is None:
|
| 51 |
+
return tensor
|
| 52 |
+
|
| 53 |
+
ptr = tensor.data_ptr()
|
| 54 |
+
size = tensor.numel() * tensor.element_size()
|
| 55 |
+
res = cudart.cudaHostRegister(ctypes.c_void_p(ptr), ctypes.c_size_t(size), 0)
|
| 56 |
+
|
| 57 |
+
if res == 0:
|
| 58 |
+
return tensor
|
| 59 |
+
else:
|
| 60 |
+
raise RuntimeError(f"cudaHostRegister failed with error code {res}")
|
pkgs/MagiCompiler/magi_compiler/cuda_graph_mgr.py
ADDED
|
@@ -0,0 +1,931 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, fields, is_dataclass
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from .utils import magi_logger, nvtx
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InplaceSubstituteFakeClass:
|
| 25 |
+
"""
|
| 26 |
+
The class which inherits from this class will not be replaced with a new instance,
|
| 27 |
+
but the attributes will be updated in-place.
|
| 28 |
+
For example, InferenceParams.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class FakeTensor:
|
| 36 |
+
shape: Tuple[int, ...] = None
|
| 37 |
+
dtype: str = None
|
| 38 |
+
device: str = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class HashableDataclass:
|
| 43 |
+
_cached_hash: Optional[int] = None
|
| 44 |
+
|
| 45 |
+
@nvtx.instrument_nvtx
|
| 46 |
+
def _get_hashable_fields(self) -> Tuple[Any, ...]:
|
| 47 |
+
hashable_values = []
|
| 48 |
+
for f in fields(self):
|
| 49 |
+
if f.name == "_cached_hash":
|
| 50 |
+
continue
|
| 51 |
+
value = getattr(self, f.name)
|
| 52 |
+
if value is None:
|
| 53 |
+
continue
|
| 54 |
+
if isinstance(value, HashableDataclass):
|
| 55 |
+
hashable_values.append(value._get_cached_hash())
|
| 56 |
+
elif isinstance(value, tuple):
|
| 57 |
+
tuple_vals = []
|
| 58 |
+
for item in value:
|
| 59 |
+
if isinstance(item, (HashableDataclass, str, int, float, bool)):
|
| 60 |
+
if isinstance(item, HashableDataclass):
|
| 61 |
+
tuple_vals.append(item._get_cached_hash())
|
| 62 |
+
else:
|
| 63 |
+
tuple_vals.append(item)
|
| 64 |
+
if tuple_vals:
|
| 65 |
+
hashable_values.append(tuple(tuple_vals))
|
| 66 |
+
elif isinstance(value, (str, int, float, bool)):
|
| 67 |
+
hashable_values.append(value)
|
| 68 |
+
return tuple(hashable_values)
|
| 69 |
+
|
| 70 |
+
@nvtx.instrument_nvtx
|
| 71 |
+
def _compute_hash(self) -> int:
|
| 72 |
+
"""Computes a hash value based on the dataclass's hashable fields."""
|
| 73 |
+
hashable_fields = self._get_hashable_fields()
|
| 74 |
+
return hash(hashable_fields) % (1 << 64) # 限制为 64 位
|
| 75 |
+
|
| 76 |
+
@nvtx.instrument_nvtx
|
| 77 |
+
def _get_cached_hash(self) -> int:
|
| 78 |
+
if self._cached_hash is None:
|
| 79 |
+
self._cached_hash = self._compute_hash()
|
| 80 |
+
return self._cached_hash
|
| 81 |
+
|
| 82 |
+
@nvtx.instrument_nvtx
|
| 83 |
+
def __hash__(self) -> int:
|
| 84 |
+
return self._get_cached_hash()
|
| 85 |
+
|
| 86 |
+
@nvtx.instrument_nvtx
|
| 87 |
+
def __eq__(self, other: Any) -> bool:
|
| 88 |
+
if not isinstance(other, self.__class__):
|
| 89 |
+
return False
|
| 90 |
+
if self._get_cached_hash() != other._get_cached_hash():
|
| 91 |
+
return False
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass(unsafe_hash=True)
|
| 96 |
+
class LiteralsInfo(HashableDataclass):
|
| 97 |
+
literals: Tuple[Any, ...] = tuple()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass(unsafe_hash=True)
|
| 101 |
+
class TensorStaticInfo(HashableDataclass):
|
| 102 |
+
name: str = ""
|
| 103 |
+
shapes: Tuple[int, ...] = tuple()
|
| 104 |
+
dtype: str = ""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass(unsafe_hash=True)
|
| 108 |
+
class TensorDynamicInfo(HashableDataclass):
|
| 109 |
+
name: str = ""
|
| 110 |
+
shapes: Tuple[int, ...] = tuple()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass(unsafe_hash=True)
|
| 114 |
+
class StaticSignature(HashableDataclass):
|
| 115 |
+
func_name: str = ""
|
| 116 |
+
tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@dataclass(unsafe_hash=True)
|
| 120 |
+
class DynamicSignature(HashableDataclass):
|
| 121 |
+
tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
|
| 122 |
+
literals_info: LiteralsInfo = None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class GraphEntry:
|
| 127 |
+
graph: Optional[torch.cuda.CUDAGraph] = None
|
| 128 |
+
inconsistent: bool = False
|
| 129 |
+
invalid: bool = False
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
class OutputTemplateEntry:
|
| 134 |
+
graph_entry_dict: Dict[int, GraphEntry] = None # key = layer_number
|
| 135 |
+
output_template: Any = None # 用于存储输出对象literals的结构模板
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class StaticTensorEntry:
|
| 140 |
+
input_tensors: Optional[List[torch.Tensor]] = None
|
| 141 |
+
output_tensors: Optional[List[torch.Tensor]] = None
|
| 142 |
+
template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class ArgsUtils:
|
| 146 |
+
@staticmethod
|
| 147 |
+
@nvtx.instrument_nvtx
|
| 148 |
+
def generate_both_signatures_from_tensors(
|
| 149 |
+
func_name: str, tensors: List[torch.Tensor], names: List[str], literals: List[Any]
|
| 150 |
+
) -> Tuple[StaticSignature, DynamicSignature]:
|
| 151 |
+
num_tensors = len(tensors)
|
| 152 |
+
tensor_static_infos = [TensorStaticInfo() for _ in range(num_tensors)]
|
| 153 |
+
tensor_dynamic_infos = [TensorDynamicInfo() for _ in range(num_tensors)]
|
| 154 |
+
|
| 155 |
+
# Local references for performance
|
| 156 |
+
TensorStaticInfo_setattr = TensorStaticInfo.__setattr__
|
| 157 |
+
TensorDynamicInfo_setattr = TensorDynamicInfo.__setattr__
|
| 158 |
+
_tuple = tuple
|
| 159 |
+
|
| 160 |
+
for i in range(num_tensors):
|
| 161 |
+
t = tensors[i]
|
| 162 |
+
t_dim = t.dim()
|
| 163 |
+
t_shape = t.shape
|
| 164 |
+
t_dtype_str = str(t.dtype)
|
| 165 |
+
# Last dimension is static, others are dynamic (except for 1D tensor)
|
| 166 |
+
static_shapes = (
|
| 167 |
+
_tuple(-1 if idx != t_dim - 1 else dim_size for idx, dim_size in enumerate(t_shape)) if t_dim > 1 else (-1,)
|
| 168 |
+
)
|
| 169 |
+
static_info = tensor_static_infos[i]
|
| 170 |
+
TensorStaticInfo_setattr(static_info, "shapes", static_shapes)
|
| 171 |
+
TensorStaticInfo_setattr(static_info, "dtype", t_dtype_str)
|
| 172 |
+
|
| 173 |
+
dynamic_shapes = static_shapes = (
|
| 174 |
+
_tuple(-1 if idx == t_dim - 1 else dim_size for idx, dim_size in enumerate(t_shape))
|
| 175 |
+
if t_dim > 1
|
| 176 |
+
else _tuple(t_shape)
|
| 177 |
+
)
|
| 178 |
+
dynamic_info = tensor_dynamic_infos[i]
|
| 179 |
+
TensorDynamicInfo_setattr(dynamic_info, "shapes", dynamic_shapes)
|
| 180 |
+
|
| 181 |
+
literals_info = LiteralsInfo(literals=_tuple(literals))
|
| 182 |
+
static_sig = StaticSignature(func_name=func_name, tensor_static_infos=_tuple(tensor_static_infos))
|
| 183 |
+
dynamic_sig = DynamicSignature(tensor_dynamic_infos=_tuple(tensor_dynamic_infos), literals_info=literals_info)
|
| 184 |
+
return static_sig, dynamic_sig
|
| 185 |
+
|
| 186 |
+
@staticmethod
|
| 187 |
+
@nvtx.instrument_nvtx
|
| 188 |
+
def replace_sliced_with_static(obj: Any, static_tensors: List[torch.Tensor]) -> Any:
|
| 189 |
+
tensor_idx = 0
|
| 190 |
+
|
| 191 |
+
def recursive_replace(o: Any) -> Any:
|
| 192 |
+
nonlocal tensor_idx
|
| 193 |
+
if isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter):
|
| 194 |
+
# Copy data to the corresponding static tensor slice
|
| 195 |
+
static_tensor = static_tensors[tensor_idx]
|
| 196 |
+
slices = [slice(None)] * static_tensor.ndim
|
| 197 |
+
for i in range(min(o.ndim, static_tensor.ndim)):
|
| 198 |
+
slices[i] = slice(0, o.shape[i])
|
| 199 |
+
# Only copy if the data_ptrs are different
|
| 200 |
+
if not o.data_ptr() == static_tensor[tuple(slices)].data_ptr():
|
| 201 |
+
static_tensor[tuple(slices)].copy_(o)
|
| 202 |
+
tensor_idx += 1
|
| 203 |
+
return static_tensor[tuple(slices)]
|
| 204 |
+
|
| 205 |
+
elif isinstance(o, dict):
|
| 206 |
+
return {k: recursive_replace(v) for k, v in o.items()}
|
| 207 |
+
elif isinstance(o, (list, tuple)):
|
| 208 |
+
return type(o)(recursive_replace(item) for item in o)
|
| 209 |
+
elif is_dataclass(o):
|
| 210 |
+
field_values = {f.name: recursive_replace(getattr(o, f.name)) for f in fields(o)}
|
| 211 |
+
return type(o)(**field_values)
|
| 212 |
+
elif issubclass(o.__class__, InplaceSubstituteFakeClass):
|
| 213 |
+
# Do not create a new instance, but modify attributes in place (to keep original initialization logic)
|
| 214 |
+
for k, v in o.__dict__.items():
|
| 215 |
+
if not callable(v):
|
| 216 |
+
o.__dict__[k] = recursive_replace(v)
|
| 217 |
+
return o
|
| 218 |
+
elif o is None or isinstance(o, (int, float, str, bool)):
|
| 219 |
+
return o # Keep None and basic types
|
| 220 |
+
else:
|
| 221 |
+
return o
|
| 222 |
+
|
| 223 |
+
return recursive_replace(obj)
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
@nvtx.instrument_nvtx
|
| 227 |
+
def replace_sliced_with_static_simple(
|
| 228 |
+
sliced_tensors: List[torch.Tensor], static_tensors: List[torch.Tensor]
|
| 229 |
+
) -> List[torch.Tensor]:
|
| 230 |
+
for sliced_tensor, static_tensor in zip(sliced_tensors, static_tensors):
|
| 231 |
+
if not sliced_tensor.data_ptr() == static_tensor.data_ptr():
|
| 232 |
+
slices = [slice(None)] * static_tensor.ndim
|
| 233 |
+
for i in range(sliced_tensor.ndim):
|
| 234 |
+
slices[i] = slice(0, sliced_tensor.shape[i])
|
| 235 |
+
static_tensor[tuple(slices)].copy_(sliced_tensor)
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
@nvtx.instrument_nvtx
|
| 239 |
+
def replace_static_with_sliced(obj: Any, static_tensors: List[torch.Tensor]) -> Any:
|
| 240 |
+
tensor_idx = 0
|
| 241 |
+
|
| 242 |
+
def recursive_replace(o: Any) -> Any:
|
| 243 |
+
nonlocal tensor_idx
|
| 244 |
+
if (isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)) or isinstance(o, FakeTensor):
|
| 245 |
+
# Replace with the corresponding sliced tensor
|
| 246 |
+
static_tensor = static_tensors[tensor_idx]
|
| 247 |
+
shape_to_slice = o.shape
|
| 248 |
+
slices = [slice(0, dim_size) for dim_size in shape_to_slice]
|
| 249 |
+
result_tensor = static_tensor[tuple(slices)]
|
| 250 |
+
tensor_idx += 1
|
| 251 |
+
return result_tensor
|
| 252 |
+
|
| 253 |
+
elif isinstance(o, dict):
|
| 254 |
+
return {k: recursive_replace(v) for k, v in o.items()}
|
| 255 |
+
elif isinstance(o, (list, tuple)):
|
| 256 |
+
return type(o)(recursive_replace(item) for item in o)
|
| 257 |
+
elif is_dataclass(o):
|
| 258 |
+
field_values = {f.name: recursive_replace(getattr(o, f.name)) for f in fields(o)}
|
| 259 |
+
return type(o)(**field_values)
|
| 260 |
+
elif issubclass(o.__class__, InplaceSubstituteFakeClass):
|
| 261 |
+
# Do not create a new instance, but modify attributes in place (to keep original initialization logic)
|
| 262 |
+
for k, v in o.__dict__.items():
|
| 263 |
+
if not callable(v):
|
| 264 |
+
o.__dict__[k] = recursive_replace(v)
|
| 265 |
+
return o
|
| 266 |
+
elif o is None or isinstance(o, (int, float, str, bool)):
|
| 267 |
+
return o # Keep None and basic types
|
| 268 |
+
else:
|
| 269 |
+
return o
|
| 270 |
+
|
| 271 |
+
return recursive_replace(obj)
|
| 272 |
+
|
| 273 |
+
@staticmethod
|
| 274 |
+
@nvtx.instrument_nvtx
|
| 275 |
+
def try_fx_extract_core(
|
| 276 |
+
obj: Any, extract_tensors: bool = True, extract_literals: bool = True, with_names: bool = False
|
| 277 |
+
) -> Tuple[List[torch.Tensor], List[str], List[Any]]:
|
| 278 |
+
failed_tuple = None, None, None
|
| 279 |
+
tensors = []
|
| 280 |
+
names = []
|
| 281 |
+
literals = []
|
| 282 |
+
|
| 283 |
+
if not isinstance(obj, dict) or "args" not in obj or "kwargs" not in obj:
|
| 284 |
+
return failed_tuple
|
| 285 |
+
args, kwargs = obj["args"], obj["kwargs"]
|
| 286 |
+
if kwargs:
|
| 287 |
+
return failed_tuple
|
| 288 |
+
if not isinstance(args, (list, tuple)):
|
| 289 |
+
return failed_tuple
|
| 290 |
+
|
| 291 |
+
for idx, item in enumerate(args):
|
| 292 |
+
if extract_tensors and isinstance(item, torch.Tensor) and not isinstance(item, torch.nn.Parameter):
|
| 293 |
+
tensors.append(item)
|
| 294 |
+
elif extract_literals and isinstance(item, (int, float, str, bool)):
|
| 295 |
+
literals.append(item)
|
| 296 |
+
|
| 297 |
+
names = [""] * len(tensors)
|
| 298 |
+
return tensors, names, literals
|
| 299 |
+
|
| 300 |
+
@staticmethod
|
| 301 |
+
@nvtx.instrument_nvtx
|
| 302 |
+
def recursive_extract_core(
|
| 303 |
+
obj: Any, extract_tensors: bool = True, extract_literals: bool = True, with_names: bool = False
|
| 304 |
+
) -> Tuple[List[torch.Tensor], List[str], List[Any]]:
|
| 305 |
+
tensors = []
|
| 306 |
+
names = []
|
| 307 |
+
literals = []
|
| 308 |
+
|
| 309 |
+
def recursive_traverse(o: Any, prefix: str = ""):
|
| 310 |
+
# 1. Extract tensors (if enabled)
|
| 311 |
+
if extract_tensors and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter):
|
| 312 |
+
tensors.append(o)
|
| 313 |
+
names.append(prefix) if with_names else None
|
| 314 |
+
elif extract_literals and isinstance(o, (int, float, str, bool)):
|
| 315 |
+
literals.append(o) if extract_literals else None
|
| 316 |
+
elif isinstance(o, dict):
|
| 317 |
+
for k, v in o.items():
|
| 318 |
+
new_prefix = f"{prefix}.{k}" if (with_names and extract_tensors) else prefix
|
| 319 |
+
recursive_traverse(v, new_prefix)
|
| 320 |
+
elif isinstance(o, (list, tuple)):
|
| 321 |
+
for idx, item in enumerate(o):
|
| 322 |
+
new_prefix = f"{prefix}[{idx}]" if (with_names and extract_tensors) else prefix
|
| 323 |
+
recursive_traverse(item, new_prefix)
|
| 324 |
+
elif is_dataclass(o):
|
| 325 |
+
for f in fields(o):
|
| 326 |
+
new_prefix = f"{prefix}.{f.name}" if (with_names and extract_tensors) else prefix
|
| 327 |
+
recursive_traverse(getattr(o, f.name), new_prefix)
|
| 328 |
+
elif issubclass(o.__class__, InplaceSubstituteFakeClass):
|
| 329 |
+
for k, v in o.__dict__.items():
|
| 330 |
+
if not callable(v):
|
| 331 |
+
new_prefix = f"{prefix}.{k}" if (with_names and extract_tensors) else prefix
|
| 332 |
+
recursive_traverse(v, new_prefix)
|
| 333 |
+
elif o is None:
|
| 334 |
+
pass
|
| 335 |
+
else:
|
| 336 |
+
pass
|
| 337 |
+
|
| 338 |
+
recursive_traverse(obj)
|
| 339 |
+
return tensors, names if with_names else [""] * len(tensors), literals if extract_literals else None
|
| 340 |
+
|
| 341 |
+
@staticmethod
|
| 342 |
+
@nvtx.instrument_nvtx
|
| 343 |
+
def extract_output_template(obj: Any) -> Any:
|
| 344 |
+
def recursive_template(o: Any) -> Any:
|
| 345 |
+
if isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter):
|
| 346 |
+
return FakeTensor(shape=list(o.shape), dtype=str(o.dtype), device=str(o.device))
|
| 347 |
+
elif isinstance(o, dict):
|
| 348 |
+
return {k: recursive_template(v) for k, v in o.items()}
|
| 349 |
+
elif isinstance(o, (list, tuple)):
|
| 350 |
+
return type(o)(recursive_template(item) for item in o)
|
| 351 |
+
elif is_dataclass(o):
|
| 352 |
+
field_values = {f.name: recursive_template(getattr(o, f.name)) for f in fields(o)}
|
| 353 |
+
return type(o)(**field_values)
|
| 354 |
+
elif issubclass(o.__class__, InplaceSubstituteFakeClass):
|
| 355 |
+
# 不重新创建实例,直接修改属性(保持原有初始化逻辑)
|
| 356 |
+
for k, v in o.__dict__.items():
|
| 357 |
+
if not callable(v):
|
| 358 |
+
o.__dict__[k] = recursive_template(v)
|
| 359 |
+
return o
|
| 360 |
+
elif o is None or isinstance(o, (int, float, str, bool)):
|
| 361 |
+
return o
|
| 362 |
+
else:
|
| 363 |
+
return o
|
| 364 |
+
|
| 365 |
+
return recursive_template(obj)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class CudaGraphMgr:
|
| 369 |
+
"""CUDA Graph Manager for caching and managing CUDA Graphs and static tensors."""
|
| 370 |
+
|
| 371 |
+
def __init__(self):
|
| 372 |
+
self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
|
| 373 |
+
self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
|
| 374 |
+
self.check_output_inconsistency = False # Not enabled by default
|
| 375 |
+
|
| 376 |
+
@property
|
| 377 |
+
def graph_count(self) -> int:
|
| 378 |
+
count = 0
|
| 379 |
+
for tensor_entry in self.cache.values():
|
| 380 |
+
if tensor_entry.template_entry_dict is not None:
|
| 381 |
+
for template_entry in tensor_entry.template_entry_dict.values():
|
| 382 |
+
for graph_entry in template_entry.graph_entry_dict.values():
|
| 383 |
+
if graph_entry.graph is not None and not graph_entry.inconsistent and not graph_entry.invalid:
|
| 384 |
+
count += 1
|
| 385 |
+
return count
|
| 386 |
+
|
| 387 |
+
@property
|
| 388 |
+
def tensor_entry_count(self) -> int:
|
| 389 |
+
count = 0
|
| 390 |
+
for tensor_entry in self.cache.values():
|
| 391 |
+
if tensor_entry.input_tensors is not None and tensor_entry.output_tensors is not None:
|
| 392 |
+
count += 1
|
| 393 |
+
return count
|
| 394 |
+
|
| 395 |
+
@property
|
| 396 |
+
def graph_mem_pool_size(self) -> float:
|
| 397 |
+
if not hasattr(self, "graph_mem_pool") or self.graph_mem_pool is None:
|
| 398 |
+
return 0.0
|
| 399 |
+
pool_stats = torch.cuda.memory.memory_stats(self.graph_mem_pool)
|
| 400 |
+
used_mem = pool_stats.get("allocated_bytes.all.current", 0)
|
| 401 |
+
return used_mem / (1024 * 1024) # 转换为MB
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def tensor_mem_size(self) -> float:
|
| 405 |
+
total_size = 0 # 字节
|
| 406 |
+
for tensor_entry in self.cache.values():
|
| 407 |
+
if tensor_entry.input_tensors is not None:
|
| 408 |
+
for t in tensor_entry.input_tensors:
|
| 409 |
+
total_size += t.element_size() * t.nelement()
|
| 410 |
+
if tensor_entry.output_tensors is not None:
|
| 411 |
+
for t in tensor_entry.output_tensors:
|
| 412 |
+
total_size += t.element_size() * t.nelement()
|
| 413 |
+
return total_size / (1024 * 1024) # 转换为MB
|
| 414 |
+
|
| 415 |
+
@nvtx.instrument_nvtx
|
| 416 |
+
def formatted_cache_str(self) -> str:
|
| 417 |
+
"""Format the cache content as a string for debugging."""
|
| 418 |
+
lines = []
|
| 419 |
+
for static_sig, tensor_entry in self.cache.items():
|
| 420 |
+
lines.append(f"StaticSignature: {static_sig}")
|
| 421 |
+
s = " Input Static Tensors: "
|
| 422 |
+
for it in tensor_entry.input_tensors:
|
| 423 |
+
s += f"[shape={list(it.shape)},dtype={str(it.dtype)}] "
|
| 424 |
+
lines.append(s)
|
| 425 |
+
s = " Output Static Tensors: "
|
| 426 |
+
for ot in tensor_entry.output_tensors:
|
| 427 |
+
s += f"[shape={list(ot.shape)},dtype={str(ot.dtype)}] "
|
| 428 |
+
lines.append(s)
|
| 429 |
+
if tensor_entry.template_entry_dict is not None:
|
| 430 |
+
for dynamic_sig, template_entry in tensor_entry.template_entry_dict.items():
|
| 431 |
+
lines.append(f" DynamicSignature: {dynamic_sig}")
|
| 432 |
+
lines.append(f" Output Template: {template_entry.output_template}")
|
| 433 |
+
for layer_number, graph_entry in template_entry.graph_entry_dict.items():
|
| 434 |
+
status = "Valid"
|
| 435 |
+
if graph_entry.inconsistent:
|
| 436 |
+
status = "Inconsistent"
|
| 437 |
+
elif graph_entry.invalid:
|
| 438 |
+
status = "Invalid"
|
| 439 |
+
lines.append(f" Layer {layer_number}: Graph Status: {status}")
|
| 440 |
+
return "\n".join(lines)
|
| 441 |
+
|
| 442 |
+
@nvtx.instrument_nvtx
|
| 443 |
+
def try_get_cuda_graph(
|
| 444 |
+
self, static_sig: StaticSignature, dynamic_sig: DynamicSignature, layer_number: int
|
| 445 |
+
) -> Optional[torch.cuda.CUDAGraph]:
|
| 446 |
+
graph_entry = self.try_get_graph_entry(static_sig, dynamic_sig, layer_number)
|
| 447 |
+
if (
|
| 448 |
+
graph_entry is not None
|
| 449 |
+
and graph_entry.graph is not None
|
| 450 |
+
and not graph_entry.inconsistent
|
| 451 |
+
and not graph_entry.invalid
|
| 452 |
+
):
|
| 453 |
+
return graph_entry.graph
|
| 454 |
+
return None
|
| 455 |
+
|
| 456 |
+
@nvtx.instrument_nvtx
|
| 457 |
+
def get_static_tensors(self, input_static_sig: StaticSignature) -> Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]]:
|
| 458 |
+
if input_static_sig in self.cache:
|
| 459 |
+
cached_entry = self.cache[input_static_sig]
|
| 460 |
+
return cached_entry.input_tensors, cached_entry.output_tensors
|
| 461 |
+
raise ValueError("Cached input/output tensors not found for the given static signature.")
|
| 462 |
+
|
| 463 |
+
@nvtx.instrument_nvtx
|
| 464 |
+
def warmup_run(self, func: Callable, *args, **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 465 |
+
warmup_outputs = None
|
| 466 |
+
s = torch.cuda.Stream()
|
| 467 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 468 |
+
with torch.cuda.stream(s), torch.no_grad():
|
| 469 |
+
for _ in range(1):
|
| 470 |
+
warmup_outputs = func(*args, **kwargs)
|
| 471 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 472 |
+
return warmup_outputs
|
| 473 |
+
|
| 474 |
+
@nvtx.instrument_nvtx
|
| 475 |
+
def add_static_entry(
|
| 476 |
+
self,
|
| 477 |
+
static_sig: StaticSignature,
|
| 478 |
+
input_tensors: Optional[List[torch.Tensor]] = None,
|
| 479 |
+
output_tensors: Optional[List[torch.Tensor]] = None,
|
| 480 |
+
) -> None:
|
| 481 |
+
assert static_sig not in self.cache
|
| 482 |
+
self.cache[static_sig] = StaticTensorEntry(
|
| 483 |
+
input_tensors=input_tensors, output_tensors=output_tensors, template_entry_dict=dict()
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
@nvtx.instrument_nvtx
|
| 487 |
+
def add_template_entry(
|
| 488 |
+
self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, output_obj: Any = None
|
| 489 |
+
) -> None:
|
| 490 |
+
try:
|
| 491 |
+
output_template = ArgsUtils.extract_output_template(output_obj)
|
| 492 |
+
self.cache[input_static_sig].template_entry_dict[input_dynamic_sig] = OutputTemplateEntry(
|
| 493 |
+
graph_entry_dict=dict(), output_template=output_template
|
| 494 |
+
)
|
| 495 |
+
except KeyError:
|
| 496 |
+
raise ValueError("StaticSignature not found in cache when adding template entry.")
|
| 497 |
+
|
| 498 |
+
@nvtx.instrument_nvtx
|
| 499 |
+
def add_graph_entry(
|
| 500 |
+
self,
|
| 501 |
+
input_static_sig: StaticSignature,
|
| 502 |
+
input_dynamic_sig: DynamicSignature,
|
| 503 |
+
layer_number: int,
|
| 504 |
+
graph: torch.cuda.CUDAGraph,
|
| 505 |
+
) -> None:
|
| 506 |
+
try:
|
| 507 |
+
self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number] = GraphEntry(
|
| 508 |
+
graph=graph, inconsistent=False, invalid=False
|
| 509 |
+
)
|
| 510 |
+
except KeyError:
|
| 511 |
+
raise ValueError("StaticSignature or DynamicSignature not found in cache when adding graph entry.")
|
| 512 |
+
|
| 513 |
+
@nvtx.instrument_nvtx
|
| 514 |
+
def try_get_graph_entry(
|
| 515 |
+
self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, layer_number: int
|
| 516 |
+
) -> Optional[GraphEntry]:
|
| 517 |
+
try:
|
| 518 |
+
return self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number]
|
| 519 |
+
except KeyError:
|
| 520 |
+
pass
|
| 521 |
+
return None
|
| 522 |
+
|
| 523 |
+
@nvtx.instrument_nvtx
|
| 524 |
+
def batch_set_graph_invalid(self, static_sig: StaticSignature) -> None:
|
| 525 |
+
if static_sig in self.cache:
|
| 526 |
+
static_tensor_entry = self.cache[static_sig]
|
| 527 |
+
if static_tensor_entry.template_entry_dict is not None:
|
| 528 |
+
for template_entry in static_tensor_entry.template_entry_dict.values():
|
| 529 |
+
for graph_entry in template_entry.graph_entry_dict.values():
|
| 530 |
+
graph_entry.invalid = True
|
| 531 |
+
|
| 532 |
+
@nvtx.instrument_nvtx
|
| 533 |
+
def set_graph_inconsistent(
|
| 534 |
+
self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, layer_number: int
|
| 535 |
+
) -> None:
|
| 536 |
+
if input_static_sig not in self.cache:
|
| 537 |
+
self.add_static_entry(input_static_sig, None, None)
|
| 538 |
+
if input_dynamic_sig not in self.cache[input_static_sig].template_entry_dict:
|
| 539 |
+
self.add_template_entry(input_static_sig, input_dynamic_sig, None)
|
| 540 |
+
if layer_number not in self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict:
|
| 541 |
+
self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number] = GraphEntry(
|
| 542 |
+
graph=None, inconsistent=True, invalid=False
|
| 543 |
+
)
|
| 544 |
+
self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number].inconsistent = True
|
| 545 |
+
|
| 546 |
+
@nvtx.instrument_nvtx
|
| 547 |
+
def wrapped_graph_capture(
|
| 548 |
+
self,
|
| 549 |
+
func: Callable,
|
| 550 |
+
input_obj: Any,
|
| 551 |
+
static_input_tensors: List[torch.Tensor],
|
| 552 |
+
static_output_tensors: List[torch.Tensor],
|
| 553 |
+
) -> torch.cuda.CUDAGraph:
|
| 554 |
+
init_cudagraph_global_pool()
|
| 555 |
+
_set_capture_start()
|
| 556 |
+
try:
|
| 557 |
+
graph = torch.cuda.CUDAGraph()
|
| 558 |
+
_static_input_obj = ArgsUtils.replace_sliced_with_static(input_obj, static_input_tensors)
|
| 559 |
+
s = None # future: s = GreenCtxManager(0).create_stream()
|
| 560 |
+
with torch.cuda.graph(graph, pool=self.graph_mem_pool, stream=s), torch.no_grad():
|
| 561 |
+
_sliced_output_obj = func(*_static_input_obj["args"], **_static_input_obj["kwargs"])
|
| 562 |
+
_static_output_obj = ArgsUtils.replace_sliced_with_static(_sliced_output_obj, static_output_tensors)
|
| 563 |
+
except Exception as e:
|
| 564 |
+
torch.cuda.synchronize() # 等待所有异步操作完成
|
| 565 |
+
_set_capture_end()
|
| 566 |
+
raise e
|
| 567 |
+
_set_capture_end()
|
| 568 |
+
return graph
|
| 569 |
+
|
| 570 |
+
@nvtx.instrument_nvtx
|
| 571 |
+
def wrapped_graph_replay(
|
| 572 |
+
self,
|
| 573 |
+
graph: torch.cuda.CUDAGraph,
|
| 574 |
+
static_input_tensors: List[torch.Tensor],
|
| 575 |
+
static_output_tensors: List[torch.Tensor],
|
| 576 |
+
input_obj: Any,
|
| 577 |
+
output_template: Any,
|
| 578 |
+
) -> Any:
|
| 579 |
+
_static_input_obj = ArgsUtils.replace_sliced_with_static(input_obj, static_input_tensors)
|
| 580 |
+
graph.replay()
|
| 581 |
+
output_obj = ArgsUtils.replace_static_with_sliced(output_template, static_output_tensors)
|
| 582 |
+
return output_obj
|
| 583 |
+
|
| 584 |
+
@nvtx.instrument_nvtx
|
| 585 |
+
def replay_graph(
|
| 586 |
+
self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, input_obj: Any, layer_number: int
|
| 587 |
+
) -> Any:
|
| 588 |
+
output_template = self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].output_template
|
| 589 |
+
static_input_tensors = self.cache[input_static_sig].input_tensors
|
| 590 |
+
static_output_tensors = self.cache[input_static_sig].output_tensors
|
| 591 |
+
graph = self.try_get_cuda_graph(input_static_sig, input_dynamic_sig, layer_number=layer_number)
|
| 592 |
+
assert graph is not None, "CUDA Graph not found for replay."
|
| 593 |
+
output_obj = self.wrapped_graph_replay(
|
| 594 |
+
graph=graph,
|
| 595 |
+
static_input_tensors=static_input_tensors,
|
| 596 |
+
static_output_tensors=static_output_tensors,
|
| 597 |
+
input_obj=input_obj,
|
| 598 |
+
output_template=output_template,
|
| 599 |
+
)
|
| 600 |
+
return output_obj
|
| 601 |
+
|
| 602 |
+
@nvtx.instrument_nvtx
|
| 603 |
+
def capture_and_cache(
|
| 604 |
+
self,
|
| 605 |
+
func: Callable,
|
| 606 |
+
input_obj: Any,
|
| 607 |
+
layer_number: int,
|
| 608 |
+
input_static_sig: StaticSignature,
|
| 609 |
+
input_dynamic_sig: DynamicSignature,
|
| 610 |
+
) -> Any:
|
| 611 |
+
"""Capture a new CUDA Graph and cache it."""
|
| 612 |
+
# Access static tensors from cache
|
| 613 |
+
static_tensor_entry = self.cache[input_static_sig]
|
| 614 |
+
assert static_tensor_entry.input_tensors is not None
|
| 615 |
+
assert static_tensor_entry.output_tensors is not None
|
| 616 |
+
static_input_tensors = static_tensor_entry.input_tensors
|
| 617 |
+
static_output_tensors = static_tensor_entry.output_tensors
|
| 618 |
+
|
| 619 |
+
# Capture CUDA Graph
|
| 620 |
+
graph = self.wrapped_graph_capture(
|
| 621 |
+
func=func,
|
| 622 |
+
input_obj=input_obj,
|
| 623 |
+
static_input_tensors=static_input_tensors,
|
| 624 |
+
static_output_tensors=static_output_tensors,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
# Cache the captured graph
|
| 628 |
+
graph_entry = self.try_get_graph_entry(input_static_sig, input_dynamic_sig, layer_number)
|
| 629 |
+
if graph_entry:
|
| 630 |
+
graph_entry.graph = graph
|
| 631 |
+
graph_entry.inconsistent = False
|
| 632 |
+
graph_entry.invalid = False
|
| 633 |
+
else:
|
| 634 |
+
self.add_graph_entry(
|
| 635 |
+
input_static_sig=input_static_sig, input_dynamic_sig=input_dynamic_sig, layer_number=layer_number, graph=graph
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
@nvtx.instrument_nvtx
|
| 639 |
+
def if_need_expand_static_tensors(
|
| 640 |
+
self, static_tensors: List[torch.Tensor], new_tensors: List[torch.Tensor], input_static_sig: StaticSignature
|
| 641 |
+
) -> bool:
|
| 642 |
+
"""Judge whether static tensors need to be expanded based on new tensors."""
|
| 643 |
+
res = False
|
| 644 |
+
static_infos = input_static_sig.tensor_static_infos
|
| 645 |
+
|
| 646 |
+
if len(static_tensors) != len(new_tensors) or len(static_tensors) != len(static_infos):
|
| 647 |
+
raise AssertionError(
|
| 648 |
+
f"[CUDA Graph] Tensor count mismatch. {len(static_tensors)=}, {len(new_tensors)=}, {len(static_infos)=}"
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
for static_t, new_t, static_info in zip(static_tensors, new_tensors, static_infos):
|
| 652 |
+
if static_t.ndim != new_t.ndim:
|
| 653 |
+
raise AssertionError(f"[CUDA Graph] Rank mismatch. {static_t.shape=}, {new_t.shape=}")
|
| 654 |
+
if static_t.dtype != new_t.dtype:
|
| 655 |
+
raise AssertionError(f"[CUDA Graph] Dtype mismatch. {static_t.dtype=}, {new_t.dtype=}")
|
| 656 |
+
for i in range(static_t.ndim):
|
| 657 |
+
if static_info.shapes[i] != -1 and static_info.shapes[i] != new_t.shape[i]:
|
| 658 |
+
raise AssertionError(
|
| 659 |
+
f"[CUDA Graph] Static dimension mismatch. {static_t.shape=}, {new_t.shape=}, {static_info.shapes=}, dim={i}"
|
| 660 |
+
)
|
| 661 |
+
if static_t.shape[i] < new_t.shape[i]:
|
| 662 |
+
res = True
|
| 663 |
+
return res
|
| 664 |
+
|
| 665 |
+
@nvtx.instrument_nvtx
|
| 666 |
+
def get_expanded_static_tensors(
|
| 667 |
+
self, static_tensors: List[torch.Tensor], new_tensors: List[torch.Tensor]
|
| 668 |
+
) -> List[torch.Tensor]:
|
| 669 |
+
"""Get expanded static tensors based on new tensors. Reuses existing tensors when possible."""
|
| 670 |
+
expanded_tensors = []
|
| 671 |
+
for static_t, new_t in zip(static_tensors, new_tensors):
|
| 672 |
+
if static_t.ndim != new_t.ndim:
|
| 673 |
+
raise AssertionError(
|
| 674 |
+
f"[CUDA Graph] Rank mismatch during expansion. Static: {static_t.shape}, New: {new_t.shape}"
|
| 675 |
+
)
|
| 676 |
+
new_shape = tuple(max(s, n) for s, n in zip(static_t.shape, new_t.shape))
|
| 677 |
+
|
| 678 |
+
if static_t.shape == new_shape:
|
| 679 |
+
expanded_tensors.append(static_t)
|
| 680 |
+
elif new_shape == new_t.shape:
|
| 681 |
+
expanded_tensors.append(new_t)
|
| 682 |
+
else:
|
| 683 |
+
expanded_tensor = torch.empty(new_shape, dtype=static_t.dtype, device=static_t.device)
|
| 684 |
+
expanded_tensors.append(expanded_tensor)
|
| 685 |
+
return expanded_tensors
|
| 686 |
+
|
| 687 |
+
@nvtx.instrument_nvtx
|
| 688 |
+
def try_replay_graph_inline(
|
| 689 |
+
self, func: Callable, args: Tuple, kwargs: Dict, layer_number: int
|
| 690 |
+
) -> Tuple[bool, Optional[Union[torch.Tensor, List[torch.Tensor]]]]:
|
| 691 |
+
"""Try to replay the CUDA Graph inline for fast execution."""
|
| 692 |
+
try:
|
| 693 |
+
func_name = func.__qualname__
|
| 694 |
+
input_obj = {"args": args, "kwargs": kwargs}
|
| 695 |
+
|
| 696 |
+
input_tensors, input_tensor_names, literals = ArgsUtils.try_fx_extract_core(input_obj)
|
| 697 |
+
if None in (input_tensors, input_tensor_names, literals):
|
| 698 |
+
input_tensors, input_tensor_names, literals = ArgsUtils.recursive_extract_core(input_obj)
|
| 699 |
+
input_static_sig, input_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
|
| 700 |
+
func_name, input_tensors, input_tensor_names, literals
|
| 701 |
+
)
|
| 702 |
+
static_tensor_entry = self.cache[input_static_sig]
|
| 703 |
+
static_input_tensors = static_tensor_entry.input_tensors
|
| 704 |
+
static_output_tensors = static_tensor_entry.output_tensors
|
| 705 |
+
|
| 706 |
+
template_entry = static_tensor_entry.template_entry_dict[input_dynamic_sig]
|
| 707 |
+
output_template = template_entry.output_template
|
| 708 |
+
|
| 709 |
+
graph_entry = template_entry.graph_entry_dict[layer_number]
|
| 710 |
+
graph = graph_entry.graph
|
| 711 |
+
|
| 712 |
+
assert graph is not None, "CUDA Graph not found for inline replay."
|
| 713 |
+
assert graph_entry.inconsistent is False, "CUDA Graph marked as inconsistent for inline replay."
|
| 714 |
+
assert graph_entry.invalid is False, "CUDA Graph marked as invalid for inline replay."
|
| 715 |
+
|
| 716 |
+
ArgsUtils.replace_sliced_with_static_simple(input_tensors, static_input_tensors)
|
| 717 |
+
graph.replay()
|
| 718 |
+
output_obj = ArgsUtils.replace_static_with_sliced(output_template, static_output_tensors)
|
| 719 |
+
|
| 720 |
+
if self.check_output_inconsistency:
|
| 721 |
+
cur_output_tensors, cur_output_tensor_names, cur_output_literals = ArgsUtils.recursive_extract_core(output_obj)
|
| 722 |
+
cur_output_static_sig, cur_output_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
|
| 723 |
+
func.__qualname__, cur_output_tensors, cur_output_tensor_names, cur_output_literals
|
| 724 |
+
)
|
| 725 |
+
output_tensors, output_tensor_names, output_literals = ArgsUtils.recursive_extract_core(output_obj)
|
| 726 |
+
cached_output_static_sig, cached_output_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
|
| 727 |
+
func.__qualname__, output_tensors, output_tensor_names, output_literals
|
| 728 |
+
)
|
| 729 |
+
if cur_output_static_sig != cached_output_static_sig or cur_output_dynamic_sig != cached_output_dynamic_sig:
|
| 730 |
+
magi_logger.warning(
|
| 731 |
+
f"[CUDA Graph] Warning: Output signature changed during inline replay. {func.__qualname__=}, {layer_number=}"
|
| 732 |
+
)
|
| 733 |
+
self.set_graph_inconsistent(input_static_sig, input_dynamic_sig, layer_number)
|
| 734 |
+
return False, None
|
| 735 |
+
return True, output_obj
|
| 736 |
+
except KeyError:
|
| 737 |
+
return False, None
|
| 738 |
+
except AssertionError:
|
| 739 |
+
return False, None
|
| 740 |
+
except Exception as e:
|
| 741 |
+
magi_logger.info(
|
| 742 |
+
f"[CUDA Graph] Exception during inline replay: {e=}, {func.__qualname__=}, {layer_number=}", rank="all"
|
| 743 |
+
)
|
| 744 |
+
raise e
|
| 745 |
+
|
| 746 |
+
@nvtx.instrument_nvtx
|
| 747 |
+
def run(self, func: Callable, *args, layer_number: Optional[int], **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 748 |
+
"""Run the function with CUDA Graph optimization if possible."""
|
| 749 |
+
|
| 750 |
+
# Try inline replay first
|
| 751 |
+
success, output_obj = self.try_replay_graph_inline(func=func, args=args, kwargs=kwargs, layer_number=layer_number)
|
| 752 |
+
if success:
|
| 753 |
+
# print_rank_0(f"[CUDA Graph] Current cache stats: {self.tensor_entry_count=}, {self.graph_count=}.")
|
| 754 |
+
return output_obj
|
| 755 |
+
|
| 756 |
+
# Extract input signatures
|
| 757 |
+
func_name = func.__qualname__
|
| 758 |
+
input_obj = {"args": args, "kwargs": kwargs}
|
| 759 |
+
input_tensors, input_tensor_names, literals = ArgsUtils.recursive_extract_core(input_obj)
|
| 760 |
+
input_static_sig, input_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
|
| 761 |
+
func_name, input_tensors, input_tensor_names, literals
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# Judge if the graph is marked as inconsistent
|
| 765 |
+
graph_entry = self.try_get_graph_entry(input_static_sig, input_dynamic_sig, layer_number)
|
| 766 |
+
if graph_entry is not None and graph_entry.inconsistent:
|
| 767 |
+
return func(*args, **kwargs)
|
| 768 |
+
|
| 769 |
+
# Judge if need to expand static tensors
|
| 770 |
+
if_need_expand_static_tensors = False
|
| 771 |
+
if_cached_tensor_entry = input_static_sig in self.cache
|
| 772 |
+
if if_cached_tensor_entry:
|
| 773 |
+
static_input_tensors, static_output_tensors = self.get_static_tensors(input_static_sig)
|
| 774 |
+
if_need_expand_static_tensors = self.if_need_expand_static_tensors(
|
| 775 |
+
static_input_tensors, input_tensors, input_static_sig
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# Warmup run
|
| 779 |
+
warmup_output_obj = self.warmup_run(func, *args, **kwargs)
|
| 780 |
+
|
| 781 |
+
# Check input signature consistency after warmup
|
| 782 |
+
warmup_input_tensors, warmup_input_tensor_names, warmup_literals = ArgsUtils.recursive_extract_core(input_obj)
|
| 783 |
+
warmup_input_static_sig, warmup_input_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
|
| 784 |
+
func_name, warmup_input_tensors, warmup_input_tensor_names, warmup_literals
|
| 785 |
+
)
|
| 786 |
+
if warmup_input_static_sig != input_static_sig or warmup_input_dynamic_sig != input_dynamic_sig:
|
| 787 |
+
magi_logger.warning(
|
| 788 |
+
f"[CUDA Graph] Warning: Input signature changed during warmup run. {func_name=}, {layer_number=}"
|
| 789 |
+
)
|
| 790 |
+
self.set_graph_inconsistent(input_static_sig, input_dynamic_sig, layer_number)
|
| 791 |
+
return warmup_output_obj
|
| 792 |
+
|
| 793 |
+
# Update cache entries
|
| 794 |
+
if if_cached_tensor_entry:
|
| 795 |
+
if if_need_expand_static_tensors:
|
| 796 |
+
output_tensors, _, _ = ArgsUtils.recursive_extract_core(warmup_output_obj, extract_literals=False)
|
| 797 |
+
# Need to expand static tensors
|
| 798 |
+
new_static_input_tensors = self.get_expanded_static_tensors(static_input_tensors, input_tensors)
|
| 799 |
+
new_static_output_tensors = self.get_expanded_static_tensors(static_output_tensors, output_tensors)
|
| 800 |
+
# Register as new cache entries
|
| 801 |
+
self.batch_set_graph_invalid(input_static_sig)
|
| 802 |
+
self.cache[input_static_sig].input_tensors = new_static_input_tensors
|
| 803 |
+
self.cache[input_static_sig].output_tensors = new_static_output_tensors
|
| 804 |
+
|
| 805 |
+
self.add_template_entry(input_static_sig, input_dynamic_sig, warmup_output_obj)
|
| 806 |
+
else:
|
| 807 |
+
# Simply reuse existing static tensor entry
|
| 808 |
+
static_tensor_entry = self.cache[input_static_sig]
|
| 809 |
+
if input_dynamic_sig not in static_tensor_entry.template_entry_dict:
|
| 810 |
+
self.add_template_entry(input_static_sig, input_dynamic_sig, warmup_output_obj)
|
| 811 |
+
|
| 812 |
+
else:
|
| 813 |
+
# Create new static tensor entry
|
| 814 |
+
output_tensors, _, _ = ArgsUtils.recursive_extract_core(warmup_output_obj, extract_literals=False)
|
| 815 |
+
self.add_static_entry(input_static_sig, input_tensors, output_tensors)
|
| 816 |
+
self.add_template_entry(input_static_sig, input_dynamic_sig, warmup_output_obj)
|
| 817 |
+
|
| 818 |
+
# Capture and cache new CUDA Graph
|
| 819 |
+
self.capture_and_cache(
|
| 820 |
+
func=func,
|
| 821 |
+
input_obj=input_obj,
|
| 822 |
+
layer_number=layer_number,
|
| 823 |
+
input_static_sig=input_static_sig,
|
| 824 |
+
input_dynamic_sig=input_dynamic_sig,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
magi_logger.info(
|
| 828 |
+
f"[CUDA Graph] Current cache stats: {self.tensor_entry_count=}, {self.graph_count=}, {self.tensor_mem_size=:.2f} MB, {self.graph_mem_pool_size=:.2f} MB"
|
| 829 |
+
)
|
| 830 |
+
return warmup_output_obj
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
_IS_GRAPH_CAPTURING = False
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def _is_graph_capturing():
|
| 837 |
+
"""Query if currently capturing."""
|
| 838 |
+
global _IS_GRAPH_CAPTURING
|
| 839 |
+
return _IS_GRAPH_CAPTURING
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def _set_capture_start():
|
| 843 |
+
"""Set graph capture has started."""
|
| 844 |
+
global _IS_GRAPH_CAPTURING
|
| 845 |
+
_IS_GRAPH_CAPTURING = True
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def _set_capture_end():
|
| 849 |
+
"""Set graph capture has ended."""
|
| 850 |
+
global _IS_GRAPH_CAPTURING
|
| 851 |
+
_IS_GRAPH_CAPTURING = False
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
# Singleton instance of CudaGraphMgr
|
| 855 |
+
_CUDA_GRAPH_MGR = CudaGraphMgr()
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def cuda_graph_mgr() -> CudaGraphMgr:
|
| 859 |
+
"""
|
| 860 |
+
Get the current CudaGraphMgr instance.
|
| 861 |
+
Returns:
|
| 862 |
+
CudaGraphMgr: The current CudaGraphMgr instance.
|
| 863 |
+
Raises:
|
| 864 |
+
AssertionError: If the CudaGraphMgr has not been initialized.
|
| 865 |
+
"""
|
| 866 |
+
assert _CUDA_GRAPH_MGR is not None, "cuda graph manager is not initialized"
|
| 867 |
+
return _CUDA_GRAPH_MGR
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def cuda_graph_enable_if(condition: Callable):
|
| 871 |
+
def decorator(func):
|
| 872 |
+
"""
|
| 873 |
+
Decorator to enable CUDA graph option for a function. The function will be executed using CUDA Graph if the condition func provided outputs True.
|
| 874 |
+
Args:
|
| 875 |
+
condition (Callable): A callable that returns a bool indicating whether enable CUDA Graph.
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
@wraps(func)
|
| 879 |
+
def wrapped_func(*args, **kwargs):
|
| 880 |
+
enable_cuda_graph = condition()
|
| 881 |
+
if not enable_cuda_graph or _is_graph_capturing():
|
| 882 |
+
return func(*args, **kwargs)
|
| 883 |
+
|
| 884 |
+
layer_number = getattr(args[0], "layer_number", None) if args else None
|
| 885 |
+
|
| 886 |
+
return cuda_graph_mgr().run(func, *args, layer_number=layer_number, **kwargs)
|
| 887 |
+
|
| 888 |
+
return wrapped_func
|
| 889 |
+
|
| 890 |
+
return decorator
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def gen_wrap_func_for_cudagraph(func: Callable, mode_prefix: str, target_prefix=None) -> Callable:
|
| 894 |
+
"""
|
| 895 |
+
Wrap the given function for CUDA Graph:
|
| 896 |
+
1. Generate a unique __qualname__ for caching
|
| 897 |
+
2. Built-in call to cuda_graph_mgr().run
|
| 898 |
+
"""
|
| 899 |
+
# Generate a unique identifier to avoid cache conflicts
|
| 900 |
+
func_id = id(func) if not hasattr(func, "__name__") else func.__name__
|
| 901 |
+
if mode_prefix == "full":
|
| 902 |
+
wrapped_func_name = f"Athena_CUDAGraph_{mode_prefix}_{func_id}"
|
| 903 |
+
else: # piecewise
|
| 904 |
+
wrapped_func_name = f"Athena_CUDAGraph_{mode_prefix}_{target_prefix}_{func_id}"
|
| 905 |
+
|
| 906 |
+
@nvtx.instrument_nvtx
|
| 907 |
+
def wrapped_func(*args, **kwargs):
|
| 908 |
+
layer_number = kwargs.pop("layer_number", None)
|
| 909 |
+
res = cuda_graph_mgr().run(func, *args, layer_number=layer_number, **kwargs)
|
| 910 |
+
return res
|
| 911 |
+
|
| 912 |
+
func.__qualname__ = wrapped_func_name
|
| 913 |
+
magi_logger.info(f"Set original function qualname to {wrapped_func_name} for CUDA Graph caching.")
|
| 914 |
+
|
| 915 |
+
# Copy attributes from the original function to the wrapped function
|
| 916 |
+
wrapped_func.__dict__.update(func.__dict__)
|
| 917 |
+
wrapped_func.__qualname__ = wrapped_func_name
|
| 918 |
+
for attr in ["__is_first_graph", "__is_last_graph", "__sym_shape_indices"]:
|
| 919 |
+
if hasattr(func, attr):
|
| 920 |
+
setattr(wrapped_func, attr, getattr(func, attr))
|
| 921 |
+
|
| 922 |
+
return wrapped_func
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def init_cudagraph_global_pool():
|
| 926 |
+
"""Initialize the global CUDA graph memory pool if not already initialized."""
|
| 927 |
+
from magi_compiler.cuda_graph_mgr import cuda_graph_mgr
|
| 928 |
+
|
| 929 |
+
if cuda_graph_mgr().graph_mem_pool is None:
|
| 930 |
+
cuda_graph_mgr().graph_mem_pool = torch.cuda.graph_pool_handle()
|
| 931 |
+
magi_logger.info("Initialized global CUDA graph pool for Athena.")
|
pkgs/MagiCompiler/magi_compiler/joint_graph_partition.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any, Optional, Sequence, Tuple
|
| 17 |
+
from unittest.mock import patch
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.fx as fx
|
| 21 |
+
from torch._functorch.compile_utils import get_aten_target
|
| 22 |
+
from torch._functorch.partitioners import NodeInfo, OpTypes, get_default_op_list, min_cut_rematerialization_partition
|
| 23 |
+
from torch._inductor.custom_graph_pass import CustomPartitionerFn
|
| 24 |
+
from torch.utils._ordered_set import OrderedSet
|
| 25 |
+
|
| 26 |
+
# from magi_compiler.partitioners import min_cut_rematerialization_partition
|
| 27 |
+
from .config import RecomputePolicy, get_compile_config
|
| 28 |
+
from .utils import compute_code_hash, magi_logger
|
| 29 |
+
from .utils.visualize import joint_graph_vis
|
| 30 |
+
|
| 31 |
+
SAVE_TENSOR_NODES: Optional[list[fx.Node]] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def is_memory_increase_by_node(node: fx.Node) -> bool:
|
| 35 |
+
# Only support aten.to now
|
| 36 |
+
assert get_aten_target(node) == torch.ops.prims.convert_element_type
|
| 37 |
+
input_dtype = node.args[0].meta["tensor_meta"].dtype
|
| 38 |
+
output_dtype = node.args[1]
|
| 39 |
+
assert output_dtype is not None
|
| 40 |
+
return output_dtype.itemsize > input_dtype.itemsize
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def is_primal_contribute_to_bwd_directly(primal_node: fx.Node, node_info: NodeInfo, op_types: OpTypes) -> bool:
|
| 44 |
+
"""
|
| 45 |
+
FSDP ensures that weights already reside in memory. If there exists a path from the primal to the bwd, and the path does not contain any matmul, then the primal contributes to the bwd directly.
|
| 46 |
+
And we should save this primals.
|
| 47 |
+
"""
|
| 48 |
+
if node_info.is_required_bw(primal_node):
|
| 49 |
+
return True
|
| 50 |
+
topology_start = set({primal_node})
|
| 51 |
+
|
| 52 |
+
while len(topology_start) > 0:
|
| 53 |
+
cur_node = topology_start.pop()
|
| 54 |
+
for user in cur_node.users:
|
| 55 |
+
if node_info.is_required_bw(user):
|
| 56 |
+
return True
|
| 57 |
+
if op_types.is_compute_intensive(user):
|
| 58 |
+
continue
|
| 59 |
+
topology_start.add(user)
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def is_compute_intensive_and_has_following_recomputable_ops(
|
| 64 |
+
intermidiate_node: fx.Node, node_info: NodeInfo, op_types: OpTypes
|
| 65 |
+
) -> Tuple[bool, fx.Node]:
|
| 66 |
+
"""
|
| 67 |
+
If compute-intensive node(CIN) is not the output of fwd graph(has following memory-intensive ops in the fwd graph), then we should save this CIN node.
|
| 68 |
+
NOTE: For CIN+aten.to, we should save aten.to op instead of CIN op to save more memory.
|
| 69 |
+
"""
|
| 70 |
+
if not op_types.is_compute_intensive(intermidiate_node):
|
| 71 |
+
return False, None
|
| 72 |
+
|
| 73 |
+
save_node = intermidiate_node
|
| 74 |
+
topology_start = set({save_node})
|
| 75 |
+
while len(topology_start) > 0:
|
| 76 |
+
cur_node = topology_start.pop()
|
| 77 |
+
fwd_user_nodes = []
|
| 78 |
+
for user in cur_node.users:
|
| 79 |
+
if node_info.is_required_fw(user):
|
| 80 |
+
fwd_user_nodes.append(user)
|
| 81 |
+
|
| 82 |
+
if len(fwd_user_nodes) > 1: # multiple users, save current node
|
| 83 |
+
return True, save_node
|
| 84 |
+
elif len(fwd_user_nodes) == 0: # output, return
|
| 85 |
+
return False, None
|
| 86 |
+
|
| 87 |
+
# save current node if it's user is recomputable
|
| 88 |
+
next_node = fwd_user_nodes[0]
|
| 89 |
+
if op_types.is_view(next_node):
|
| 90 |
+
if save_node == cur_node:
|
| 91 |
+
save_node = next_node
|
| 92 |
+
topology_start.add(next_node)
|
| 93 |
+
# Special case for aten.to, memory efficient case
|
| 94 |
+
elif get_aten_target(next_node) == torch.ops.prims.convert_element_type:
|
| 95 |
+
is_memory_increase = is_memory_increase_by_node(next_node)
|
| 96 |
+
if not is_memory_increase:
|
| 97 |
+
save_node = next_node
|
| 98 |
+
topology_start.add(next_node)
|
| 99 |
+
elif next_node.op == "output":
|
| 100 |
+
return False, None
|
| 101 |
+
else:
|
| 102 |
+
return True, save_node
|
| 103 |
+
assert False, f"Should not reach here: {intermidiate_node=} {save_node=}"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# TODO: We find an elegant impl to heuristically save nodes, reconstruct this later
|
| 107 |
+
def heuristic_choose_saved_values_set(joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1) -> list[fx.Node]:
|
| 108 |
+
output: OrderedSet[fx.Node] = OrderedSet()
|
| 109 |
+
op_types = get_default_op_list()
|
| 110 |
+
# Select the inputs that are required by the backward pass
|
| 111 |
+
for primal_node in node_info.inputs:
|
| 112 |
+
if is_primal_contribute_to_bwd_directly(primal_node, node_info, op_types):
|
| 113 |
+
output.add(primal_node)
|
| 114 |
+
magi_logger.info("MagiCompiler: saved_output forward-input = %s", output)
|
| 115 |
+
# Select the compute-intensive nodes that are required by the forward pass
|
| 116 |
+
for intermidiate_node in node_info.required_fw_nodes:
|
| 117 |
+
is_save, save_node = is_compute_intensive_and_has_following_recomputable_ops(intermidiate_node, node_info, op_types)
|
| 118 |
+
if is_save:
|
| 119 |
+
output.add(save_node)
|
| 120 |
+
magi_logger.info("MagiCompiler: saved_output compute-intensive = %s", output)
|
| 121 |
+
global SAVE_TENSOR_NODES
|
| 122 |
+
SAVE_TENSOR_NODES = list(output)
|
| 123 |
+
return list(output)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def custom_joint_graph_partition_fn(
|
| 127 |
+
joint_module: fx.GraphModule,
|
| 128 |
+
_joint_inputs,
|
| 129 |
+
compiler="inductor",
|
| 130 |
+
*,
|
| 131 |
+
num_fwd_outputs,
|
| 132 |
+
static_lifetime_input_indices: Optional[list[int]] = None,
|
| 133 |
+
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
| 134 |
+
recompute_config = get_compile_config().recompute_config
|
| 135 |
+
if recompute_config.recompute_policy == RecomputePolicy.HANDCRAFT:
|
| 136 |
+
magi_logger.info("MagiCompiler using handcraft recompute policy")
|
| 137 |
+
# TODO: different memory budget definition from torch
|
| 138 |
+
with patch("torch._functorch.config.activation_memory_budget", recompute_config.memory_budget):
|
| 139 |
+
fwd_module, bwd_module = min_cut_rematerialization_partition(
|
| 140 |
+
joint_module,
|
| 141 |
+
_joint_inputs,
|
| 142 |
+
compiler,
|
| 143 |
+
num_fwd_outputs=num_fwd_outputs,
|
| 144 |
+
static_lifetime_input_indices=static_lifetime_input_indices,
|
| 145 |
+
)
|
| 146 |
+
elif recompute_config.recompute_policy == RecomputePolicy.HEURISTIC:
|
| 147 |
+
magi_logger.info("MagiCompiler using heuristic recompute policy")
|
| 148 |
+
with patch("torch._functorch.partitioners.choose_saved_values_set", heuristic_choose_saved_values_set):
|
| 149 |
+
fwd_module, bwd_module = min_cut_rematerialization_partition(
|
| 150 |
+
joint_module,
|
| 151 |
+
_joint_inputs,
|
| 152 |
+
compiler,
|
| 153 |
+
num_fwd_outputs=num_fwd_outputs,
|
| 154 |
+
static_lifetime_input_indices=static_lifetime_input_indices,
|
| 155 |
+
)
|
| 156 |
+
elif recompute_config.recompute_policy == RecomputePolicy.AUTOSEARCH:
|
| 157 |
+
raise ValueError(f"AutoSearch recompute policy is not supported yet")
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"Invalid recompute policy: {recompute_config.recompute_policy}")
|
| 160 |
+
|
| 161 |
+
joint_graph_vis(joint_module, fwd_module, bwd_module, save_tensor_nodes=SAVE_TENSOR_NODES)
|
| 162 |
+
|
| 163 |
+
return fwd_module, bwd_module
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class CustomJointGraphPartitionFn(CustomPartitionerFn):
|
| 167 |
+
def __call__(
|
| 168 |
+
self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any
|
| 169 |
+
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
| 170 |
+
"""
|
| 171 |
+
Implementation of the custom partitioner.
|
| 172 |
+
"""
|
| 173 |
+
return custom_joint_graph_partition_fn(gm, joint_inputs, **kwargs)
|
| 174 |
+
|
| 175 |
+
def uuid(self) -> Optional[Any]:
|
| 176 |
+
"""
|
| 177 |
+
Return an ID to uniquely identify your custom partitioner implementation.
|
| 178 |
+
Return None to skip inductor code caching entirely.
|
| 179 |
+
"""
|
| 180 |
+
return compute_code_hash({os.path.abspath(__file__)})
|
pkgs/MagiCompiler/magi_compiler/magi_backend.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import ast
|
| 17 |
+
import dataclasses
|
| 18 |
+
import pprint
|
| 19 |
+
import time
|
| 20 |
+
from collections.abc import Callable
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
import magi_compiler.utils.envs as envs
|
| 26 |
+
import torch
|
| 27 |
+
import torch.fx as fx
|
| 28 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 29 |
+
from torch._dynamo.utils import lazy_format_graph_code
|
| 30 |
+
from torch._guards import detect_fake_mode
|
| 31 |
+
|
| 32 |
+
from ._cache_data_cls import CacheEntry, CacheHandle
|
| 33 |
+
from .compile_artifacts import MagiSerializableFunction
|
| 34 |
+
from .config import CompileConfig, CompileMode, CudaGraphMode
|
| 35 |
+
from .cuda_graph_mgr import gen_wrap_func_for_cudagraph
|
| 36 |
+
from .joint_graph_partition import CustomJointGraphPartitionFn
|
| 37 |
+
from .offload.offload_warpper import OffloadWrapper
|
| 38 |
+
from .partition_rules import inductor_partition_rule_context, resolve_defined_ops
|
| 39 |
+
from .passes import PostGradPassManager
|
| 40 |
+
from .passes.inductor_pass import pass_context
|
| 41 |
+
from .passes.replace_pass import FullGraphPassManager
|
| 42 |
+
from .piecewise_backend import PiecewiseBackend
|
| 43 |
+
from .piecewise_compiler import CompilerInterface, EagerAdaptor, InductorStandaloneAdaptor
|
| 44 |
+
from .utils import (
|
| 45 |
+
CompileMonitor,
|
| 46 |
+
compilation_counter,
|
| 47 |
+
compute_code_hash,
|
| 48 |
+
compute_hash,
|
| 49 |
+
detect_symbolic_tensor_indices,
|
| 50 |
+
magi_logger,
|
| 51 |
+
)
|
| 52 |
+
from .utils.envs import MAGI_CUSTOM_PARTITIONER_FN, MAGI_MODEL_TAG, MAGI_POST_GRAD_PASS
|
| 53 |
+
from .utils.visualize import save_fx_graph_visualization
|
| 54 |
+
|
| 55 |
+
compilation_start_time: float = 0.0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _print_with_shape_and_time(runtime_shape: int | None, prefix: str = ""):
|
| 59 |
+
elapsed = time.time() - compilation_start_time
|
| 60 |
+
if runtime_shape is None:
|
| 61 |
+
magi_logger.info("%s for dynamic shape, took %.3f s", prefix, elapsed)
|
| 62 |
+
else:
|
| 63 |
+
magi_logger.info("%s for shape %s, took %.3f s", prefix, str(runtime_shape), elapsed)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclasses.dataclass
|
| 67 |
+
class SplitItem:
|
| 68 |
+
submod_name: str
|
| 69 |
+
graph_id: int
|
| 70 |
+
is_splitting_graph: bool
|
| 71 |
+
graph: fx.GraphModule
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def make_compiler(compile_config: CompileConfig) -> CompilerInterface:
|
| 75 |
+
if compile_config.backend == "inductor":
|
| 76 |
+
# Use standalone_compile with PyTorch 2.8+
|
| 77 |
+
assert hasattr(torch._inductor, "standalone_compile"), "standalone_compile not found in PyTorch Inductor"
|
| 78 |
+
magi_logger.info("Using InductorStandaloneAdaptor")
|
| 79 |
+
return InductorStandaloneAdaptor()
|
| 80 |
+
else:
|
| 81 |
+
assert compile_config.backend == "eager", f"Invalid backend for MagiCompiler: {compile_config.backend}"
|
| 82 |
+
magi_logger.info("Using EagerAdaptor")
|
| 83 |
+
return EagerAdaptor()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CompilerManager:
|
| 87 |
+
"""
|
| 88 |
+
Manage the compilation process, including graph compilation, compile artifacts caching and loading.
|
| 89 |
+
|
| 90 |
+
The cache is a dict mapping `(runtime_shape, graph_index, backend_name)` to `any_data` returned from the compiler.
|
| 91 |
+
|
| 92 |
+
When serializing the cache, we save it to a Python file for readability. We don't use json here because json doesn't support int as key.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, compile_config: CompileConfig):
|
| 96 |
+
self.cache: dict[CacheEntry, CacheHandle] = dict()
|
| 97 |
+
self.compile_config = compile_config
|
| 98 |
+
self.compiler = make_compiler(compile_config)
|
| 99 |
+
self.disable_cache = envs.MAGI_DISABLE_COMPILE_CACHE
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def hash(self) -> str:
|
| 103 |
+
return self.compiler.hash
|
| 104 |
+
|
| 105 |
+
@contextmanager
|
| 106 |
+
def compile_context(self, runtime_shape: int | None = None):
|
| 107 |
+
"""Provide compilation context for the duration of compilation to set
|
| 108 |
+
any torch global properties we want to scope to a single Inductor
|
| 109 |
+
compilation (e.g. partition rules, pass context)."""
|
| 110 |
+
with pass_context(runtime_shape):
|
| 111 |
+
if self.compile_config.use_inductor_graph_partition:
|
| 112 |
+
inductor_partition_ops = resolve_defined_ops(self.compile_config.splitting_ops)
|
| 113 |
+
with inductor_partition_rule_context(inductor_partition_ops):
|
| 114 |
+
yield
|
| 115 |
+
else:
|
| 116 |
+
yield
|
| 117 |
+
|
| 118 |
+
def initialize_cache(self, cache_dir: Path, prefix: str = ""):
|
| 119 |
+
"""
|
| 120 |
+
Initialize the cache directory for the compiler.
|
| 121 |
+
|
| 122 |
+
The organization of the cache directory is as follows:
|
| 123 |
+
cache_dir=/path/to/torch_compile_cache/rank_i_j/hash_str/prefix/
|
| 124 |
+
inside cache_dir, there will be:
|
| 125 |
+
- magi_compile_cache.py
|
| 126 |
+
- computation_graph.py
|
| 127 |
+
|
| 128 |
+
for multiple prefixes, they can share the same base cache dir of
|
| 129 |
+
/path/to/torch_compile_cache/rank_i_j/hash_str/ to store some
|
| 130 |
+
common compilation artifacts.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
self.cache_dir: Path = cache_dir
|
| 134 |
+
self.cache_file_path: Path = cache_dir / "magi_compile_cache.py"
|
| 135 |
+
|
| 136 |
+
if self.disable_cache:
|
| 137 |
+
magi_logger.info("MagiCompiler's cache is disabled.")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
magi_logger.info("Using cache directory: %s for MagiCompiler", cache_dir)
|
| 141 |
+
if self.cache_file_path.exists():
|
| 142 |
+
# load the cache from the file
|
| 143 |
+
with self.cache_file_path.open() as f:
|
| 144 |
+
# Parse Python literals using ast.literal_eval, which is a safe alternative to eval().
|
| 145 |
+
raw = ast.literal_eval(f.read())
|
| 146 |
+
self.cache = {CacheEntry(*entry): CacheHandle(*handle) for entry, handle in raw.items()}
|
| 147 |
+
|
| 148 |
+
self.compiler.initialize_cache(cache_dir=self.cache_dir, prefix=prefix)
|
| 149 |
+
|
| 150 |
+
def save_to_file(self):
|
| 151 |
+
if self.disable_cache:
|
| 152 |
+
return
|
| 153 |
+
# serialize to a literal-friendly dict
|
| 154 |
+
serializable = {(e.runtime_shape, e.graph_index, e.backend_name): (h.key, h.path) for e, h in self.cache.items()}
|
| 155 |
+
printer = pprint.PrettyPrinter(indent=4)
|
| 156 |
+
data = printer.pformat(serializable)
|
| 157 |
+
with self.cache_file_path.open("w") as f:
|
| 158 |
+
f.write(data)
|
| 159 |
+
|
| 160 |
+
def load(self, graph: fx.GraphModule, example_inputs: list[Any], cache_entry: CacheEntry) -> Callable | None:
|
| 161 |
+
if cache_entry not in self.cache:
|
| 162 |
+
return None
|
| 163 |
+
cache_handle = self.cache[cache_entry]
|
| 164 |
+
_print_with_shape_and_time(
|
| 165 |
+
cache_entry.runtime_shape,
|
| 166 |
+
f"Directly load the {cache_entry.graph_index}-th graph from {cache_entry.backend_name} via handle {cache_handle}",
|
| 167 |
+
)
|
| 168 |
+
return self.compiler.load(graph, example_inputs, cache_entry, cache_handle)
|
| 169 |
+
|
| 170 |
+
# TODO(hongyu): Support training mode here
|
| 171 |
+
def compile(
|
| 172 |
+
self,
|
| 173 |
+
graph: fx.GraphModule,
|
| 174 |
+
example_inputs: tuple[torch.fx.node.Argument, ...],
|
| 175 |
+
compile_config: CompileConfig,
|
| 176 |
+
graph_index: int = 0,
|
| 177 |
+
num_graphs: int = 1,
|
| 178 |
+
runtime_shape: int | None = None,
|
| 179 |
+
) -> Callable:
|
| 180 |
+
# Step0: update some global metrics
|
| 181 |
+
compilation_counter.num_backend_compilations += 1
|
| 182 |
+
if graph_index == 0:
|
| 183 |
+
global compilation_start_time
|
| 184 |
+
compilation_start_time = time.time()
|
| 185 |
+
|
| 186 |
+
# Step1: Try loading from the cache
|
| 187 |
+
cache_entry = CacheEntry(runtime_shape, graph_index, self.compiler.name)
|
| 188 |
+
compiled_graph = self.load(graph, example_inputs, cache_entry)
|
| 189 |
+
if compiled_graph is not None:
|
| 190 |
+
return compiled_graph
|
| 191 |
+
|
| 192 |
+
# Step2: Compile the graph
|
| 193 |
+
key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
| 194 |
+
with self.compile_context(runtime_shape):
|
| 195 |
+
compiled_graph, cache_handle = self.compiler.compile(
|
| 196 |
+
graph, example_inputs, compile_config.inductor_compile_config, runtime_shape, key
|
| 197 |
+
)
|
| 198 |
+
assert compiled_graph is not None, "Failed to compile the graph"
|
| 199 |
+
|
| 200 |
+
# Step3: Store the artifact in the cache
|
| 201 |
+
if not self.disable_cache and cache_handle is not None:
|
| 202 |
+
assert cache_entry not in self.cache, "Cache entry already exists"
|
| 203 |
+
self.cache[cache_entry] = cache_handle
|
| 204 |
+
compilation_counter.num_cache_entries += 1
|
| 205 |
+
_print_with_shape_and_time(runtime_shape, f"Compile the {graph_index}/{num_graphs} graph")
|
| 206 |
+
|
| 207 |
+
return compiled_graph
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# TODO(hongyu): Support training mode here
|
| 211 |
+
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
| 212 |
+
"""
|
| 213 |
+
Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
| 214 |
+
It runs the given graph with fake inputs, and compile some submodules specified by `compile_submod_names` with compilation configs.
|
| 215 |
+
|
| 216 |
+
NOTE: the order in `compile_submod_names` matters, because it will be used to determine the order of the compiled piecewise graphs.
|
| 217 |
+
The first graph will handle logging, and the last graph has some special cudagraph output handling.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
module: torch.fx.GraphModule,
|
| 223 |
+
compiler_manager: CompilerManager,
|
| 224 |
+
compile_submod_names: list[str],
|
| 225 |
+
compile_config: CompileConfig,
|
| 226 |
+
):
|
| 227 |
+
super().__init__(module)
|
| 228 |
+
|
| 229 |
+
self.fake_mode = detect_fake_mode()
|
| 230 |
+
self.compiler_manager = compiler_manager
|
| 231 |
+
self.compile_submod_names = compile_submod_names
|
| 232 |
+
self.compile_config = compile_config
|
| 233 |
+
# extra_traceback is attribute of torch.fx.Interpreter, when it is True, it annoyingly dumps the torch.fx.Graph on errors.
|
| 234 |
+
self.extra_traceback = False
|
| 235 |
+
|
| 236 |
+
def _fix_graph_device_placement(self, module: torch.nn.Module):
|
| 237 |
+
for name, child in module.named_children():
|
| 238 |
+
self._fix_graph_device_placement(child)
|
| 239 |
+
|
| 240 |
+
if isinstance(module, torch.fx.GraphModule):
|
| 241 |
+
needs_recompile = False
|
| 242 |
+
target_device = torch.cuda.current_device()
|
| 243 |
+
|
| 244 |
+
factory_functions = [
|
| 245 |
+
torch.empty,
|
| 246 |
+
torch.zeros,
|
| 247 |
+
torch.ones,
|
| 248 |
+
torch.full,
|
| 249 |
+
torch.rand,
|
| 250 |
+
torch.randn,
|
| 251 |
+
torch.arange,
|
| 252 |
+
torch.tensor,
|
| 253 |
+
torch.ops.aten.empty.memory_format,
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
for node in module.graph.nodes:
|
| 257 |
+
if node.op == 'call_function':
|
| 258 |
+
is_factory = node.target in factory_functions or (
|
| 259 |
+
hasattr(node.target, '__name__') and node.target.__name__ in ['empty', 'zeros', 'ones', 'full']
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if is_factory:
|
| 263 |
+
if 'device' in node.kwargs:
|
| 264 |
+
current_dev = node.kwargs['device']
|
| 265 |
+
if str(current_dev) == 'cpu' or current_dev == torch.device('cpu'):
|
| 266 |
+
node.update_kwarg('device', target_device)
|
| 267 |
+
needs_recompile = True
|
| 268 |
+
|
| 269 |
+
if needs_recompile:
|
| 270 |
+
module.recompile()
|
| 271 |
+
|
| 272 |
+
def run(self, *args):
|
| 273 |
+
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
|
| 274 |
+
if self.compile_config.offload_config.model_cpu_offload:
|
| 275 |
+
self._fix_graph_device_placement(self.module)
|
| 276 |
+
for i, arg in enumerate(fake_args):
|
| 277 |
+
if isinstance(arg, torch.Tensor):
|
| 278 |
+
fake_args[i] = arg.cuda()
|
| 279 |
+
|
| 280 |
+
with self.fake_mode, enable_python_dispatcher():
|
| 281 |
+
return super().run(*fake_args)
|
| 282 |
+
|
| 283 |
+
def call_module(
|
| 284 |
+
self, target: torch.fx.node.Target, args: tuple[torch.fx.node.Argument, ...], kwargs: dict[str, Any]
|
| 285 |
+
) -> Any:
|
| 286 |
+
assert isinstance(target, str)
|
| 287 |
+
output = super().call_module(target, args, kwargs)
|
| 288 |
+
if target not in self.compile_submod_names:
|
| 289 |
+
return output
|
| 290 |
+
|
| 291 |
+
index = self.compile_submod_names.index(target)
|
| 292 |
+
submod = self.fetch_attr(target)
|
| 293 |
+
sym_shape_indices = [i for i, x in enumerate(args) if isinstance(x, torch.SymInt)]
|
| 294 |
+
magi_logger.info(f"Compiling {target=}, {sym_shape_indices=}, {args=}")
|
| 295 |
+
|
| 296 |
+
compiled_graph_for_dynamic_shape = self.compiler_manager.compile(
|
| 297 |
+
submod, args, self.compile_config, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
piecewise_backend = PiecewiseBackend(
|
| 301 |
+
submod,
|
| 302 |
+
compiled_graph_for_dynamic_shape,
|
| 303 |
+
self.compile_config,
|
| 304 |
+
index,
|
| 305 |
+
len(self.compile_submod_names),
|
| 306 |
+
sym_shape_indices,
|
| 307 |
+
self.compiler_manager,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if self.compile_config.use_inductor_graph_partition or self.compile_config.cudagraph_mode != CudaGraphMode.PIECEWISE:
|
| 311 |
+
self.module.__dict__[target] = piecewise_backend
|
| 312 |
+
else:
|
| 313 |
+
wrapped_backend = gen_wrap_func_for_cudagraph(
|
| 314 |
+
func=piecewise_backend, mode_prefix=CudaGraphMode.PIECEWISE.name.lower(), target_prefix=target
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.module.__dict__[target] = wrapped_backend
|
| 318 |
+
magi_logger.info(
|
| 319 |
+
f"Wrapped piecewise submodule {target} (index {index}) with CUDA Graph "
|
| 320 |
+
f"[PIECEWISE mode, first_graph={piecewise_backend.is_first_graph}, last_graph={piecewise_backend.is_last_graph}]"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return output
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class MagiBackend:
|
| 327 |
+
"""
|
| 328 |
+
The compilation backend for `torch.compile` with MagiCompiler.
|
| 329 |
+
It is used for compilation mode of `CompileMode.MAGI_COMPILE`,
|
| 330 |
+
where we customize the compilation.
|
| 331 |
+
|
| 332 |
+
The major work of this backend is to split the graph into
|
| 333 |
+
piecewise graphs, and pass them to the piecewise backend.
|
| 334 |
+
|
| 335 |
+
This backend also adds the PostGradPassManager to Inductor config,
|
| 336 |
+
which handles the post-grad passes.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
compile_config: CompileConfig
|
| 340 |
+
_called_once: bool = False
|
| 341 |
+
# for the graph we compiled
|
| 342 |
+
graph: fx.GraphModule
|
| 343 |
+
compiler_manager: CompilerManager
|
| 344 |
+
# for cudagraph
|
| 345 |
+
sym_tensor_indices: list[int] # indices for tensors that have symbolic shapes
|
| 346 |
+
input_buffers: list[torch.Tensor] # buffers for input tensors that have symbolic shapes
|
| 347 |
+
|
| 348 |
+
def __init__(self, compile_config: CompileConfig, model_tag: str = ""):
|
| 349 |
+
self.model_tag = model_tag or MAGI_MODEL_TAG
|
| 350 |
+
self.compile_config = compile_config
|
| 351 |
+
self._configure_custom_passes()
|
| 352 |
+
self.compiler_manager: CompilerManager = CompilerManager(self.compile_config)
|
| 353 |
+
|
| 354 |
+
self.sym_tensor_indices = []
|
| 355 |
+
self.input_buffers = []
|
| 356 |
+
|
| 357 |
+
def _configure_custom_passes(self):
|
| 358 |
+
# Custom pass 1: full graph passes between Dynamo and AOTAutograd
|
| 359 |
+
self.full_graph_pass_manager = FullGraphPassManager(self.compile_config.pass_config)
|
| 360 |
+
|
| 361 |
+
# Custom pass 2: custom partitioner function
|
| 362 |
+
custom_partitioner_fn = CustomJointGraphPartitionFn()
|
| 363 |
+
if MAGI_CUSTOM_PARTITIONER_FN in self.compile_config.inductor_compile_config:
|
| 364 |
+
existing_fn = self.compile_config.inductor_compile_config[MAGI_CUSTOM_PARTITIONER_FN]
|
| 365 |
+
assert isinstance(existing_fn, CustomJointGraphPartitionFn)
|
| 366 |
+
assert existing_fn.uuid() == custom_partitioner_fn.uuid()
|
| 367 |
+
self.compile_config.inductor_compile_config[MAGI_CUSTOM_PARTITIONER_FN] = custom_partitioner_fn
|
| 368 |
+
|
| 369 |
+
# Custom pass 3: post-grad passes after AOTAutograd
|
| 370 |
+
post_grad_pass_manager = PostGradPassManager()
|
| 371 |
+
post_grad_pass_manager.configure(self.compile_config)
|
| 372 |
+
|
| 373 |
+
# Run post-grad custom passes with post_grad_custom_post_pass hook
|
| 374 |
+
if MAGI_POST_GRAD_PASS in self.compile_config.inductor_compile_config:
|
| 375 |
+
existing_pass = self.compile_config.inductor_compile_config[MAGI_POST_GRAD_PASS]
|
| 376 |
+
assert isinstance(existing_pass, PostGradPassManager)
|
| 377 |
+
assert existing_pass.uuid() == post_grad_pass_manager.uuid()
|
| 378 |
+
|
| 379 |
+
self.compile_config.inductor_compile_config[MAGI_POST_GRAD_PASS] = post_grad_pass_manager
|
| 380 |
+
|
| 381 |
+
def _init_cache(self) -> str:
|
| 382 |
+
hash_key = compute_hash(
|
| 383 |
+
[self.compile_config.hash, self.compiler_manager.hash, compute_code_hash(self.compile_config.traced_files)]
|
| 384 |
+
)
|
| 385 |
+
self.compile_config.traced_files.clear()
|
| 386 |
+
|
| 387 |
+
# Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/{model_tag}/ (last segment = class name or user tag)
|
| 388 |
+
self.local_cache_dir: Path = self.compile_config.cache_dump_path() / hash_key / self.model_tag
|
| 389 |
+
self.local_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 390 |
+
|
| 391 |
+
self.compiler_manager.initialize_cache(self.local_cache_dir, self.model_tag)
|
| 392 |
+
|
| 393 |
+
def _save_partitioned_graph(self, split_gm: fx.GraphModule):
|
| 394 |
+
graph_path = self.local_cache_dir / "computation_graph.py"
|
| 395 |
+
if not graph_path.exists():
|
| 396 |
+
# code adapted from
|
| 397 |
+
# https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
|
| 398 |
+
# use `print_readable` because it can include submodules
|
| 399 |
+
src = "from __future__ import annotations\nimport torch\n" + split_gm.print_readable(print_output=False)
|
| 400 |
+
src = src.replace("<lambda>", "GraphModule")
|
| 401 |
+
with open(graph_path, "w") as f:
|
| 402 |
+
f.write(src)
|
| 403 |
+
magi_logger.info("Computation graph saved to %s", graph_path)
|
| 404 |
+
|
| 405 |
+
def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[SplitItem]]:
|
| 406 |
+
# Step 1: resolve the splitting ops
|
| 407 |
+
if self.compile_config.use_inductor_graph_partition:
|
| 408 |
+
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
|
| 409 |
+
fx_split_ops: list[str] = []
|
| 410 |
+
else:
|
| 411 |
+
fx_split_ops = self.compile_config.splitting_ops or []
|
| 412 |
+
resolved_ops: list[torch._ops.OpOverload] = resolve_defined_ops(fx_split_ops)
|
| 413 |
+
magi_logger.info(f"Setting up FX-level graph split with ops: {fx_split_ops=}")
|
| 414 |
+
magi_logger.info(f"Resolved splitting ops for FX-level graph split: {resolved_ops=}")
|
| 415 |
+
|
| 416 |
+
# Step 2: split graph by ops, we split graph based on resolved_ops, which becomes the partitioned single graph.
|
| 417 |
+
subgraph_id = 0
|
| 418 |
+
node_to_subgraph_id = {}
|
| 419 |
+
split_op_graphs = []
|
| 420 |
+
for node in graph.graph.nodes:
|
| 421 |
+
if node.op in ("output", "placeholder"):
|
| 422 |
+
continue
|
| 423 |
+
# Match node.target against resolved_ops, node.target can be OpOverloadPacket, need to check .default
|
| 424 |
+
if node.op == "call_function" and (
|
| 425 |
+
node.target in resolved_ops or (hasattr(node.target, "default") and node.target.default in resolved_ops)
|
| 426 |
+
):
|
| 427 |
+
magi_logger.info(f"Splitting graph at {node=} with {node.target=}")
|
| 428 |
+
subgraph_id += 1
|
| 429 |
+
node_to_subgraph_id[node] = subgraph_id
|
| 430 |
+
split_op_graphs.append(subgraph_id)
|
| 431 |
+
subgraph_id += 1
|
| 432 |
+
else:
|
| 433 |
+
node_to_subgraph_id[node] = subgraph_id
|
| 434 |
+
|
| 435 |
+
# Step 3: split the graph based on node_to_subgraph_id
|
| 436 |
+
# pytorch might reorder the nodes and the semantics of the graph will change when we have mutations in the graph, if we don't set keep_original_order=True
|
| 437 |
+
split_gm = torch.fx.passes.split_module.split_module(
|
| 438 |
+
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
def _extract_example_values(args) -> list:
|
| 442 |
+
example_values = []
|
| 443 |
+
|
| 444 |
+
def _recurse_extract(arg):
|
| 445 |
+
if isinstance(arg, (list, tuple)):
|
| 446 |
+
for sub_arg in arg:
|
| 447 |
+
_recurse_extract(sub_arg)
|
| 448 |
+
else:
|
| 449 |
+
example_value = arg.meta.get("example_value")
|
| 450 |
+
assert example_value is not None, f"Output arg {arg} has no example_value for tensor_meta recovery"
|
| 451 |
+
example_values.append(example_value)
|
| 452 |
+
|
| 453 |
+
_recurse_extract(args)
|
| 454 |
+
return example_values
|
| 455 |
+
|
| 456 |
+
def _format_output_values(values: list):
|
| 457 |
+
if not values:
|
| 458 |
+
return None
|
| 459 |
+
return tuple(values) if len(values) > 1 else values[0]
|
| 460 |
+
|
| 461 |
+
def _recursive_recover_tensor_meta(gm: fx.GraphModule):
|
| 462 |
+
"""
|
| 463 |
+
递归恢复指定 GraphModule 及其所有嵌套 submodule 中所有 node 的 example_value
|
| 464 |
+
支持任意层级的 submodule 嵌套
|
| 465 |
+
"""
|
| 466 |
+
for node in gm.graph.nodes:
|
| 467 |
+
if node.meta.get("example_value") is not None:
|
| 468 |
+
continue
|
| 469 |
+
|
| 470 |
+
if node.op == "call_module":
|
| 471 |
+
submod: fx.GraphModule = getattr(gm, node.target)
|
| 472 |
+
_recursive_recover_tensor_meta(submod) # 递归调用,处理嵌套 submodule
|
| 473 |
+
output_node = next(n for n in submod.graph.nodes if n.op == "output")
|
| 474 |
+
assert output_node is not None, f"Output node not found in submodule {node.target}"
|
| 475 |
+
output_values = _extract_example_values(output_node.args)
|
| 476 |
+
node.meta["example_value"] = _format_output_values(output_values)
|
| 477 |
+
elif node.op == "call_function":
|
| 478 |
+
if "getitem" in str(node.target):
|
| 479 |
+
prev_node, getitem_index = node.args
|
| 480 |
+
prev_example_value = prev_node.meta.get("example_value")
|
| 481 |
+
assert (
|
| 482 |
+
prev_example_value is not None
|
| 483 |
+
), f"Previous node {prev_node} has no example_value for tensor_meta recovery of node {node}"
|
| 484 |
+
node.meta["example_value"] = prev_example_value[getitem_index]
|
| 485 |
+
elif "grad" in str(node.target) or "device" in str(node.target): # 暂时不做处理
|
| 486 |
+
node.meta["example_value"] = None
|
| 487 |
+
elif node.op == "output":
|
| 488 |
+
output_values = _extract_example_values(node.args[0])
|
| 489 |
+
node.meta["example_value"] = _format_output_values(output_values)
|
| 490 |
+
|
| 491 |
+
else:
|
| 492 |
+
raise ValueError(f"Unsupported node op for tensor_meta recovery: {node.op} for node {node}")
|
| 493 |
+
|
| 494 |
+
magi_logger.info(f"Recovered example_value for node {node.name}: {node.meta['example_value']=}")
|
| 495 |
+
|
| 496 |
+
# Recover tensor_meta for all nodes in split_gm and its submodules
|
| 497 |
+
if envs.MAGI_ENABLE_PROFILE:
|
| 498 |
+
_recursive_recover_tensor_meta(split_gm)
|
| 499 |
+
|
| 500 |
+
# Step 4: fetch all the submodules
|
| 501 |
+
piecewise_graphs = []
|
| 502 |
+
names = [name for (name, module) in split_gm.named_modules()]
|
| 503 |
+
for name in names:
|
| 504 |
+
# Only keep the top-level modules, skip recursive child modules or the root module
|
| 505 |
+
if "." in name or name == "":
|
| 506 |
+
continue
|
| 507 |
+
|
| 508 |
+
module = getattr(split_gm, name)
|
| 509 |
+
assert isinstance(module, fx.GraphModule), f"Expected fx.GraphModule, got {type(module)}"
|
| 510 |
+
|
| 511 |
+
graph_id = int(name.replace("submod_", ""))
|
| 512 |
+
piecewise_graphs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
| 513 |
+
# sort by integer graph_id, rather than string name
|
| 514 |
+
piecewise_graphs.sort(key=lambda x: x.graph_id)
|
| 515 |
+
|
| 516 |
+
# Step 5: visualize the split graph
|
| 517 |
+
# depyf already hooks lazy_format_graph_code and dumps the graph, we do not print the graph here
|
| 518 |
+
lazy_format_graph_code("Before split", graph, print_output=True, include_stride=True, include_device=True)
|
| 519 |
+
lazy_format_graph_code("After split", split_gm, print_output=True, include_stride=True, include_device=True)
|
| 520 |
+
|
| 521 |
+
if envs.MAGI_ENABLE_FX_GRAPH_VIZ:
|
| 522 |
+
save_fx_graph_visualization(split_gm.graph, sub_dir="after_split", filename="split_gm_root")
|
| 523 |
+
for item in piecewise_graphs:
|
| 524 |
+
save_fx_graph_visualization(item.graph.graph, sub_dir="after_split", filename=item.submod_name)
|
| 525 |
+
|
| 526 |
+
return split_gm, piecewise_graphs
|
| 527 |
+
|
| 528 |
+
def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFunction:
|
| 529 |
+
assert not self._called_once, "MagiBackend can only be called once cause compilation is a one-time process"
|
| 530 |
+
self._called_once = True
|
| 531 |
+
magi_logger.info("Dynamo traced files (for compilation cache):\n%s", "\n".join(self.compile_config.traced_files))
|
| 532 |
+
compilation_counter.num_graphs_seen += 1
|
| 533 |
+
CompileMonitor().mark("Dynamo bytecode transform")
|
| 534 |
+
|
| 535 |
+
self._init_cache()
|
| 536 |
+
|
| 537 |
+
self.full_graph_pass_manager(graph)
|
| 538 |
+
|
| 539 |
+
split_gm, piecewise_graphs = self._split_graph(graph)
|
| 540 |
+
|
| 541 |
+
submod_names_to_compile = [item.submod_name for item in piecewise_graphs if not item.is_splitting_graph]
|
| 542 |
+
compilation_counter.num_piecewise_graphs_seen += len(piecewise_graphs)
|
| 543 |
+
compilation_counter.num_piecewise_capturable_graphs_seen += len(submod_names_to_compile)
|
| 544 |
+
magi_logger.info(f"Piecewise modules waiting for compilation: {submod_names_to_compile}")
|
| 545 |
+
|
| 546 |
+
# propagate the split graph to the piecewise backend, compile submodules with symbolic shapes
|
| 547 |
+
try:
|
| 548 |
+
PiecewiseCompileInterpreter(split_gm, self.compiler_manager, submod_names_to_compile, self.compile_config).run(
|
| 549 |
+
*example_inputs
|
| 550 |
+
)
|
| 551 |
+
except Exception as e:
|
| 552 |
+
# Magi compile 的集中失败入口:直接打印 ERROR,方便在大模型日志中 grep
|
| 553 |
+
magi_logger.error("Magi compile failed while compiling piecewise submodules %s: %s", submod_names_to_compile, e)
|
| 554 |
+
raise
|
| 555 |
+
self._save_partitioned_graph(split_gm)
|
| 556 |
+
|
| 557 |
+
# TODO: Support DBO and NAT here
|
| 558 |
+
# split_gm = DBOGraphModule(split_gm, self.compile_config)
|
| 559 |
+
if self.compile_config.offload_config.model_cpu_offload:
|
| 560 |
+
split_gm = OffloadWrapper(split_gm, self.compile_config)
|
| 561 |
+
|
| 562 |
+
# if envs.MAGI_ENABLE_TOKENFLOW:
|
| 563 |
+
# from magi_compiler.tokenflow.graph_fork import GraphForkWrapper
|
| 564 |
+
|
| 565 |
+
if envs.MAGI_ENABLE_PROFILE:
|
| 566 |
+
from magi_compiler.tokenflow.graph_profile import gen_profile_wrap_func
|
| 567 |
+
|
| 568 |
+
split_gm = gen_profile_wrap_func(split_gm)
|
| 569 |
+
|
| 570 |
+
if self.compile_config.cudagraph_mode == CudaGraphMode.FULL and self.compile_config.cudagraph_copy_inputs:
|
| 571 |
+
return self._serialize_func_with_cudagraph(graph, split_gm, example_inputs)
|
| 572 |
+
|
| 573 |
+
return MagiSerializableFunction(graph, example_inputs, self.model_tag, split_gm)
|
| 574 |
+
|
| 575 |
+
def _serialize_func_with_cudagraph(
|
| 576 |
+
self, graph: fx.GraphModule, split_gm: fx.GraphModule, example_inputs: list[Any]
|
| 577 |
+
) -> MagiSerializableFunction:
|
| 578 |
+
fake_mode = detect_fake_mode()
|
| 579 |
+
fake_args = [fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in example_inputs]
|
| 580 |
+
|
| 581 |
+
self.sym_tensor_indices = detect_symbolic_tensor_indices(fake_args)
|
| 582 |
+
|
| 583 |
+
wrapped_split_gm = gen_wrap_func_for_cudagraph(func=split_gm, mode_prefix=CudaGraphMode.FULL.name.lower())
|
| 584 |
+
|
| 585 |
+
return MagiSerializableFunction(graph, example_inputs, self.model_tag, wrapped_split_gm)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def init_backend(compile_config: CompileConfig) -> str | Callable:
|
| 589 |
+
"""
|
| 590 |
+
Initialize the backend based on CompileConfig.
|
| 591 |
+
"""
|
| 592 |
+
if compile_config.compile_mode is None or compile_config.compile_mode == CompileMode.NONE:
|
| 593 |
+
raise ValueError("No compilation mode is set.")
|
| 594 |
+
|
| 595 |
+
from torch._dynamo.backends.registry import list_backends
|
| 596 |
+
|
| 597 |
+
torch_backends = list_backends(exclude_tags=tuple())
|
| 598 |
+
magi_logger.info("Supported torch backends: %s", torch_backends)
|
| 599 |
+
if compile_config.compile_mode == CompileMode.TORCH_COMPILE:
|
| 600 |
+
assert compile_config.backend in torch_backends, f"Invalid backend for torch compilation: {compile_config.backend}"
|
| 601 |
+
return compile_config.backend
|
| 602 |
+
elif compile_config.compile_mode == CompileMode.MAGI_COMPILE:
|
| 603 |
+
assert compile_config.backend in ["eager", "inductor"], f"Invalid backend for MagiCompiler: {compile_config.backend}"
|
| 604 |
+
model_tag = getattr(compile_config, "model_tag", None) or MAGI_MODEL_TAG
|
| 605 |
+
return MagiBackend(compile_config, model_tag=model_tag)
|
| 606 |
+
else:
|
| 607 |
+
raise ValueError(f"Invalid compile mode: {compile_config.compile_mode}")
|
pkgs/MagiCompiler/magi_compiler/magi_compiler_base.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from abc import abstractmethod
|
| 22 |
+
from contextlib import contextmanager
|
| 23 |
+
from types import CodeType
|
| 24 |
+
from typing import Callable, Literal
|
| 25 |
+
|
| 26 |
+
import magi_compiler.utils.envs as envs
|
| 27 |
+
import torch
|
| 28 |
+
from magi_compiler.utils import compute_hash, get_git_version
|
| 29 |
+
from magi_compiler.utils.compile_time_monitor import CompileMonitor
|
| 30 |
+
|
| 31 |
+
from .config import CompileConfig, CompileMode
|
| 32 |
+
from .magi_backend import init_backend
|
| 33 |
+
from .utils import compute_code_hash, compute_code_hash_with_content, magi_logger
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _verify_source_unchanged(source_info, compile_config: CompileConfig) -> None:
|
| 37 |
+
file_contents = {}
|
| 38 |
+
for source in source_info.inlined_sources:
|
| 39 |
+
module = sys.modules[source.module]
|
| 40 |
+
file = inspect.getfile(module)
|
| 41 |
+
file_contents[file] = source.content
|
| 42 |
+
compile_config.traced_files.add(file)
|
| 43 |
+
expected_checksum = compute_code_hash_with_content(file_contents)
|
| 44 |
+
actual_checksum = compute_code_hash(set(file_contents.keys()))
|
| 45 |
+
if expected_checksum != actual_checksum:
|
| 46 |
+
raise RuntimeError("Source code has changed since the last compilation. Recompiling the model.")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class MagiCompilerBase:
|
| 50 |
+
compile_config: CompileConfig
|
| 51 |
+
"""
|
| 52 |
+
A wrapper class for torch.compile, with a custom dispatch logic.
|
| 53 |
+
Subclasses should:
|
| 54 |
+
1. Implement the forward method
|
| 55 |
+
2. Implement the dispatch logic in the __call__ method
|
| 56 |
+
It can use `self.compiled_codes` to access the compiled bytecode,
|
| 57 |
+
and `with self.dispatch_to_compiled_code:` to dispatch to
|
| 58 |
+
the compiled code.
|
| 59 |
+
3. Implement the `__init__` method to determine how to call
|
| 60 |
+
`torch.compile` over the forward method.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, compile_config: CompileConfig):
|
| 64 |
+
backend = init_backend(compile_config)
|
| 65 |
+
options = None
|
| 66 |
+
if isinstance(backend, str) and backend == "inductor":
|
| 67 |
+
options = compile_config.inductor_compile_config
|
| 68 |
+
if envs.MAGI_AOT_COMPILE:
|
| 69 |
+
options = options or {}
|
| 70 |
+
# Drop all the guards in the AOT compile mode as bytecode hook is not used anymore.
|
| 71 |
+
options["guard_filter_fn"] = lambda guards: [False for _ in guards]
|
| 72 |
+
assert hasattr(torch._dynamo.config, "enable_aot_compile"), "enable_aot_compile config not available"
|
| 73 |
+
torch._dynamo.config.enable_aot_compile = True
|
| 74 |
+
|
| 75 |
+
self.compiled_callable = torch.compile(self.forward, fullgraph=True, backend=backend, options=options)
|
| 76 |
+
self.original_code_object: CodeType = self.__class__.forward.__code__
|
| 77 |
+
self.compiled_code: CodeType | None = None
|
| 78 |
+
self.aot_compiled_fn: Callable | None = None
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def aot_compilation_path(self) -> str:
|
| 82 |
+
"""
|
| 83 |
+
When using torch.compile in AOT mode, we store the cache artifacts
|
| 84 |
+
under cache_root_dir/torch_aot_compile/{hash}/rank_i_j. The {hash}
|
| 85 |
+
contains all of the factors except for the source files being
|
| 86 |
+
traced through, because we don't actually know which source files
|
| 87 |
+
to check at this point (before dynamo runs).
|
| 88 |
+
On loading we will actually look at the source files being traced
|
| 89 |
+
through. If any source file have changed (compared with the
|
| 90 |
+
serialized backend artifacts), then we need to generate a new AOT
|
| 91 |
+
compile artifact from scratch.
|
| 92 |
+
"""
|
| 93 |
+
hash_key = compute_hash([self.forward, self.compile_config.model_idx, self.compile_config.hash, get_git_version()])
|
| 94 |
+
cache_dir = os.path.join(self.compile_config.cache_root_dir, "torch_aot_compile", hash_key)
|
| 95 |
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
| 96 |
+
cache_dir = os.path.join(cache_dir, f"rank_{rank}")
|
| 97 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 98 |
+
aot_compilation_path = os.path.join(cache_dir, "model")
|
| 99 |
+
return aot_compilation_path
|
| 100 |
+
|
| 101 |
+
def try_load_aot_compile_artifacts(self) -> Callable | None:
|
| 102 |
+
if self.aot_compiled_fn is not None:
|
| 103 |
+
return self.aot_compiled_fn
|
| 104 |
+
if not os.path.exists(self.aot_compilation_path):
|
| 105 |
+
return None
|
| 106 |
+
with open(self.aot_compilation_path, "rb") as f:
|
| 107 |
+
CompileMonitor().start(
|
| 108 |
+
self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path()
|
| 109 |
+
)
|
| 110 |
+
loaded_fn = torch.compiler.load_compiled_function(f)
|
| 111 |
+
_verify_source_unchanged(loaded_fn.source_info(), self.compile_config)
|
| 112 |
+
return loaded_fn
|
| 113 |
+
|
| 114 |
+
def aot_compile(self, *args, **kwargs):
|
| 115 |
+
"""
|
| 116 |
+
Run the model in AOT (Ahead-Of-Time) compile mode.
|
| 117 |
+
|
| 118 |
+
All compilation work is completed before execution, suitable for production environment.
|
| 119 |
+
This results in longer compilation time but superior runtime performance.
|
| 120 |
+
"""
|
| 121 |
+
assert hasattr(self.compiled_callable, "aot_compile"), "aot_compile is not supported by the current configuration"
|
| 122 |
+
return self.compiled_callable.aot_compile((args, kwargs))
|
| 123 |
+
|
| 124 |
+
def jit_compile(self, *args, **kwargs):
|
| 125 |
+
"""
|
| 126 |
+
Run the model in JIT (Just-In-Time) compile mode.
|
| 127 |
+
|
| 128 |
+
Compilation occurs at runtime, first run may be slower due to compilation overhead.
|
| 129 |
+
"""
|
| 130 |
+
handle = torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
| 131 |
+
output = self.compiled_callable(*args, **kwargs)
|
| 132 |
+
handle.remove()
|
| 133 |
+
return output
|
| 134 |
+
|
| 135 |
+
@abstractmethod
|
| 136 |
+
def forward(self, *args, **kwargs):
|
| 137 |
+
...
|
| 138 |
+
|
| 139 |
+
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
| 140 |
+
"""Hook to save the compiled bytecode for direct execution."""
|
| 141 |
+
if old_code is not self.original_code_object:
|
| 142 |
+
return
|
| 143 |
+
# Step1: Check if the old bytecode is from the compiled code
|
| 144 |
+
# code borrowed from depyf enable_debugging.py
|
| 145 |
+
frame = sys._getframe()
|
| 146 |
+
while frame and frame.f_back:
|
| 147 |
+
frame = frame.f_back
|
| 148 |
+
code_name = frame.f_code.co_name
|
| 149 |
+
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
| 150 |
+
if code_name == "_compile" and file_name == "convert_frame.py":
|
| 151 |
+
break
|
| 152 |
+
frame = frame.f_locals["frame"]
|
| 153 |
+
assert frame.f_code == old_code
|
| 154 |
+
|
| 155 |
+
if hasattr(frame.f_locals, "self") and frame.f_locals["self"] is not self:
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
# Step2: Save the compiled bytecode
|
| 159 |
+
self.compiled_code = new_code
|
| 160 |
+
|
| 161 |
+
# Step3: Save the decompiled code
|
| 162 |
+
path = self.compile_config.debug_dump_path()
|
| 163 |
+
decompiled_file = os.path.join(path, "decompiled_code.py")
|
| 164 |
+
if os.path.exists(decompiled_file):
|
| 165 |
+
return
|
| 166 |
+
try:
|
| 167 |
+
# usually the decompilation will succeed for most models, as we guarantee a full-graph compilation in Dynamo.
|
| 168 |
+
# but there's no 100% guarantee, since decompliation is not a reversible process.
|
| 169 |
+
from magi_compiler.magi_depyf import decompile as magi_decompile
|
| 170 |
+
|
| 171 |
+
src = magi_decompile(new_code)
|
| 172 |
+
with open(decompiled_file, "w") as f:
|
| 173 |
+
f.write(src)
|
| 174 |
+
magi_logger.info("Dynamo transformed code saved to %s", decompiled_file)
|
| 175 |
+
except Exception:
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
@contextmanager
|
| 179 |
+
def dispatch_to_compiled_fwd(self, mode: Literal["jit", "aot"] = "jit"):
|
| 180 |
+
"""
|
| 181 |
+
Context manager to dispatch to the compiled code.
|
| 182 |
+
Why does this work? Because Dynamo guarantees that the compiled
|
| 183 |
+
bytecode has exactly the same arguments, cell variables, and free
|
| 184 |
+
variables as the original code. Therefore we can directly switch
|
| 185 |
+
the code object in the function and call it.
|
| 186 |
+
|
| 187 |
+
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
|
| 188 |
+
for more details.
|
| 189 |
+
|
| 190 |
+
NOTE: Why compile `forward` but invoke through `old_call`?
|
| 191 |
+
|
| 192 |
+
In torch.nn.Module, `__call__` wraps `forward` with critical runtime logic:
|
| 193 |
+
- Pre/post forward hooks
|
| 194 |
+
- FSDP parameter sharding/gathering and device placement
|
| 195 |
+
|
| 196 |
+
Our strategy: use this context manager to temporarily replace `self.forward`
|
| 197 |
+
with the compiled version, then invoke `old_call(self, *args, **kwargs)`.
|
| 198 |
+
|
| 199 |
+
This way:
|
| 200 |
+
1. `old_call` executes hooks and FSDP mechanics normally
|
| 201 |
+
2. When `old_call` internally calls `self.forward`, it hits our compiled code
|
| 202 |
+
3. Compiled code runs within the proper FSDP/hook context
|
| 203 |
+
|
| 204 |
+
Calling `self.forward()` directly would bypass FSDP (seeing sharded/invalid
|
| 205 |
+
params) and skip hooks that other components may rely on.
|
| 206 |
+
"""
|
| 207 |
+
if mode == "jit":
|
| 208 |
+
assert self.compiled_code is not None
|
| 209 |
+
self.__class__.forward.__code__ = self.compiled_code
|
| 210 |
+
yield
|
| 211 |
+
self.__class__.forward.__code__ = self.original_code_object
|
| 212 |
+
elif mode == "aot":
|
| 213 |
+
assert self.aot_compiled_fn is not None
|
| 214 |
+
old_forward = self.forward
|
| 215 |
+
self.forward = lambda *args, **kwargs: self.aot_compiled_fn(self, *args, **kwargs)
|
| 216 |
+
yield
|
| 217 |
+
self.forward = old_forward
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError(f"Invalid mode: {mode}")
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""magi_depyf — a modern bytecode decompiler and torch.compile inspector."""
|
| 16 |
+
|
| 17 |
+
from .decompile import DecompilationError, Decompiler, decompile, safe_decompile
|
| 18 |
+
|
| 19 |
+
__version__ = "0.1.0"
|
| 20 |
+
|
| 21 |
+
__all__ = ["Decompiler", "decompile", "safe_decompile", "DecompilationError", "__version__"]
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Decompilation: bytecode → Python source, plus recompile/fix/postprocess."""
|
| 16 |
+
|
| 17 |
+
from .decompiler import DecompilationError, Decompiler, decompile, safe_decompile
|
| 18 |
+
|
| 19 |
+
__all__ = ["Decompiler", "decompile", "safe_decompile", "DecompilationError"]
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Bytecode processing — pure Python, no torch dependency."""
|
| 16 |
+
|
| 17 |
+
from .decompile_context import DecompileContext
|
| 18 |
+
from .handler_registry import HandlerRegistry, registry
|
| 19 |
+
from .instruction import Instruction
|
| 20 |
+
from .source_emitter import SourceEmitter
|
| 21 |
+
|
| 22 |
+
__all__ = ["Instruction", "SourceEmitter", "HandlerRegistry", "DecompileContext", "registry"]
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/decompile_context.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""DecompileContext — read-only bag passed to every handler.
|
| 16 |
+
|
| 17 |
+
Handlers receive ``(emitter, inst, ctx)`` — they mutate *emitter*
|
| 18 |
+
and call *ctx* methods but never touch the ``Decompiler`` directly.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from types import CodeType
|
| 24 |
+
from typing import TYPE_CHECKING, Callable, Dict, Tuple
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from .instruction import Instruction
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DecompileContext:
|
| 31 |
+
"""Read-only context providing handlers with instructions, code object,
|
| 32 |
+
and the ``decompile_range`` callback for recursive sub-block decompilation."""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
code: CodeType,
|
| 37 |
+
instructions: Tuple["Instruction", ...],
|
| 38 |
+
indentation: int,
|
| 39 |
+
decompile_range: Callable,
|
| 40 |
+
offset_to_index: Dict[int, int],
|
| 41 |
+
) -> None:
|
| 42 |
+
self.code = code
|
| 43 |
+
self.instructions = instructions
|
| 44 |
+
self.indentation = indentation
|
| 45 |
+
self.decompile_range = decompile_range
|
| 46 |
+
self._offset_to_index = offset_to_index
|
| 47 |
+
|
| 48 |
+
def index_of(self, offset: int) -> int:
|
| 49 |
+
"""Return the index of the instruction at *offset* (O(1) lookup)."""
|
| 50 |
+
try:
|
| 51 |
+
return self._offset_to_index[offset]
|
| 52 |
+
except KeyError:
|
| 53 |
+
raise ValueError(f"No instruction at offset {offset}") from None
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handler_registry.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""HandlerRegistry — opcode-to-handler dispatch.
|
| 16 |
+
|
| 17 |
+
A *handler* is a plain function with signature::
|
| 18 |
+
|
| 19 |
+
(emitter: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]
|
| 20 |
+
|
| 21 |
+
Returning ``None`` advances to the next instruction.
|
| 22 |
+
Returning an ``int`` jumps to that instruction index.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
| 28 |
+
|
| 29 |
+
if TYPE_CHECKING:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
HandlerFn = Callable[..., Optional[int]]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class HandlerRegistry:
|
| 36 |
+
"""Maps opcode names -> handler functions."""
|
| 37 |
+
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
self._handlers: dict[str, HandlerFn] = {}
|
| 40 |
+
|
| 41 |
+
def register(self, *opnames: str) -> Callable[[HandlerFn], HandlerFn]:
|
| 42 |
+
"""Decorator that registers *fn* for one or more opcode names."""
|
| 43 |
+
|
| 44 |
+
def decorator(fn: HandlerFn) -> HandlerFn:
|
| 45 |
+
for name in opnames:
|
| 46 |
+
self._handlers[name] = fn
|
| 47 |
+
return fn
|
| 48 |
+
|
| 49 |
+
return decorator
|
| 50 |
+
|
| 51 |
+
def get(self, opname: str) -> Optional[HandlerFn]:
|
| 52 |
+
return self._handlers.get(opname)
|
| 53 |
+
|
| 54 |
+
def __contains__(self, opname: str) -> bool:
|
| 55 |
+
return opname in self._handlers
|
| 56 |
+
|
| 57 |
+
def supported_opnames(self) -> List[str]:
|
| 58 |
+
return sorted(self._handlers.keys())
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Singleton registry — handlers register against this at import time.
|
| 62 |
+
registry = HandlerRegistry()
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Import every handler module so they register against the global registry."""
|
| 16 |
+
|
| 17 |
+
from . import arithmetic # noqa: F401
|
| 18 |
+
from . import calls # noqa: F401
|
| 19 |
+
from . import containers # noqa: F401
|
| 20 |
+
from . import control_flow # noqa: F401
|
| 21 |
+
from . import load_store # noqa: F401
|
| 22 |
+
from . import stack_ops # noqa: F401
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/arithmetic.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Handlers for unary, binary, inplace, and comparison operations."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from ..decompile_context import DecompileContext
|
| 20 |
+
from ..handler_registry import registry
|
| 21 |
+
from ..instruction import Instruction
|
| 22 |
+
from ..source_emitter import SourceEmitter
|
| 23 |
+
|
| 24 |
+
_reg = registry.register
|
| 25 |
+
|
| 26 |
+
# ── Unary ─────────────────────────────────────────────────────────────────
|
| 27 |
+
|
| 28 |
+
_UNARY_SYMBOLS = {"UNARY_NEGATIVE": "-", "UNARY_POSITIVE": "+", "UNARY_INVERT": "~", "UNARY_NOT": "not"}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@_reg(*_UNARY_SYMBOLS)
|
| 32 |
+
def _unary(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 33 |
+
em.push(f"({_UNARY_SYMBOLS[inst.opname]} {em.pop()})")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@_reg("GET_LEN")
|
| 37 |
+
def _get_len(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 38 |
+
em.push(f"len({em.peek()})")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ── Binary ────────────────────────────────────────────────────────────────
|
| 42 |
+
|
| 43 |
+
_BINARY_SYMBOLS = {
|
| 44 |
+
"BINARY_MULTIPLY": "*",
|
| 45 |
+
"BINARY_ADD": "+",
|
| 46 |
+
"BINARY_SUBTRACT": "-",
|
| 47 |
+
"BINARY_TRUE_DIVIDE": "/",
|
| 48 |
+
"BINARY_FLOOR_DIVIDE": "//",
|
| 49 |
+
"BINARY_MODULO": "%",
|
| 50 |
+
"BINARY_POWER": "**",
|
| 51 |
+
"BINARY_AND": "&",
|
| 52 |
+
"BINARY_OR": "|",
|
| 53 |
+
"BINARY_XOR": "^",
|
| 54 |
+
"BINARY_LSHIFT": "<<",
|
| 55 |
+
"BINARY_RSHIFT": ">>",
|
| 56 |
+
"BINARY_MATRIX_MULTIPLY": "@",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@_reg(*_BINARY_SYMBOLS)
|
| 61 |
+
def _binary(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 62 |
+
rhs = em.pop()
|
| 63 |
+
lhs = em.pop()
|
| 64 |
+
em.push(f"({lhs} {_BINARY_SYMBOLS[inst.opname]} {rhs})")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@_reg("BINARY_SUBSCR")
|
| 68 |
+
def _subscr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 69 |
+
rhs = em.pop()
|
| 70 |
+
lhs = em.pop()
|
| 71 |
+
em.push(f"{lhs}[{rhs}]")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@_reg("BINARY_SLICE")
|
| 75 |
+
def _slice(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 76 |
+
end = em.pop()
|
| 77 |
+
start = em.pop()
|
| 78 |
+
container = em.pop()
|
| 79 |
+
em.push(f"{container}[{start}:{end}]")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ── Inplace ───────────────────────────────────────────────────────────────
|
| 83 |
+
|
| 84 |
+
_INPLACE_SYMBOLS = {
|
| 85 |
+
"INPLACE_MULTIPLY": "*",
|
| 86 |
+
"INPLACE_ADD": "+",
|
| 87 |
+
"INPLACE_SUBTRACT": "-",
|
| 88 |
+
"INPLACE_TRUE_DIVIDE": "/",
|
| 89 |
+
"INPLACE_FLOOR_DIVIDE": "//",
|
| 90 |
+
"INPLACE_MODULO": "%",
|
| 91 |
+
"INPLACE_POWER": "**",
|
| 92 |
+
"INPLACE_AND": "&",
|
| 93 |
+
"INPLACE_OR": "|",
|
| 94 |
+
"INPLACE_XOR": "^",
|
| 95 |
+
"INPLACE_LSHIFT": "<<",
|
| 96 |
+
"INPLACE_RSHIFT": ">>",
|
| 97 |
+
"INPLACE_MATRIX_MULTIPLY": "@",
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@_reg(*_INPLACE_SYMBOLS)
|
| 102 |
+
def _inplace(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 103 |
+
rhs = em.pop()
|
| 104 |
+
lhs = em.pop()
|
| 105 |
+
em.emit(f"{lhs} {_INPLACE_SYMBOLS[inst.opname]}= {rhs}")
|
| 106 |
+
em.push(lhs)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@_reg("BINARY_OP")
|
| 110 |
+
def _binary_op(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 111 |
+
"""Python 3.12+ unified BINARY_OP."""
|
| 112 |
+
rhs = em.pop()
|
| 113 |
+
lhs = em.pop()
|
| 114 |
+
if "=" in inst.argrepr:
|
| 115 |
+
em.emit(f"{lhs} {inst.argrepr} {rhs}")
|
| 116 |
+
em.push(lhs)
|
| 117 |
+
else:
|
| 118 |
+
em.push(f"({lhs} {inst.argrepr} {rhs})")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ── Comparison ────────────────────────────────────────────────────────────
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@_reg("COMPARE_OP")
|
| 125 |
+
def _compare(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 126 |
+
rhs = em.pop()
|
| 127 |
+
lhs = em.pop()
|
| 128 |
+
em.push(f"({lhs} {inst.argval} {rhs})")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@_reg("IS_OP")
|
| 132 |
+
def _is_op(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 133 |
+
rhs = em.pop()
|
| 134 |
+
lhs = em.pop()
|
| 135 |
+
op = "is" if inst.argval == 0 else "is not"
|
| 136 |
+
em.push(f"({lhs} {op} {rhs})")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@_reg("CONTAINS_OP")
|
| 140 |
+
def _contains(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 141 |
+
rhs = em.pop()
|
| 142 |
+
lhs = em.pop()
|
| 143 |
+
op = "in" if inst.argval == 0 else "not in"
|
| 144 |
+
em.push(f"({lhs} {op} {rhs})")
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/calls.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Handlers for function-call and function-creation opcodes."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
from ..decompile_context import DecompileContext
|
| 23 |
+
from ..handler_registry import registry
|
| 24 |
+
from ..instruction import Instruction
|
| 25 |
+
from ..source_emitter import SourceEmitter
|
| 26 |
+
|
| 27 |
+
_reg = registry.register
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@_reg("KW_NAMES")
|
| 31 |
+
def _kw_names(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 32 |
+
# Python 3.11+ instruction that passes keyword argument names to the subsequent CALL.
|
| 33 |
+
# inst.arg indexes into co_consts for the key-name tuple, e.g. ('y', 'z').
|
| 34 |
+
# Push repr so it becomes the string "('y', 'z')"; the CALL handler later eval()s it back to a tuple.
|
| 35 |
+
names = ctx.code.co_consts[inst.arg]
|
| 36 |
+
em.push(repr(names))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@_reg("CALL")
|
| 40 |
+
def _call(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 41 |
+
"""Python 3.11+ unified CALL.
|
| 42 |
+
|
| 43 |
+
3.12 stack layout: [NULL, callable, arg0, ..., argN-1] (KW_NAMES precedes)
|
| 44 |
+
3.11 stack layout: [NULL, callable, arg0, ..., argN-1] (KW_NAMES → PRECALL → CALL)
|
| 45 |
+
"""
|
| 46 |
+
# Check whether KW_NAMES precedes CALL (indicating keyword arguments exist).
|
| 47 |
+
# 3.12: KW_NAMES → CALL; 3.11: KW_NAMES → PRECALL → CALL
|
| 48 |
+
preceding = [x for x in ctx.instructions if x.offset < inst.offset]
|
| 49 |
+
has_kw = False
|
| 50 |
+
if preceding:
|
| 51 |
+
if preceding[-1].opname == "KW_NAMES" or (
|
| 52 |
+
len(preceding) > 1
|
| 53 |
+
and preceding[-2].opname == "KW_NAMES"
|
| 54 |
+
and preceding[-1].opname == "PRECALL" # 3.11 transitional opcode, removed in 3.12
|
| 55 |
+
):
|
| 56 |
+
has_kw = True
|
| 57 |
+
|
| 58 |
+
kw_names: tuple = ()
|
| 59 |
+
if has_kw:
|
| 60 |
+
kw_names = eval(em.pop()) # retrieve the tuple stored by KW_NAMES from the stack
|
| 61 |
+
args = [em.pop() for _ in range(inst.argval)][::-1]
|
| 62 |
+
pos_args = args[: len(args) - len(kw_names)]
|
| 63 |
+
kw_args = args[len(args) - len(kw_names) :]
|
| 64 |
+
kwcalls = [f"{n}={v}" for n, v in zip(kw_names, kw_args)]
|
| 65 |
+
func = em.pop()
|
| 66 |
+
# 3.11+ PUSH_NULL / LOAD_GLOBAL(NULL+name) pushes a NULL sentinel before the call.
|
| 67 |
+
# After popping the callable, the top of stack may be NULL (represented as None); clear it.
|
| 68 |
+
if em.stack_size and em.peek() is None:
|
| 69 |
+
em.pop()
|
| 70 |
+
# GET_ITER produces "iter(x)"; if func happens to be "iter(x)" it is actually an argument
|
| 71 |
+
# (e.g. in the next(iter(x)) pattern), and the real callable is further down the stack.
|
| 72 |
+
if "iter(" in str(func):
|
| 73 |
+
pos_args = [func]
|
| 74 |
+
func = em.pop()
|
| 75 |
+
em.push(f"{func}({', '.join(pos_args + kwcalls)})")
|
| 76 |
+
# replace_tos_with_temp: the call result may be referenced multiple times (assignment, passing,
|
| 77 |
+
# method call), so store it in a temp to avoid repeated evaluation and side effects.
|
| 78 |
+
em.replace_tos_with_temp()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@_reg("CALL_FUNCTION", "CALL_METHOD")
|
| 82 |
+
def _call_legacy(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 83 |
+
"""CALL_FUNCTION / CALL_METHOD (Python ≤3.10)."""
|
| 84 |
+
args = [em.pop() for _ in range(inst.argval)][::-1]
|
| 85 |
+
func = em.pop()
|
| 86 |
+
em.push(f"{func}({', '.join(args)})")
|
| 87 |
+
em.replace_tos_with_temp()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@_reg("CALL_FUNCTION_KW")
|
| 91 |
+
def _call_function_kw(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 92 |
+
kw_args = eval(em.pop())
|
| 93 |
+
kw_vals = [em.pop() for _ in range(len(kw_args))]
|
| 94 |
+
kw_vals.reverse()
|
| 95 |
+
kwcalls = [f"{n}={v}" for n, v in zip(kw_args, kw_vals)]
|
| 96 |
+
pos_args = [em.pop() for _ in range(inst.argval - len(kw_args))][::-1]
|
| 97 |
+
func = em.pop()
|
| 98 |
+
em.push(f"{func}({', '.join(pos_args + kwcalls)})")
|
| 99 |
+
em.replace_tos_with_temp()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@_reg("CALL_FUNCTION_EX")
|
| 103 |
+
def _call_function_ex(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 104 |
+
# 3.11+ stack: [NULL, func, args (, kwargs)]
|
| 105 |
+
# After popping func, clear the NULL sentinel before pushing the result
|
| 106 |
+
if inst.argval == 0:
|
| 107 |
+
a = em.pop()
|
| 108 |
+
f = em.pop()
|
| 109 |
+
if em.stack_size and em.peek() is None:
|
| 110 |
+
em.pop()
|
| 111 |
+
em.push(f"{f}(*{a})")
|
| 112 |
+
elif inst.argval == 1:
|
| 113 |
+
kw = em.pop()
|
| 114 |
+
a = em.pop()
|
| 115 |
+
f = em.pop()
|
| 116 |
+
if em.stack_size and em.peek() is None:
|
| 117 |
+
em.pop()
|
| 118 |
+
em.push(f"{f}(*{a}, **{kw})")
|
| 119 |
+
em.replace_tos_with_temp()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@_reg("CALL_INTRINSIC_1")
|
| 123 |
+
def _intrinsic_1(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 124 |
+
"""Python 3.12 instruction replacing some internal C-level calls.
|
| 125 |
+
argrepr identifies the specific operation, e.g. INTRINSIC_PRINT, INTRINSIC_UNARY_POSITIVE.
|
| 126 |
+
Most are compiler-internal operations (import *, typealias) rarely triggered by user code."""
|
| 127 |
+
_SKIP = {
|
| 128 |
+
"INTRINSIC_1_INVALID",
|
| 129 |
+
"INTRINSIC_IMPORT_STAR",
|
| 130 |
+
"INTRINSIC_STOPITERATION_ERROR",
|
| 131 |
+
"INTRINSIC_ASYNC_GEN_WRAP",
|
| 132 |
+
"INTRINSIC_TYPEVAR",
|
| 133 |
+
"INTRINSIC_PARAMSPEC",
|
| 134 |
+
"INTRINSIC_TYPEVARTUPLE",
|
| 135 |
+
"INTRINSIC_SUBSCRIPT_GENERIC",
|
| 136 |
+
"INTRINSIC_TYPEALIAS",
|
| 137 |
+
}
|
| 138 |
+
if inst.argrepr in _SKIP:
|
| 139 |
+
return
|
| 140 |
+
if inst.argrepr == "INTRINSIC_PRINT":
|
| 141 |
+
em.emit(f"print({em.pop()})")
|
| 142 |
+
em.push("None")
|
| 143 |
+
elif inst.argrepr == "INTRINSIC_UNARY_POSITIVE":
|
| 144 |
+
em.set_at(0, f"+{em.peek()}")
|
| 145 |
+
elif inst.argrepr == "INTRINSIC_LIST_TO_TUPLE":
|
| 146 |
+
em.push(f"tuple({em.pop()})")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ── MAKE_FUNCTION ─────────────────────────────────────────────────────────
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@_reg("MAKE_FUNCTION")
|
| 153 |
+
def _make_function(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
|
| 154 |
+
"""Handle bytecode for def inner(...) and lambda.
|
| 155 |
+
|
| 156 |
+
Bytecode: LOAD_CONST <code_object> → MAKE_FUNCTION → STORE_FAST name
|
| 157 |
+
The handler recursively decompiles the inner code object and emits the full def statement.
|
| 158 |
+
"""
|
| 159 |
+
if sys.version_info < (3, 11):
|
| 160 |
+
# 3.10: qualified_name string is still on the stack
|
| 161 |
+
qual_name = em.pop()
|
| 162 |
+
try:
|
| 163 |
+
qual_name = eval(qual_name)
|
| 164 |
+
except Exception:
|
| 165 |
+
pass
|
| 166 |
+
func_name = qual_name.split(".")[-1]
|
| 167 |
+
if "<" in func_name: # <lambda>, <listcomp>, etc. — invalid identifiers
|
| 168 |
+
em.emit(f'"original function name {func_name} is illegal, use a temp name."')
|
| 169 |
+
func_name = em.make_temp()
|
| 170 |
+
else:
|
| 171 |
+
func_name = em.make_temp()
|
| 172 |
+
|
| 173 |
+
code = em.pop() # inner CodeType object pushed by LOAD_CONST
|
| 174 |
+
# argval bit flags indicate whether extra function components remain on the stack
|
| 175 |
+
if inst.argval & 0x08:
|
| 176 |
+
em.pop() # closure tuple (cell references for freevars)
|
| 177 |
+
if inst.argval & 0x04:
|
| 178 |
+
em.pop() # annotations dict
|
| 179 |
+
if inst.argval & 0x02:
|
| 180 |
+
em.pop() # keyword-only defaults tuple
|
| 181 |
+
if inst.argval & 0x01:
|
| 182 |
+
em.pop() # positional defaults tuple
|
| 183 |
+
|
| 184 |
+
# If the next instruction is STORE_FAST, use the target variable name as the function name
|
| 185 |
+
this_idx = ctx.index_of(inst.offset)
|
| 186 |
+
immediately_used = False
|
| 187 |
+
if ctx.instructions[this_idx + 1].opname == "STORE_FAST":
|
| 188 |
+
func_name = ctx.instructions[this_idx + 1].argval
|
| 189 |
+
immediately_used = True
|
| 190 |
+
|
| 191 |
+
# Recurse: create a new Decompiler instance for the inner code object
|
| 192 |
+
from ...decompiler import Decompiler
|
| 193 |
+
|
| 194 |
+
inner = Decompiler(code).decompile(overwrite_fn_name=func_name)
|
| 195 |
+
em.emit_raw(inner)
|
| 196 |
+
|
| 197 |
+
if immediately_used:
|
| 198 |
+
return this_idx + 2 # skip the MAKE_FUNCTION + STORE_FAST pair
|
| 199 |
+
em.push(func_name) # not immediately assigned — push onto stack for later use (e.g. as an argument)
|
| 200 |
+
return None
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/containers.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Handlers for BUILD_*, UNPACK_*, LIST_EXTEND/APPEND, SET_ADD, MAP_ADD,
|
| 16 |
+
FORMAT_VALUE, and BUILD_SLICE / BUILD_STRING."""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
from ..decompile_context import DecompileContext
|
| 23 |
+
from ..handler_registry import registry
|
| 24 |
+
from ..instruction import Instruction
|
| 25 |
+
from ..source_emitter import SourceEmitter
|
| 26 |
+
|
| 27 |
+
_reg = registry.register
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ── BUILD tuple / list / set ──────────────────────────────────────────────
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _safe_str(val) -> str:
|
| 34 |
+
"""Convert a stack value to string, handling None sentinels from PUSH_NULL."""
|
| 35 |
+
return "None" if val is None else str(val)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@_reg("BUILD_TUPLE", "BUILD_TUPLE_UNPACK", "BUILD_TUPLE_UNPACK_WITH_CALL")
|
| 39 |
+
def _build_tuple(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 40 |
+
args = [_safe_str(em.pop()) for _ in range(inst.argval)][::-1]
|
| 41 |
+
if "UNPACK" in inst.opname:
|
| 42 |
+
args = [f"*{a}" for a in args]
|
| 43 |
+
em.push(f"({args[0]},)" if inst.argval == 1 else f"({', '.join(args)})")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@_reg("BUILD_LIST", "BUILD_LIST_UNPACK")
|
| 47 |
+
def _build_list(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 48 |
+
args = [_safe_str(em.pop()) for _ in range(inst.argval)][::-1]
|
| 49 |
+
if "UNPACK" in inst.opname:
|
| 50 |
+
args = [f"*{a}" for a in args]
|
| 51 |
+
em.push(f"[{', '.join(args)}]")
|
| 52 |
+
em.replace_tos_with_temp()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@_reg("BUILD_SET", "BUILD_SET_UNPACK")
|
| 56 |
+
def _build_set(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 57 |
+
if inst.argval == 0:
|
| 58 |
+
em.push("set()")
|
| 59 |
+
else:
|
| 60 |
+
args = [em.pop() for _ in range(inst.argval)][::-1]
|
| 61 |
+
if "UNPACK" in inst.opname:
|
| 62 |
+
args = [f"*{a}" for a in args]
|
| 63 |
+
em.push(f"{{{', '.join(args)}}}")
|
| 64 |
+
em.replace_tos_with_temp()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ── BUILD map ─────────────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@_reg("BUILD_MAP")
|
| 71 |
+
def _build_map(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 72 |
+
items = [em.pop() for _ in range(inst.argval * 2)][::-1]
|
| 73 |
+
keys, vals = items[::2], items[1::2]
|
| 74 |
+
em.push(f"{{{', '.join(f'{k}: {v}' for k, v in zip(keys, vals))}}}")
|
| 75 |
+
em.replace_tos_with_temp()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@_reg("BUILD_MAP_UNPACK", "BUILD_MAP_UNPACK_WITH_CALL")
|
| 79 |
+
def _build_map_unpack(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 80 |
+
if inst.argval == 0:
|
| 81 |
+
em.push("dict()")
|
| 82 |
+
else:
|
| 83 |
+
args = [em.pop() for _ in range(inst.argval)][::-1]
|
| 84 |
+
em.push(f"{{{', '.join(f'**{a}' for a in args)}}}")
|
| 85 |
+
em.replace_tos_with_temp()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@_reg("BUILD_CONST_KEY_MAP")
|
| 89 |
+
def _const_key_map(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 90 |
+
keys = eval(em.pop())
|
| 91 |
+
vals = [em.pop() for _ in range(inst.argval)][::-1]
|
| 92 |
+
em.push(f"{{{', '.join(f'{k!r}: {v}' for k, v in zip(keys, vals))}}}")
|
| 93 |
+
em.replace_tos_with_temp()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@_reg("BUILD_STRING")
|
| 97 |
+
def _build_string(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 98 |
+
args = [em.pop() for _ in range(inst.argval)][::-1]
|
| 99 |
+
em.push(" + ".join(args))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@_reg("BUILD_SLICE")
|
| 103 |
+
def _build_slice(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 104 |
+
tos = em.pop()
|
| 105 |
+
tos1 = em.pop()
|
| 106 |
+
if inst.argval == 2:
|
| 107 |
+
em.push(f"slice({tos1}, {tos})")
|
| 108 |
+
elif inst.argval == 3:
|
| 109 |
+
tos2 = em.pop()
|
| 110 |
+
em.push(f"slice({tos2}, {tos1}, {tos})")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@_reg("LIST_TO_TUPLE")
|
| 114 |
+
def _list_to_tuple(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 115 |
+
em.push(f"tuple({em.pop()})")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ── Mutating container ops ────────────────────────────────────────────────
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@_reg("LIST_EXTEND")
|
| 122 |
+
def _list_extend(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 123 |
+
values = em.pop()
|
| 124 |
+
temp = em.replace_tos_with_temp(depth=inst.argval)
|
| 125 |
+
em.emit(f"{temp}.extend({values})")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@_reg("LIST_APPEND")
|
| 129 |
+
def _list_append(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 130 |
+
argval = inst.argval if inst.argval != 1 else 2
|
| 131 |
+
container = em.stack[-argval]
|
| 132 |
+
value = em.pop()
|
| 133 |
+
em.emit(f"{container}.append({value})")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@_reg("SET_UPDATE", "DICT_UPDATE", "DICT_MERGE")
|
| 137 |
+
def _generic_update(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 138 |
+
assert inst.argval == 1, "Only tested for argval==1"
|
| 139 |
+
values = em.pop()
|
| 140 |
+
temp = em.replace_tos_with_temp()
|
| 141 |
+
em.emit(f"{temp}.update({values})")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@_reg("SET_ADD")
|
| 145 |
+
def _set_add(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 146 |
+
argval = inst.argval if inst.argval != 1 else 2
|
| 147 |
+
container = em.stack[-argval]
|
| 148 |
+
value = em.pop()
|
| 149 |
+
em.emit(f"{container}.add({value})")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@_reg("MAP_ADD")
|
| 153 |
+
def _map_add(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 154 |
+
container = em.stack[-inst.argval - 1]
|
| 155 |
+
if sys.version_info >= (3, 8):
|
| 156 |
+
value = em.pop()
|
| 157 |
+
key = em.pop()
|
| 158 |
+
else:
|
| 159 |
+
key = em.pop()
|
| 160 |
+
value = em.pop()
|
| 161 |
+
em.emit(f"{container}.__setitem__({key}, {value})")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ── Unpack ────────────────────────────────────────────────────────────────
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@_reg("UNPACK_SEQUENCE")
|
| 168 |
+
def _unpack_seq(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 169 |
+
varname = em.pop()
|
| 170 |
+
tmps = [em.make_temp() for _ in range(inst.argval)]
|
| 171 |
+
em.emit("".join(f"{t}, " for t in tmps) + f"= {varname}")
|
| 172 |
+
for t in reversed(tmps):
|
| 173 |
+
em.push(t)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@_reg("UNPACK_EX")
|
| 177 |
+
def _unpack_ex(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 178 |
+
varname = em.pop()
|
| 179 |
+
tmps = [em.make_temp() for _ in range(inst.argval)]
|
| 180 |
+
star = em.make_temp()
|
| 181 |
+
em.emit(f"{', '.join(tmps)}, *{star} = {varname}")
|
| 182 |
+
em.push(star)
|
| 183 |
+
for t in reversed(tmps):
|
| 184 |
+
em.push(t)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ── Format ────────────────────────────────────────────────────────────────
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@_reg("FORMAT_VALUE")
|
| 191 |
+
def _format_value(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 192 |
+
func, spec = inst.argval
|
| 193 |
+
if spec:
|
| 194 |
+
form_spec = em.pop()
|
| 195 |
+
value = em.pop()
|
| 196 |
+
em.push(f"format({value}, {form_spec})")
|
| 197 |
+
else:
|
| 198 |
+
value = em.pop()
|
| 199 |
+
fn = str if func is None else func
|
| 200 |
+
em.push(f"{fn.__name__}({value})")
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/control_flow.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Handlers for control-flow opcodes: jumps, if/else, for, return, yield, raise."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from ..decompile_context import DecompileContext
|
| 22 |
+
from ..handler_registry import registry
|
| 23 |
+
from ..instruction import Instruction
|
| 24 |
+
from ..source_emitter import LoopContext, SourceEmitter
|
| 25 |
+
|
| 26 |
+
_reg = registry.register
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ── Simple returns / yield / raise ────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@_reg("RETURN_VALUE")
|
| 33 |
+
def _return_value(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 34 |
+
em.emit(f"return {em.peek()}")
|
| 35 |
+
em.pop()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@_reg("RETURN_CONST")
|
| 39 |
+
def _return_const(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 40 |
+
em.emit(f"return {repr(inst.argval)}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@_reg("YIELD_VALUE")
|
| 44 |
+
def _yield_value(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 45 |
+
import sys
|
| 46 |
+
|
| 47 |
+
if sys.version_info >= (3, 12):
|
| 48 |
+
raise NotImplementedError("YIELD_VALUE is not supported in Python 3.12+")
|
| 49 |
+
em.emit(f"yield {em.peek()}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@_reg("RETURN_GENERATOR")
|
| 53 |
+
def _return_generator(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 54 |
+
"""Python 3.11+ generator function prologue. Each generator has its own stack frame;
|
| 55 |
+
RETURN_GENERATOR creates the generator object and returns it to the caller,
|
| 56 |
+
subsequent next(gen) resumes from RESUME. Push None as a placeholder during decompilation."""
|
| 57 |
+
em.push(None)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@_reg("GEN_START")
|
| 61 |
+
def _gen_start(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 62 |
+
"""Python 3.11 marks generator start (replaced by RESUME in 3.12)."""
|
| 63 |
+
assert inst.argval == 0, "Only generator expression is supported"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@_reg("RAISE_VARARGS")
|
| 67 |
+
def _raise_varargs(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 68 |
+
if inst.argval == 0:
|
| 69 |
+
em.emit("raise")
|
| 70 |
+
elif inst.argval == 1:
|
| 71 |
+
em.emit(f"raise {em.pop()}")
|
| 72 |
+
elif inst.argval == 2:
|
| 73 |
+
tos = em.pop()
|
| 74 |
+
tos1 = em.pop()
|
| 75 |
+
em.emit(f"raise {tos1} from {tos}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@_reg("BREAK_LOOP")
|
| 79 |
+
def _break_loop(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 80 |
+
em.emit("break")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ── Unconditional jumps ───────────────────────────────────────────────────
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@_reg("JUMP_ABSOLUTE")
|
| 87 |
+
@_reg("JUMP_FORWARD")
|
| 88 |
+
@_reg("JUMP_BACKWARD")
|
| 89 |
+
@_reg("JUMP_BACKWARD_NO_INTERRUPT")
|
| 90 |
+
def _abs_jump(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
|
| 91 |
+
"""Unconditional jump. Returns len(instructions) to make decompile_range stop immediately."""
|
| 92 |
+
target = inst.jump_target_offset()
|
| 93 |
+
idx = ctx.index_of(target)
|
| 94 |
+
loop = em.loop
|
| 95 |
+
if loop is not None:
|
| 96 |
+
if idx >= loop.end_index:
|
| 97 |
+
em.emit("break")
|
| 98 |
+
return len(ctx.instructions)
|
| 99 |
+
if idx == loop.start_index:
|
| 100 |
+
em.emit("continue")
|
| 101 |
+
return len(ctx.instructions)
|
| 102 |
+
return idx
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ── Conditional jumps (if / else) ─────────────────────────────────────────
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@_reg("POP_JUMP_IF_TRUE", "POP_JUMP_IF_FALSE")
|
| 109 |
+
@_reg("POP_JUMP_FORWARD_IF_TRUE", "POP_JUMP_FORWARD_IF_FALSE")
|
| 110 |
+
@_reg("POP_JUMP_BACKWARD_IF_TRUE", "POP_JUMP_BACKWARD_IF_FALSE")
|
| 111 |
+
@_reg("POP_JUMP_FORWARD_IF_NONE", "POP_JUMP_FORWARD_IF_NOT_NONE")
|
| 112 |
+
@_reg("POP_JUMP_BACKWARD_IF_NONE", "POP_JUMP_BACKWARD_IF_NOT_NONE")
|
| 113 |
+
@_reg("JUMP_IF_TRUE_OR_POP", "JUMP_IF_FALSE_OR_POP")
|
| 114 |
+
@_reg("POP_JUMP_IF_NOT_NONE", "POP_JUMP_IF_NONE")
|
| 115 |
+
def _jump_if(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
|
| 116 |
+
"""Decompile if/else structure.
|
| 117 |
+
|
| 118 |
+
Standard if/else bytecode:
|
| 119 |
+
POP_JUMP_IF_FALSE else_start ← this_idx
|
| 120 |
+
(if-body)
|
| 121 |
+
JUMP_FORWARD after_else ← last instruction of if-body
|
| 122 |
+
>> else_start: ← jump_idx
|
| 123 |
+
(else-body)
|
| 124 |
+
>> after_else: ← merge point (end)
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
jump_offset = inst.jump_target_offset()
|
| 128 |
+
jump_idx = ctx.index_of(jump_offset)
|
| 129 |
+
this_idx = ctx.index_of(inst.offset)
|
| 130 |
+
|
| 131 |
+
# ── Step 1: condition expression and branch stack state ──
|
| 132 |
+
cond = em.peek()
|
| 133 |
+
fall_stack = list(em.stack)
|
| 134 |
+
jump_stack = list(em.stack)
|
| 135 |
+
|
| 136 |
+
if "IF_NOT_NONE" in inst.opname:
|
| 137 |
+
cond = f"({cond} is None)"
|
| 138 |
+
elif "IF_NONE" in inst.opname:
|
| 139 |
+
cond = f"({cond} is not None)"
|
| 140 |
+
elif "IF_TRUE" in inst.opname:
|
| 141 |
+
cond = f"(not {cond})"
|
| 142 |
+
else:
|
| 143 |
+
cond = f"{cond}"
|
| 144 |
+
|
| 145 |
+
if "POP_JUMP" in inst.opname:
|
| 146 |
+
jump_stack.pop()
|
| 147 |
+
fall_stack.pop()
|
| 148 |
+
elif "OR_POP" in inst.opname:
|
| 149 |
+
fall_stack.pop()
|
| 150 |
+
|
| 151 |
+
# ── Step 2: merge point candidate upper bounds ──
|
| 152 |
+
merge_upper_bounds = [len(ctx.instructions)]
|
| 153 |
+
if em.loop is not None:
|
| 154 |
+
merge_upper_bounds.append(em.loop.end_index)
|
| 155 |
+
|
| 156 |
+
# ── Step 3: find "skip else" JUMPs in the if-body ──
|
| 157 |
+
def _is_forward_past_else(i: Instruction) -> bool:
|
| 158 |
+
return i.is_jump and i.jump_target_offset() >= jump_offset
|
| 159 |
+
|
| 160 |
+
forward_targets = [i.jump_target_offset() for i in ctx.instructions[this_idx:jump_idx] if _is_forward_past_else(i)]
|
| 161 |
+
|
| 162 |
+
# ── Step 4: compute merge point by case ──
|
| 163 |
+
if not forward_targets:
|
| 164 |
+
if jump_idx <= this_idx:
|
| 165 |
+
# Case C: backward jump (inside loop), emit if cond: continue
|
| 166 |
+
rev_cond = em.peek()
|
| 167 |
+
if "IF_NOT_NONE" in inst.opname:
|
| 168 |
+
rev_cond = f"({rev_cond} is not None)"
|
| 169 |
+
elif "IF_NONE" in inst.opname:
|
| 170 |
+
rev_cond = f"({rev_cond} is None)"
|
| 171 |
+
elif "IF_TRUE" in inst.opname:
|
| 172 |
+
rev_cond = f"{rev_cond}"
|
| 173 |
+
elif "IF_FALSE" in inst.opname:
|
| 174 |
+
rev_cond = f"(not {rev_cond})"
|
| 175 |
+
em.emit(f"if {rev_cond}:")
|
| 176 |
+
em.emit(em.indent("continue\n").rstrip("\n"))
|
| 177 |
+
return None
|
| 178 |
+
# Case B: both branches terminate with RETURN/RAISE
|
| 179 |
+
end = jump_idx
|
| 180 |
+
else:
|
| 181 |
+
# Case A: standard if/else, infer merge point from forward_targets
|
| 182 |
+
max_jump = max(forward_targets)
|
| 183 |
+
max_idx = ctx.index_of(max_jump)
|
| 184 |
+
all_targets = [i.jump_target_offset() for i in ctx.instructions[this_idx:max_idx] if _is_forward_past_else(i)]
|
| 185 |
+
max_idx = ctx.index_of(max(all_targets))
|
| 186 |
+
|
| 187 |
+
last = ctx.instructions[max_idx - 1]
|
| 188 |
+
if not ("RAISE" in last.opname or "RETURN" in last.opname or "STORE" in last.opname):
|
| 189 |
+
old = max_idx
|
| 190 |
+
while max_idx < len(ctx.instructions):
|
| 191 |
+
op = ctx.instructions[max_idx].opname
|
| 192 |
+
if "STORE" in op or "RETURN" in op:
|
| 193 |
+
max_idx += 1
|
| 194 |
+
break
|
| 195 |
+
if ("JUMP" in op and max_idx > old) or "FOR_ITER" in op:
|
| 196 |
+
break
|
| 197 |
+
max_idx += 1
|
| 198 |
+
|
| 199 |
+
merge_upper_bounds.append(max_idx)
|
| 200 |
+
end = min(merge_upper_bounds)
|
| 201 |
+
|
| 202 |
+
# ── Step 5: else-body end position (PR#91 fix) ──
|
| 203 |
+
else_end = end
|
| 204 |
+
if end == jump_idx and jump_idx < len(ctx.instructions):
|
| 205 |
+
last_if = ctx.instructions[jump_idx - 1]
|
| 206 |
+
if "RETURN" in last_if.opname or "RAISE" in last_if.opname:
|
| 207 |
+
else_end = len(ctx.instructions)
|
| 208 |
+
if em.loop is not None:
|
| 209 |
+
else_end = min(else_end, em.loop.end_index)
|
| 210 |
+
|
| 211 |
+
# ── Step 6: decompile both branches ──
|
| 212 |
+
with em.fork(stack=fall_stack) as if_em:
|
| 213 |
+
ctx.decompile_range(this_idx + 1, jump_idx, if_em)
|
| 214 |
+
if_body = em.indent(if_em.get_source())
|
| 215 |
+
if_end_stack = list(if_em.stack)
|
| 216 |
+
em.emit_raw(f"if {cond}:\n{if_body}")
|
| 217 |
+
|
| 218 |
+
with em.fork(stack=jump_stack) as else_em:
|
| 219 |
+
ctx.decompile_range(jump_idx, else_end, else_em)
|
| 220 |
+
else_body = else_em.get_source()
|
| 221 |
+
if else_body:
|
| 222 |
+
em.emit_raw(f"else:\n{em.indent(else_body)}")
|
| 223 |
+
|
| 224 |
+
em.stack[:] = if_end_stack
|
| 225 |
+
return else_end
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ── FOR_ITER ──────────────────────────────────────────────────────────────
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@_reg("FOR_ITER")
|
| 232 |
+
def _for_iter(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
|
| 233 |
+
"""Decompile for loop.
|
| 234 |
+
|
| 235 |
+
Bytecode layout (3.12):
|
| 236 |
+
FOR_ITER target ← get next value; jump to target when exhausted (END_FOR)
|
| 237 |
+
(loop body)
|
| 238 |
+
JUMP_BACKWARD for_iter ← normal back-jump (not continue)
|
| 239 |
+
>> target: END_FOR
|
| 240 |
+
|
| 241 |
+
Loop body range excludes the trailing JUMP_BACKWARD to avoid emitting a spurious continue.
|
| 242 |
+
"""
|
| 243 |
+
start_idx = ctx.index_of(inst.offset)
|
| 244 |
+
end_idx = ctx.index_of(inst.jump_target_offset())
|
| 245 |
+
|
| 246 |
+
temp = em.make_temp()
|
| 247 |
+
iterator = em.pop()
|
| 248 |
+
em.push(temp)
|
| 249 |
+
|
| 250 |
+
# Determine the actual end position of the loop body:
|
| 251 |
+
# if the instruction at end_idx is a back-jump to FOR_ITER, extend end_idx so
|
| 252 |
+
# the LoopContext boundary is correct (break needs to jump past end_idx)
|
| 253 |
+
if end_idx < len(ctx.instructions):
|
| 254 |
+
at_end = ctx.instructions[end_idx]
|
| 255 |
+
if at_end.is_jump and at_end.jump_target_offset() == inst.offset:
|
| 256 |
+
end_idx += 1
|
| 257 |
+
|
| 258 |
+
# Exclude the trailing JUMP_BACKWARD: it is the normal loop back-jump mechanism, not continue.
|
| 259 |
+
# Only JUMP_BACKWARDs in the middle of the loop body are continue (handled by _abs_jump).
|
| 260 |
+
body_end = end_idx
|
| 261 |
+
if body_end > start_idx + 1:
|
| 262 |
+
back_jump = ctx.instructions[body_end - 1]
|
| 263 |
+
if back_jump.is_jump and back_jump.jump_target_offset() == inst.offset:
|
| 264 |
+
body_end -= 1
|
| 265 |
+
|
| 266 |
+
loop = LoopContext(start_index=start_idx, end_index=end_idx)
|
| 267 |
+
with em.fork(stack=list(em.stack), loop=loop) as body_em:
|
| 268 |
+
ctx.decompile_range(start_idx + 1, body_end, body_em)
|
| 269 |
+
|
| 270 |
+
body_src = em.indent(body_em.get_source())
|
| 271 |
+
em.emit_raw(f"for {temp} in {iterator}:\n{body_src}")
|
| 272 |
+
em.stack[:] = body_em.stack
|
| 273 |
+
return end_idx
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/load_store.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Handlers for LOAD_*, STORE_*, DELETE_*, IMPORT_*, PUSH_NULL, GET_ITER."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from types import CodeType
|
| 20 |
+
|
| 21 |
+
from ..decompile_context import DecompileContext
|
| 22 |
+
from ..handler_registry import registry
|
| 23 |
+
from ..instruction import Instruction
|
| 24 |
+
from ..source_emitter import SourceEmitter
|
| 25 |
+
|
| 26 |
+
_reg = registry.register
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ── NOP / unsupported sentinels ──────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@_reg("NOP", "RESUME", "EXTENDED_ARG", "SETUP_LOOP", "POP_BLOCK")
|
| 33 |
+
@_reg("PRECALL", "BEGIN_FINALLY", "END_FINALLY", "MAKE_CELL")
|
| 34 |
+
@_reg("RERAISE", "END_FOR", "COPY_FREE_VARS")
|
| 35 |
+
def _nop(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@_reg("GET_YIELD_FROM_ITER")
|
| 40 |
+
@_reg("POP_EXCEPT", "WITH_EXCEPT_START", "JUMP_IF_NOT_EXC_MATCH")
|
| 41 |
+
@_reg("CHECK_EG_MATCH", "PUSH_EXC_INFO", "PREP_RERAISE_STAR")
|
| 42 |
+
@_reg("WITH_CLEANUP_FINISH", "CALL_FINALLY", "POP_FINALLY")
|
| 43 |
+
@_reg("WITH_CLEANUP_START", "SETUP_EXCEPT", "CHECK_EXC_MATCH")
|
| 44 |
+
@_reg("CLEANUP_THROW")
|
| 45 |
+
@_reg("GET_AWAITABLE", "GET_AITER", "GET_ANEXT", "END_ASYNC_FOR")
|
| 46 |
+
@_reg("BEFORE_ASYNC_WITH", "SETUP_ASYNC_WITH", "SEND", "ASYNC_GEN_WRAP")
|
| 47 |
+
@_reg("CACHE")
|
| 48 |
+
@_reg("PRINT_EXPR", "COPY_DICT_WITHOUT_KEYS")
|
| 49 |
+
@_reg("IMPORT_STAR")
|
| 50 |
+
@_reg("YIELD_FROM", "SETUP_ANNOTATIONS", "LOAD_BUILD_CLASS")
|
| 51 |
+
@_reg("MATCH_MAPPING", "MATCH_SEQUENCE", "MATCH_KEYS", "MATCH_CLASS")
|
| 52 |
+
@_reg("CALL_INTRINSIC_2")
|
| 53 |
+
@_reg("SETUP_FINALLY", "SETUP_WITH", "BEFORE_WITH")
|
| 54 |
+
def _unsupported(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 55 |
+
from ...decompiler import DecompilationError
|
| 56 |
+
|
| 57 |
+
raise DecompilationError(f"Unsupported opcode: {inst.opname}", instruction=inst)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ── LOAD instructions ────────────────────────────────────────────────────
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@_reg("LOAD_CONST")
|
| 64 |
+
def _load_const(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 65 |
+
"""Load a constant. Branches: can_repr → direct repr / type → importlib /
|
| 66 |
+
torch prefix → import torch / CodeType → push as-is for MAKE_FUNCTION."""
|
| 67 |
+
can_repr = False
|
| 68 |
+
try:
|
| 69 |
+
can_repr = eval(repr(inst.argval)) == inst.argval
|
| 70 |
+
except BaseException:
|
| 71 |
+
pass
|
| 72 |
+
if can_repr:
|
| 73 |
+
em.push(repr(inst.argval))
|
| 74 |
+
elif isinstance(inst.argval, type):
|
| 75 |
+
module = inst.argval.__module__
|
| 76 |
+
name = inst.argval.__name__
|
| 77 |
+
em.emit("import importlib")
|
| 78 |
+
tmp = em.make_temp()
|
| 79 |
+
em.emit(f'{tmp} = importlib.import_module("{module}").{name}')
|
| 80 |
+
em.push(tmp)
|
| 81 |
+
elif inst.argrepr.startswith("torch."):
|
| 82 |
+
em.emit("import torch")
|
| 83 |
+
tmp = em.make_temp()
|
| 84 |
+
em.emit(f"{tmp} = {inst.argval}")
|
| 85 |
+
em.push(tmp)
|
| 86 |
+
elif isinstance(inst.argval, CodeType):
|
| 87 |
+
em.push(inst.argval)
|
| 88 |
+
else:
|
| 89 |
+
from ...decompiler import DecompilationError
|
| 90 |
+
|
| 91 |
+
raise DecompilationError(
|
| 92 |
+
f"LOAD_CONST: cannot represent co_consts[{inst.arg}] = {repr(inst.argval)!r} "
|
| 93 |
+
f"(type {type(inst.argval).__name__}) as source code",
|
| 94 |
+
instruction=inst,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@_reg("LOAD_FAST", "LOAD_FAST_CHECK")
|
| 99 |
+
@_reg("LOAD_GLOBAL", "LOAD_DEREF", "LOAD_NAME")
|
| 100 |
+
@_reg("LOAD_CLASSDEREF", "LOAD_CLOSURE")
|
| 101 |
+
def _generic_load(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 102 |
+
"""Generic load. 3.11+ LOAD_GLOBAL argrepr "NULL + name" pushes a NULL sentinel first.
|
| 103 |
+
Python <3.12 comprehension parameter name ".0" is replaced with "comp_arg_0"."""
|
| 104 |
+
if "NULL + " in inst.argrepr:
|
| 105 |
+
em.push(None)
|
| 106 |
+
if inst.argrepr.startswith("."):
|
| 107 |
+
em.push(inst.argval.replace(".", "comp_arg_"))
|
| 108 |
+
else:
|
| 109 |
+
em.push(inst.argval)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Python 3.12 comprehension variable protection: LOAD_FAST_AND_CLEAR saves old value + STORE_FAST restores.
|
| 113 |
+
# During decompilation, temp variables used for loops don't need save/restore; push a sentinel so STORE_FAST skips.
|
| 114 |
+
_CLEAR_SENTINEL = object()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@_reg("LOAD_FAST_AND_CLEAR")
|
| 118 |
+
def _load_fast_and_clear(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 119 |
+
em.push(_CLEAR_SENTINEL)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@_reg("LOAD_LOCALS")
|
| 123 |
+
def _load_locals(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 124 |
+
"""3.12 class body: locals() returns a new dict snapshot, cached in a temp to avoid repeated calls."""
|
| 125 |
+
em.push("locals()")
|
| 126 |
+
em.replace_tos_with_temp()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@_reg("LOAD_FROM_DICT_OR_GLOBALS", "LOAD_FROM_DICT_OR_DEREF")
|
| 130 |
+
def _load_from_dict(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 131 |
+
"""3.12 class body: look up in locals dict first, fall back to globals if not found."""
|
| 132 |
+
tos = em.pop()
|
| 133 |
+
em.push(f"{tos}[{inst.argval}] if '{inst.argval}' in {tos} else {inst.argval}")
|
| 134 |
+
em.replace_tos_with_temp()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@_reg("LOAD_ATTR")
|
| 138 |
+
def _load_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 139 |
+
"""Attribute access. isidentifier() checks if the attr name is valid; if not, use getattr()."""
|
| 140 |
+
lhs = str(em.pop())
|
| 141 |
+
rhs = inst.argval
|
| 142 |
+
em.push(f"{lhs}.{rhs}" if rhs.isidentifier() else f"getattr({lhs}, {rhs!r})")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@_reg("LOAD_SUPER_ATTR")
|
| 146 |
+
def _load_super_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 147 |
+
self_obj = em.pop()
|
| 148 |
+
cls_obj = em.pop()
|
| 149 |
+
super_obj = em.pop()
|
| 150 |
+
em.push(f"{super_obj}({cls_obj}, {self_obj}).{inst.argval}")
|
| 151 |
+
em.replace_tos_with_temp()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@_reg("LOAD_METHOD")
|
| 155 |
+
def _load_method(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 156 |
+
em.push(f"{em.pop()}.{inst.argval}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@_reg("LOAD_ASSERTION_ERROR")
|
| 160 |
+
def _load_assertion_error(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 161 |
+
em.push("AssertionError")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@_reg("PUSH_NULL")
|
| 165 |
+
def _push_null(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 166 |
+
"""3.11+ pushes a NULL sentinel before function calls; the CALL handler will clear it."""
|
| 167 |
+
em.push(None)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@_reg("GET_ITER")
|
| 171 |
+
def _get_iter(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 172 |
+
em.push(f"iter({em.pop()})")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ── STORE instructions ───────────────────────────────────────────────────
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@_reg("STORE_FAST", "STORE_GLOBAL", "STORE_DEREF", "STORE_NAME")
|
| 179 |
+
def _generic_store(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 180 |
+
"""Generic store. Skips _CLEAR_SENTINEL and self-assignment, protects variable names on the stack that are about to be overwritten."""
|
| 181 |
+
left = inst.argval
|
| 182 |
+
right = em.pop()
|
| 183 |
+
if right is _CLEAR_SENTINEL:
|
| 184 |
+
return
|
| 185 |
+
if left != right:
|
| 186 |
+
if isinstance(left, str) and left in em.stack:
|
| 187 |
+
tmp = em.make_temp()
|
| 188 |
+
em.emit(f"{tmp} = {left}")
|
| 189 |
+
em.stack[:] = [tmp if x == left else x for x in em.stack]
|
| 190 |
+
em.emit(f"{left} = {right}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@_reg("STORE_SUBSCR")
|
| 194 |
+
def _store_subscr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 195 |
+
index = em.pop()
|
| 196 |
+
obj = em.pop()
|
| 197 |
+
value = em.pop()
|
| 198 |
+
em.emit(f"{obj}[{index}] = {value}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@_reg("STORE_SLICE")
|
| 202 |
+
def _store_slice(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 203 |
+
end = em.pop()
|
| 204 |
+
start = em.pop()
|
| 205 |
+
container = em.pop()
|
| 206 |
+
value = em.pop()
|
| 207 |
+
em.emit(f"{container}[{start}:{end}] = {value}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@_reg("STORE_ATTR")
|
| 211 |
+
def _store_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 212 |
+
obj = em.pop()
|
| 213 |
+
value = em.pop()
|
| 214 |
+
em.emit(f"{obj}.{inst.argval} = {value}")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ── DELETE instructions ──────────────────────────────────────────────────
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@_reg("DELETE_SUBSCR")
|
| 221 |
+
def _delete_subscr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 222 |
+
index = em.pop()
|
| 223 |
+
obj = em.pop()
|
| 224 |
+
if f"{obj}[{index}]" not in em.stack:
|
| 225 |
+
em.emit(f"del {obj}[{index}]")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@_reg("DELETE_NAME", "DELETE_GLOBAL", "DELETE_DEREF")
|
| 229 |
+
def _generic_delete(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 230 |
+
em.emit(f"del {inst.argval}")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@_reg("DELETE_FAST")
|
| 234 |
+
def _delete_fast(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 235 |
+
"""Dynamo cleans up temp variables; no explicit del needed after decompilation."""
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@_reg("DELETE_ATTR")
|
| 240 |
+
def _delete_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 241 |
+
em.emit(f"del {em.pop()}.{inst.argval}")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ── IMPORT instructions ──────────────────────────────────────────────────
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@_reg("IMPORT_NAME")
|
| 248 |
+
def _import_name(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 249 |
+
"""import os.path → binds 'os' (top-level module), accesses submodules via os.path.sep."""
|
| 250 |
+
name = inst.argval.split(".")[0]
|
| 251 |
+
fromlist = em.pop()
|
| 252 |
+
level = em.pop()
|
| 253 |
+
em.emit(f"{name} = __import__({inst.argval!r}, fromlist={fromlist}, level={level})")
|
| 254 |
+
em.push(name)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@_reg("IMPORT_FROM")
|
| 258 |
+
def _import_from(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 259 |
+
name = inst.argval
|
| 260 |
+
module = em.peek()
|
| 261 |
+
em.emit(f"{name} = {module}.{name}")
|
| 262 |
+
em.push(name)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/stack_ops.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Handlers for stack-manipulation opcodes: ROT, SWAP, COPY, POP, DUP.
|
| 16 |
+
|
| 17 |
+
See bytecode_explained.py §16 for details.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
from ..decompile_context import DecompileContext
|
| 23 |
+
from ..handler_registry import registry
|
| 24 |
+
from ..instruction import Instruction
|
| 25 |
+
from ..source_emitter import SourceEmitter
|
| 26 |
+
|
| 27 |
+
_reg = registry.register
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ── ROT_N family (Python ≤3.10, replaced by SWAP/COPY in 3.11+) ───────────
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@_reg("ROT_N")
|
| 34 |
+
@_reg("ROT_TWO")
|
| 35 |
+
@_reg("ROT_THREE")
|
| 36 |
+
@_reg("ROT_FOUR")
|
| 37 |
+
def _rot_n(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 38 |
+
"""Top-n stack rotation: [a, b, c] → [c, a, b] (n=3).
|
| 39 |
+
|
| 40 |
+
ROT_TWO: a, b → b, a (swap, used for a, b = b, a)
|
| 41 |
+
ROT_THREE: a, b, c → c, a, b (3-element rotation)
|
| 42 |
+
ROT_FOUR: a, b, c, d → d, a, b, c (4-element rotation)
|
| 43 |
+
ROT_N: generic n-element rotation (argval = n)
|
| 44 |
+
"""
|
| 45 |
+
n = inst.argval if inst.opname == "ROT_N" else {"ROT_TWO": 2, "ROT_THREE": 3, "ROT_FOUR": 4}[inst.opname]
|
| 46 |
+
vals = em.stack[-n:]
|
| 47 |
+
em.stack[-n:] = [vals[-1]] + vals[:-1]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@_reg("SWAP")
|
| 51 |
+
def _swap(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 52 |
+
"""Python 3.11+: swap stack[-1] and stack[-n]."""
|
| 53 |
+
n = inst.argval
|
| 54 |
+
em.stack[-1], em.stack[-n] = em.stack[-n], em.stack[-1]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@_reg("COPY")
|
| 58 |
+
def _copy(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 59 |
+
"""Python 3.11+: copy stack[-n] to top of stack (COPY 1 = DUP_TOP)."""
|
| 60 |
+
n = inst.argval
|
| 61 |
+
if n == 0:
|
| 62 |
+
return
|
| 63 |
+
em.push(em.stack[-1 - (n - 1)])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@_reg("POP_TOP")
|
| 67 |
+
def _pop_top(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 68 |
+
if em.stack_size > 0:
|
| 69 |
+
em.pop()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@_reg("DUP_TOP")
|
| 73 |
+
def _dup_top(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 74 |
+
"""Python ≤3.10: duplicate top of stack. Replaced by COPY 1 in 3.11+."""
|
| 75 |
+
em.push(em.peek())
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@_reg("DUP_TOP_TWO")
|
| 79 |
+
def _dup_top_two(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
|
| 80 |
+
"""Python ≤3.10: duplicate top two stack items. Replaced by two COPYs in 3.11+."""
|
| 81 |
+
tos = em.peek(0)
|
| 82 |
+
tos1 = em.peek(1)
|
| 83 |
+
em.push(tos1)
|
| 84 |
+
em.push(tos)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/instruction.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Enhanced Instruction dataclass with rich querying properties."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import dataclasses
|
| 20 |
+
import dis
|
| 21 |
+
import sys
|
| 22 |
+
from typing import Any, Optional
|
| 23 |
+
|
| 24 |
+
_ALL_JUMP_OPCODES = frozenset(dis.hasjabs) | frozenset(dis.hasjrel)
|
| 25 |
+
_PY311 = sys.version_info >= (3, 11)
|
| 26 |
+
|
| 27 |
+
_LOAD_OPCODES = frozenset(n for n in dis.opname if n.startswith("LOAD_") or n in ("PUSH_NULL", "GET_ITER"))
|
| 28 |
+
_STORE_OPCODES = frozenset(n for n in dis.opname if n.startswith("STORE_"))
|
| 29 |
+
_DELETE_OPCODES = frozenset(n for n in dis.opname if n.startswith("DELETE_"))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclasses.dataclass
|
| 33 |
+
class Instruction:
|
| 34 |
+
"""Mutable mirror of ``dis.Instruction`` with convenience queries.
|
| 35 |
+
|
| 36 |
+
Unlike the stdlib version this is mutable so cleanup passes can
|
| 37 |
+
modify instructions in-place (e.g. NOP-ing unreachable bytecode).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
opcode: int
|
| 41 |
+
opname: str
|
| 42 |
+
# arg: raw integer argument (the number in the bytecode), may be an index into co_consts/co_varnames
|
| 43 |
+
# argval: Python object resolved by the dis module (value of co_consts[arg], or a variable name string)
|
| 44 |
+
# argrepr: human-readable string of argval (e.g. "NULL + print", "to 20")
|
| 45 |
+
# See bytecode_explained.py §1 for details
|
| 46 |
+
arg: Optional[int]
|
| 47 |
+
argval: Any
|
| 48 |
+
argrepr: str
|
| 49 |
+
offset: Optional[int] = None
|
| 50 |
+
starts_line: Optional[int] = None
|
| 51 |
+
is_jump_target: bool = False
|
| 52 |
+
|
| 53 |
+
# -- identity / hashing (by object id, not value) ----------------------
|
| 54 |
+
|
| 55 |
+
def __hash__(self) -> int:
|
| 56 |
+
return id(self)
|
| 57 |
+
|
| 58 |
+
def __eq__(self, other: object) -> bool:
|
| 59 |
+
return self is other
|
| 60 |
+
|
| 61 |
+
def __repr__(self) -> str:
|
| 62 |
+
return f"Instruction({self.opname}, offset={self.offset}, argval={self.argrepr!r})"
|
| 63 |
+
|
| 64 |
+
# -- category queries ---------------------------------------------------
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def is_load(self) -> bool:
|
| 68 |
+
return self.opname in _LOAD_OPCODES
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def is_store(self) -> bool:
|
| 72 |
+
return self.opname in _STORE_OPCODES
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def is_delete(self) -> bool:
|
| 76 |
+
return self.opname in _DELETE_OPCODES
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def is_jump(self) -> bool:
|
| 80 |
+
return self.opcode in _ALL_JUMP_OPCODES
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def is_conditional_jump(self) -> bool:
|
| 84 |
+
return self.is_jump and ("IF" in self.opname or "FOR_ITER" in self.opname)
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def is_unconditional_jump(self) -> bool:
|
| 88 |
+
return self.is_jump and not self.is_conditional_jump
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def is_return(self) -> bool:
|
| 92 |
+
return self.opname in ("RETURN_VALUE", "RETURN_CONST")
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def is_nop(self) -> bool:
|
| 96 |
+
return self.opname == "NOP"
|
| 97 |
+
|
| 98 |
+
# -- jump target --------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
def jump_target_offset(self) -> Optional[int]:
|
| 101 |
+
"""Return the absolute bytecode offset this instruction jumps to,
|
| 102 |
+
or ``None`` if it is not a jump instruction."""
|
| 103 |
+
if not self.is_jump:
|
| 104 |
+
return None
|
| 105 |
+
if "to " in self.argrepr:
|
| 106 |
+
return int(self.argrepr.replace("to ", "").strip())
|
| 107 |
+
if self.opcode in dis.hasjabs:
|
| 108 |
+
return self.argval
|
| 109 |
+
if self.opcode in dis.hasjrel:
|
| 110 |
+
return self.argval if _PY311 else self.offset + self.argval
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
# -- mutation helpers (for cleanup passes) ------------------------------
|
| 114 |
+
|
| 115 |
+
def nop_(self) -> None:
|
| 116 |
+
"""In-place convert this instruction to a NOP."""
|
| 117 |
+
self.opname = "NOP"
|
| 118 |
+
self.opcode = dis.opmap["NOP"]
|
| 119 |
+
self.arg = 0
|
| 120 |
+
self.argval = 0
|
| 121 |
+
self.argrepr = ""
|
| 122 |
+
self.is_jump_target = False
|
| 123 |
+
|
| 124 |
+
# -- factory ------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def from_dis(i: dis.Instruction) -> "Instruction":
|
| 128 |
+
"""Create from a stdlib ``dis.Instruction``."""
|
| 129 |
+
return Instruction(i.opcode, i.opname, i.arg, i.argval, i.argrepr, i.offset, i.starts_line, i.is_jump_target)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/source_emitter.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""SourceEmitter: manages the evaluation stack and source-code emission.
|
| 16 |
+
|
| 17 |
+
This replaces the bare ``DecompilerState`` (just ``source_code: str``
|
| 18 |
+
and ``stack: list``) with a proper class that owns *all* mutable state
|
| 19 |
+
touched during decompilation, including the temp-variable counter
|
| 20 |
+
(instance-level, not class-level — thread-safe by design).
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import contextlib
|
| 26 |
+
import dataclasses
|
| 27 |
+
from typing import Any, Iterator, List, Optional
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass
|
| 31 |
+
class LoopContext:
|
| 32 |
+
"""Current loop boundaries, used for break/continue determination.
|
| 33 |
+
|
| 34 |
+
Range semantics similar to range(start, end):
|
| 35 |
+
start_index: index of FOR_ITER itself (inclusive — part of the loop)
|
| 36 |
+
end_index: index of the first instruction outside the loop (exclusive — not part of the loop)
|
| 37 |
+
|
| 38 |
+
break/continue determination (used in _abs_jump):
|
| 39 |
+
jump target >= end_index → break (jump out of the loop)
|
| 40 |
+
jump target == start_index → continue (jump back to loop head)
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
start_index: int # index of FOR_ITER, inclusive (part of the loop)
|
| 44 |
+
end_index: int # first instruction outside the loop, exclusive (not part of the loop)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SourceEmitter:
|
| 48 |
+
"""Stateful accumulator for the decompiler's output.
|
| 49 |
+
|
| 50 |
+
Improvements over depyf's ``DecompilerState``:
|
| 51 |
+
* ``_temp_counter`` is **instance-level** (no thread-safety issues).
|
| 52 |
+
* Stack operations (``push / pop / peek``) are proper methods.
|
| 53 |
+
* ``emit()`` appends with a trailing newline automatically.
|
| 54 |
+
* ``fork()`` context-manager creates a child emitter for sub-blocks
|
| 55 |
+
(if-else branches, loop bodies, etc.) and returns it so the caller
|
| 56 |
+
can inspect the generated source and final stack.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, indent_size: int = 4, temp_prefix: str = "__temp_", *, _parent_counter: Optional[list] = None) -> None:
|
| 60 |
+
self._lines: List[str] = []
|
| 61 |
+
self._stack: List[Any] = []
|
| 62 |
+
self._indent_size = indent_size
|
| 63 |
+
self._temp_prefix = temp_prefix
|
| 64 |
+
# Share counter across forks so names are globally unique within
|
| 65 |
+
# one Decompiler invocation, but still instance-scoped.
|
| 66 |
+
self._counter: list = _parent_counter if _parent_counter is not None else [0]
|
| 67 |
+
self.loop: Optional[LoopContext] = None
|
| 68 |
+
|
| 69 |
+
# -- source emission ----------------------------------------------------
|
| 70 |
+
|
| 71 |
+
def emit(self, line: str) -> None:
|
| 72 |
+
"""Append *line* (with auto newline) to accumulated source."""
|
| 73 |
+
self._lines.append(line + "\n")
|
| 74 |
+
|
| 75 |
+
def emit_raw(self, text: str) -> None:
|
| 76 |
+
"""Append pre-formatted *text* verbatim (e.g. nested function defs)."""
|
| 77 |
+
self._lines.append(text)
|
| 78 |
+
|
| 79 |
+
def get_source(self) -> str:
|
| 80 |
+
return "".join(self._lines)
|
| 81 |
+
|
| 82 |
+
# -- stack operations ---------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def push(self, value: Any) -> None:
|
| 85 |
+
self._stack.append(value)
|
| 86 |
+
|
| 87 |
+
def pop(self) -> Any:
|
| 88 |
+
return self._stack.pop()
|
| 89 |
+
|
| 90 |
+
def peek(self, depth: int = 0) -> Any:
|
| 91 |
+
"""Return item at ``stack[-(depth+1)]`` without popping."""
|
| 92 |
+
return self._stack[-(depth + 1)]
|
| 93 |
+
|
| 94 |
+
def set_at(self, depth: int, value: Any) -> None:
|
| 95 |
+
"""Set ``stack[-(depth+1)]`` to *value*."""
|
| 96 |
+
self._stack[-(depth + 1)] = value
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def stack(self) -> List[Any]:
|
| 100 |
+
"""Direct access (for complex multi-item operations)."""
|
| 101 |
+
return self._stack
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def stack_size(self) -> int:
|
| 105 |
+
return len(self._stack)
|
| 106 |
+
|
| 107 |
+
# -- temp variables (instance-scoped counter) ---------------------------
|
| 108 |
+
|
| 109 |
+
def make_temp(self) -> str:
|
| 110 |
+
"""Return a unique temporary variable name."""
|
| 111 |
+
self._counter[0] += 1
|
| 112 |
+
return f"{self._temp_prefix}{self._counter[0]}"
|
| 113 |
+
|
| 114 |
+
def replace_tos_with_temp(self, depth: int = 1) -> str:
|
| 115 |
+
"""Replace ``stack[-depth]`` with a fresh temp, emitting the
|
| 116 |
+
assignment ``__temp_N = <old_value>``. Returns the temp name."""
|
| 117 |
+
old = self._stack[-depth]
|
| 118 |
+
name = self.make_temp()
|
| 119 |
+
self.emit(f"{name} = {old}")
|
| 120 |
+
self._stack[-depth] = name
|
| 121 |
+
return name
|
| 122 |
+
|
| 123 |
+
# -- sub-block forking --------------------------------------------------
|
| 124 |
+
|
| 125 |
+
@contextlib.contextmanager
|
| 126 |
+
def fork(self, stack: Optional[List[Any]] = None, loop: Optional[LoopContext] = None) -> Iterator["SourceEmitter"]:
|
| 127 |
+
"""Create a child emitter for a sub-block (if-branch, loop body …).
|
| 128 |
+
|
| 129 |
+
The child shares the temp counter but has its own ``_lines`` and
|
| 130 |
+
``_stack``. If *loop* is ``None`` the parent's loop context is
|
| 131 |
+
inherited (matching depyf's ``new_state`` semantics).
|
| 132 |
+
|
| 133 |
+
Usage::
|
| 134 |
+
|
| 135 |
+
with emitter.fork(stack=my_stack) as child:
|
| 136 |
+
decompile_range(start, end, child)
|
| 137 |
+
child_source = child.get_source()
|
| 138 |
+
child_final_stack = child.stack
|
| 139 |
+
"""
|
| 140 |
+
child = SourceEmitter(indent_size=self._indent_size, temp_prefix=self._temp_prefix, _parent_counter=self._counter)
|
| 141 |
+
child._stack = list(stack) if stack is not None else list(self._stack)
|
| 142 |
+
if loop is not None:
|
| 143 |
+
child.loop = loop
|
| 144 |
+
elif self.loop is not None:
|
| 145 |
+
child.loop = self.loop
|
| 146 |
+
yield child
|
| 147 |
+
|
| 148 |
+
# -- indentation helpers ------------------------------------------------
|
| 149 |
+
|
| 150 |
+
def indent(self, text: str) -> str:
|
| 151 |
+
"""Add one level of indentation to every line in *text*."""
|
| 152 |
+
prefix = " " * self._indent_size
|
| 153 |
+
return "".join(prefix + line + "\n" for line in text.splitlines())
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/decompiler.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Decompiler — the orchestrator that ties everything together.
|
| 16 |
+
|
| 17 |
+
This module is the only place that coordinates ``SourceEmitter``,
|
| 18 |
+
``HandlerRegistry``, and ``DecompileContext``.
|
| 19 |
+
Individual handler functions never import from here (except for
|
| 20 |
+
``DecompilationError`` and recursive ``Decompiler`` usage in
|
| 21 |
+
``MAKE_FUNCTION``).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import dis
|
| 27 |
+
import inspect
|
| 28 |
+
import os
|
| 29 |
+
from types import CodeType
|
| 30 |
+
from typing import Callable, List, Optional, Union
|
| 31 |
+
|
| 32 |
+
# Force handler registration by importing the package.
|
| 33 |
+
import magi_compiler.magi_depyf.decompile.bytecode.handlers # noqa: F401
|
| 34 |
+
|
| 35 |
+
from .bytecode.decompile_context import DecompileContext
|
| 36 |
+
from .bytecode.handler_registry import registry
|
| 37 |
+
from .bytecode.instruction import Instruction
|
| 38 |
+
from .bytecode.source_emitter import SourceEmitter
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Errors
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DecompilationError(Exception):
|
| 46 |
+
"""Raised when decompilation fails.
|
| 47 |
+
|
| 48 |
+
Carries optional ``instruction`` context so callers can produce
|
| 49 |
+
actionable error messages.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, message: str = "", *, instruction: Optional[Instruction] = None):
|
| 53 |
+
self.message = message
|
| 54 |
+
self.instruction = instruction
|
| 55 |
+
super().__init__(message)
|
| 56 |
+
|
| 57 |
+
def __str__(self) -> str:
|
| 58 |
+
loc = ""
|
| 59 |
+
if self.instruction is not None:
|
| 60 |
+
loc = f" at {self.instruction}"
|
| 61 |
+
return f"DecompilationError: {self.message}{loc}"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# Signature builder (lives here — it's a Decompiler concern, not a util)
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SignatureBuilder:
|
| 70 |
+
"""Build the ``def fn(args):`` header from a ``CodeType``."""
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def build(code: CodeType, overwrite_name: Optional[str] = None) -> str:
|
| 74 |
+
n = code.co_argcount + code.co_kwonlyargcount
|
| 75 |
+
names = [x.replace(".", "comp_arg_") if x.startswith(".") else x for x in code.co_varnames[:n]]
|
| 76 |
+
if code.co_flags & inspect.CO_VARARGS:
|
| 77 |
+
names.append("*" + code.co_varnames[n])
|
| 78 |
+
n += 1
|
| 79 |
+
if code.co_flags & inspect.CO_VARKEYWORDS:
|
| 80 |
+
names.append("**" + code.co_varnames[n])
|
| 81 |
+
n += 1
|
| 82 |
+
fn_name = overwrite_name or code.co_name
|
| 83 |
+
return f"def {fn_name}({', '.join(names)}):\n"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Decompiler
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Decompiler:
|
| 92 |
+
"""Decompile a ``CodeType`` into Python source code.
|
| 93 |
+
|
| 94 |
+
Design differences from depyf's ``Decompiler``:
|
| 95 |
+
|
| 96 |
+
* Handlers live in separate modules and receive
|
| 97 |
+
``(emitter, inst, ctx)`` — they never reference this class.
|
| 98 |
+
* All mutable state is on ``SourceEmitter`` (instance-scoped counter).
|
| 99 |
+
* ``decompile_range`` is delegated *through* ``DecompileContext``
|
| 100 |
+
so handlers can recurse without importing this class (except
|
| 101 |
+
``MAKE_FUNCTION`` which needs a fresh ``Decompiler`` instance).
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
_TERMINATORS = frozenset({"RETURN_VALUE", "RETURN_CONST", "RAISE_VARARGS"})
|
| 105 |
+
|
| 106 |
+
def __init__(self, code: Union[CodeType, Callable]) -> None:
|
| 107 |
+
if callable(code) and not isinstance(code, CodeType):
|
| 108 |
+
code = _get_code_owner(code).__code__
|
| 109 |
+
self.code: CodeType = code
|
| 110 |
+
self.instructions = [Instruction.from_dis(i) for i in dis.get_instructions(code)]
|
| 111 |
+
self._cleanup()
|
| 112 |
+
|
| 113 |
+
# -- bytecode cleanup ---------------------------------------------------
|
| 114 |
+
|
| 115 |
+
def _cleanup(self) -> None:
|
| 116 |
+
"""Propagate line numbers and NOP dead code after unconditional exits."""
|
| 117 |
+
cur: Optional[int] = None
|
| 118 |
+
for inst in self.instructions:
|
| 119 |
+
if inst.starts_line is not None:
|
| 120 |
+
cur = inst.starts_line
|
| 121 |
+
inst.starts_line = cur
|
| 122 |
+
|
| 123 |
+
in_dead = False
|
| 124 |
+
for inst in self.instructions:
|
| 125 |
+
if in_dead:
|
| 126 |
+
if inst.is_jump_target:
|
| 127 |
+
in_dead = False
|
| 128 |
+
else:
|
| 129 |
+
inst.nop_()
|
| 130 |
+
elif inst.opname in self._TERMINATORS:
|
| 131 |
+
in_dead = True
|
| 132 |
+
|
| 133 |
+
# -- core loop ----------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def decompile_range(self, start: int, end: int, emitter: SourceEmitter) -> None:
|
| 136 |
+
"""Execute instruction handlers from *start* to *end* (exclusive)."""
|
| 137 |
+
idx = start
|
| 138 |
+
try:
|
| 139 |
+
while idx < end:
|
| 140 |
+
inst = self.instructions[idx]
|
| 141 |
+
handler = registry.get(inst.opname)
|
| 142 |
+
if handler is None:
|
| 143 |
+
raise DecompilationError(f"No handler for opcode {inst.opname}", instruction=inst)
|
| 144 |
+
ctx = self._make_context(emitter)
|
| 145 |
+
result = handler(emitter, inst, ctx)
|
| 146 |
+
idx = result if result is not None else idx + 1
|
| 147 |
+
except DecompilationError:
|
| 148 |
+
raise
|
| 149 |
+
except Exception as e:
|
| 150 |
+
raise DecompilationError(f"Failed at {inst!r} in {self.code.co_name}", instruction=inst) from e
|
| 151 |
+
|
| 152 |
+
def _make_context(self, emitter: SourceEmitter) -> DecompileContext:
|
| 153 |
+
return DecompileContext(
|
| 154 |
+
code=self.code,
|
| 155 |
+
instructions=tuple(self.instructions),
|
| 156 |
+
indentation=emitter._indent_size,
|
| 157 |
+
decompile_range=lambda start, end, em: self.decompile_range(start, end, em),
|
| 158 |
+
offset_to_index={inst.offset: idx for idx, inst in enumerate(self.instructions)},
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# -- public API ---------------------------------------------------------
|
| 162 |
+
|
| 163 |
+
def decompile(self, indentation: int = 4, temp_prefix: str = "__temp_", overwrite_fn_name: Optional[str] = None) -> str:
|
| 164 |
+
"""Return decompiled Python source code."""
|
| 165 |
+
try:
|
| 166 |
+
emitter = SourceEmitter(indent_size=indentation, temp_prefix=temp_prefix)
|
| 167 |
+
self.decompile_range(0, len(self.instructions), emitter)
|
| 168 |
+
body = emitter.get_source()
|
| 169 |
+
|
| 170 |
+
if os.environ.get("DEPYF_REMOVE_TEMP", "1") == "1":
|
| 171 |
+
from .postprocess import run_all as _postprocess
|
| 172 |
+
|
| 173 |
+
body = _postprocess(body, temp_prefix, indentation)
|
| 174 |
+
|
| 175 |
+
header = SignatureBuilder.build(self.code, overwrite_fn_name)
|
| 176 |
+
|
| 177 |
+
global_names = {i.argval for i in dis.get_instructions(self.code) if i.opname == "STORE_GLOBAL"}
|
| 178 |
+
preamble = ""
|
| 179 |
+
if global_names:
|
| 180 |
+
preamble += "global " + ", ".join(global_names) + "\n"
|
| 181 |
+
if self.code.co_freevars:
|
| 182 |
+
preamble += "nonlocal " + ", ".join(self.code.co_freevars) + "\n"
|
| 183 |
+
|
| 184 |
+
body = preamble + body
|
| 185 |
+
return header + emitter.indent(body)
|
| 186 |
+
except DecompilationError:
|
| 187 |
+
raise
|
| 188 |
+
except Exception as e:
|
| 189 |
+
raise DecompilationError(f"Failed to decompile {self.code.co_name}") from e
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def supported_opnames() -> List[str]:
|
| 193 |
+
return registry.supported_opnames()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
# Module-level convenience
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def decompile(code: Union[CodeType, Callable]) -> str:
|
| 202 |
+
"""One-liner: decompile a code object or callable to source."""
|
| 203 |
+
return Decompiler(code).decompile()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def safe_decompile(code: CodeType) -> str:
|
| 207 |
+
"""Decompile *code* without raising; fall back to depyf then placeholder."""
|
| 208 |
+
try:
|
| 209 |
+
return Decompiler(code).decompile()
|
| 210 |
+
except Exception:
|
| 211 |
+
try:
|
| 212 |
+
from depyf import decompile as _depyf_decompile
|
| 213 |
+
|
| 214 |
+
return _depyf_decompile(code)
|
| 215 |
+
except Exception:
|
| 216 |
+
return f"# Failed to decompile {code.co_name}\n"
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
# Private helpers
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _get_code_owner(fn):
|
| 225 |
+
"""Walk through wrappers to find the object that owns ``__code__``."""
|
| 226 |
+
if hasattr(fn, "__func__"):
|
| 227 |
+
return fn.__func__
|
| 228 |
+
if hasattr(fn, "__wrapped__"):
|
| 229 |
+
return _get_code_owner(fn.__wrapped__)
|
| 230 |
+
return fn
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Source-level post-processing pipeline for decompiled code.
|
| 16 |
+
|
| 17 |
+
Each pass is a function ``(source, ...) -> source`` that performs one
|
| 18 |
+
semantics-preserving transformation. ``run_all`` applies them in order.
|
| 19 |
+
All passes are best-effort: on any exception they return the input unchanged.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from .branch_dedup import dedup_branch_tails
|
| 23 |
+
from .for_temps import eliminate_for_temps
|
| 24 |
+
from .inline_temps import eliminate_inline_temps
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_all(source: str, temp_prefix: str = "__temp_", indent: int = 4) -> str:
|
| 28 |
+
"""Apply all post-processing passes in sequence."""
|
| 29 |
+
source = eliminate_for_temps(source, temp_prefix, indent)
|
| 30 |
+
source = eliminate_inline_temps(source, temp_prefix, indent)
|
| 31 |
+
source = dedup_branch_tails(source, indent)
|
| 32 |
+
return source
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__all__ = ["run_all", "eliminate_for_temps", "eliminate_inline_temps", "dedup_branch_tails"]
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/branch_dedup.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Pass 3: if/else branch tail deduplication.
|
| 16 |
+
|
| 17 |
+
Move identical trailing statements from if/else branches to after the block.
|
| 18 |
+
|
| 19 |
+
Example::
|
| 20 |
+
|
| 21 |
+
if cond: if cond:
|
| 22 |
+
x = 1 x = 1
|
| 23 |
+
return x → else:
|
| 24 |
+
else: x = 2
|
| 25 |
+
x = 2 return x
|
| 26 |
+
return x
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import ast
|
| 32 |
+
from typing import List, Tuple
|
| 33 |
+
|
| 34 |
+
import astor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def dedup_branch_tails(source: str, indent: int = 4) -> str:
|
| 38 |
+
"""Move identical trailing statements from if/else branches to after the block."""
|
| 39 |
+
try:
|
| 40 |
+
tree = ast.parse(source)
|
| 41 |
+
changed = False
|
| 42 |
+
for node in ast.walk(tree):
|
| 43 |
+
if hasattr(node, "body") and isinstance(node.body, list):
|
| 44 |
+
new_body, c = _dedup_stmts(node.body)
|
| 45 |
+
if c:
|
| 46 |
+
node.body = new_body
|
| 47 |
+
changed = True
|
| 48 |
+
if not changed:
|
| 49 |
+
return source
|
| 50 |
+
ast.fix_missing_locations(tree)
|
| 51 |
+
return astor.to_source(tree, indent_with=" " * indent)
|
| 52 |
+
except Exception:
|
| 53 |
+
return source
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _dedup_stmts(stmts: List[ast.stmt]) -> Tuple[List[ast.stmt], bool]:
|
| 57 |
+
"""Process a statement list, extracting common if/else tails."""
|
| 58 |
+
result: List[ast.stmt] = []
|
| 59 |
+
changed = False
|
| 60 |
+
|
| 61 |
+
for stmt in stmts:
|
| 62 |
+
for attr in ("body", "orelse", "handlers", "finalbody"):
|
| 63 |
+
sub = getattr(stmt, attr, None)
|
| 64 |
+
if isinstance(sub, list) and sub:
|
| 65 |
+
new_sub, c = _dedup_stmts(sub)
|
| 66 |
+
if c:
|
| 67 |
+
setattr(stmt, attr, new_sub)
|
| 68 |
+
changed = True
|
| 69 |
+
|
| 70 |
+
if isinstance(stmt, ast.If) and stmt.orelse:
|
| 71 |
+
n = _common_tail_length(stmt.body, stmt.orelse)
|
| 72 |
+
if n > 0:
|
| 73 |
+
common = stmt.body[-n:]
|
| 74 |
+
stmt.body = stmt.body[:-n] or [ast.Pass()]
|
| 75 |
+
stmt.orelse = stmt.orelse[:-n] or []
|
| 76 |
+
result.append(stmt)
|
| 77 |
+
result.extend(common)
|
| 78 |
+
changed = True
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
result.append(stmt)
|
| 82 |
+
|
| 83 |
+
return result, changed
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _common_tail_length(body: List[ast.stmt], orelse: List[ast.stmt]) -> int:
|
| 87 |
+
"""Count identical trailing statements (by AST dump equality)."""
|
| 88 |
+
count = 0
|
| 89 |
+
i, j = len(body) - 1, len(orelse) - 1
|
| 90 |
+
while i >= 0 and j >= 0:
|
| 91 |
+
if ast.dump(body[i]) == ast.dump(orelse[j]):
|
| 92 |
+
count += 1
|
| 93 |
+
i -= 1
|
| 94 |
+
j -= 1
|
| 95 |
+
else:
|
| 96 |
+
break
|
| 97 |
+
if count >= len(body) or count >= len(orelse):
|
| 98 |
+
count = min(len(body), len(orelse)) - 1
|
| 99 |
+
return max(count, 0)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/for_temps.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Pass 1: for-loop temp elimination.
|
| 16 |
+
|
| 17 |
+
``for __temp in iter: var = __temp; ...`` → ``for var in iter: ...``
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import ast
|
| 23 |
+
|
| 24 |
+
import astor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def eliminate_for_temps(source: str, temp_prefix: str = "__temp_", indent: int = 4) -> str:
|
| 28 |
+
"""Only applies when the first body statement is a plain assignment
|
| 29 |
+
from the temp to a real variable."""
|
| 30 |
+
try:
|
| 31 |
+
tree = ast.parse(source)
|
| 32 |
+
tree = _ForTempEliminator(temp_prefix).visit(tree)
|
| 33 |
+
ast.fix_missing_locations(tree)
|
| 34 |
+
return astor.to_source(tree, indent_with=" " * indent)
|
| 35 |
+
except Exception:
|
| 36 |
+
return source
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class _ForTempEliminator(ast.NodeTransformer):
|
| 40 |
+
def __init__(self, prefix: str):
|
| 41 |
+
self._prefix = prefix
|
| 42 |
+
|
| 43 |
+
def visit_For(self, node: ast.For) -> ast.For:
|
| 44 |
+
self.generic_visit(node)
|
| 45 |
+
if not (
|
| 46 |
+
isinstance(node.target, ast.Name)
|
| 47 |
+
and node.target.id.startswith(self._prefix)
|
| 48 |
+
and node.body
|
| 49 |
+
and isinstance(node.body[0], ast.Assign)
|
| 50 |
+
and len(node.body[0].targets) == 1
|
| 51 |
+
and isinstance(node.body[0].value, ast.Name)
|
| 52 |
+
and node.body[0].value.id == node.target.id
|
| 53 |
+
):
|
| 54 |
+
return node
|
| 55 |
+
node.target = node.body[0].targets[0]
|
| 56 |
+
node.body = node.body[1:] or [ast.Pass()]
|
| 57 |
+
return node
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/inline_temps.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Pass 2: single-use temp inlining.
|
| 16 |
+
|
| 17 |
+
``__temp = expr; use(__temp)`` → ``use(expr)`` for single-use temps.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import ast
|
| 23 |
+
from collections import defaultdict
|
| 24 |
+
from typing import List, Optional
|
| 25 |
+
|
| 26 |
+
import astor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def eliminate_inline_temps(source: str, temp_prefix: str = "__temp_", indent: int = 4) -> str:
|
| 30 |
+
"""Inline single-use temporaries into their use site."""
|
| 31 |
+
try:
|
| 32 |
+
tree = ast.parse(source)
|
| 33 |
+
_set_parents(tree)
|
| 34 |
+
|
| 35 |
+
occurrences: dict[str, list] = defaultdict(list)
|
| 36 |
+
for node in ast.walk(tree):
|
| 37 |
+
if isinstance(node, ast.Name) and node.id.startswith(temp_prefix):
|
| 38 |
+
occurrences[node.id].append(node)
|
| 39 |
+
|
| 40 |
+
_INDENT_NODES = (
|
| 41 |
+
ast.FunctionDef,
|
| 42 |
+
ast.AsyncFunctionDef,
|
| 43 |
+
ast.For,
|
| 44 |
+
ast.AsyncFor,
|
| 45 |
+
ast.While,
|
| 46 |
+
ast.If,
|
| 47 |
+
ast.Try,
|
| 48 |
+
ast.With,
|
| 49 |
+
ast.AsyncWith,
|
| 50 |
+
ast.ClassDef,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
for name in occurrences:
|
| 54 |
+
occ = occurrences[name]
|
| 55 |
+
if len(occ) == 2:
|
| 56 |
+
n1, n2 = occ
|
| 57 |
+
_, p1, p2 = _lowest_common_parent(n1, n2)
|
| 58 |
+
ap = p1 if isinstance(getattr(n1, "parent", None), ast.Assign) else p2
|
| 59 |
+
can = not isinstance(ap, _INDENT_NODES)
|
| 60 |
+
if can:
|
| 61 |
+
can = _safe_to_inline(tree, n1, n2)
|
| 62 |
+
occ.append(can)
|
| 63 |
+
tree = _RemoveAssign(name, occurrences).visit(tree)
|
| 64 |
+
tree = _InlineTemp(name, occurrences).visit(tree)
|
| 65 |
+
|
| 66 |
+
return astor.to_source(tree, indent_with=" " * indent)
|
| 67 |
+
except Exception:
|
| 68 |
+
return source
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# AST helpers
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _set_parents(node: ast.AST, parent: Optional[ast.AST] = None) -> None:
|
| 77 |
+
for child in ast.iter_child_nodes(node):
|
| 78 |
+
child.parent = parent # type: ignore[attr-defined]
|
| 79 |
+
_set_parents(child, child)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _get_parents(node: ast.AST) -> List[ast.AST]:
|
| 83 |
+
out = []
|
| 84 |
+
while node:
|
| 85 |
+
out.append(node)
|
| 86 |
+
node = getattr(node, "parent", None)
|
| 87 |
+
return out
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _lowest_common_parent(n1: ast.AST, n2: ast.AST):
|
| 91 |
+
p1 = _get_parents(n1)
|
| 92 |
+
p2 = _get_parents(n2)
|
| 93 |
+
p1.reverse()
|
| 94 |
+
p2.reverse()
|
| 95 |
+
last = c1 = c2 = None
|
| 96 |
+
for a, b in zip(p1, p2):
|
| 97 |
+
if a is b:
|
| 98 |
+
last = a
|
| 99 |
+
else:
|
| 100 |
+
c1, c2 = a, b
|
| 101 |
+
break
|
| 102 |
+
return last, c1, c2
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _safe_to_inline(tree: ast.AST, def_node: ast.AST, use_node: ast.AST) -> bool:
|
| 106 |
+
"""Verify the RHS variable is not reassigned between definition and use."""
|
| 107 |
+
assign_parent = getattr(def_node, "parent", None)
|
| 108 |
+
if not isinstance(assign_parent, ast.Assign):
|
| 109 |
+
return True
|
| 110 |
+
rhs = assign_parent.value
|
| 111 |
+
if not isinstance(rhs, ast.Name):
|
| 112 |
+
return True
|
| 113 |
+
|
| 114 |
+
rhs_name = rhs.id
|
| 115 |
+
stmts: List[ast.stmt] = []
|
| 116 |
+
for node in ast.walk(tree):
|
| 117 |
+
if hasattr(node, "body") and isinstance(node.body, list):
|
| 118 |
+
stmts = node.body
|
| 119 |
+
break
|
| 120 |
+
try:
|
| 121 |
+
def_idx = next(i for i, s in enumerate(stmts) if s is assign_parent)
|
| 122 |
+
use_stmt = getattr(use_node, "parent", None)
|
| 123 |
+
while use_stmt and use_stmt not in stmts:
|
| 124 |
+
use_stmt = getattr(use_stmt, "parent", None)
|
| 125 |
+
use_idx = next(i for i, s in enumerate(stmts) if s is use_stmt)
|
| 126 |
+
except StopIteration:
|
| 127 |
+
return True
|
| 128 |
+
|
| 129 |
+
for stmt in stmts[def_idx + 1 : use_idx]:
|
| 130 |
+
if isinstance(stmt, ast.Assign):
|
| 131 |
+
for t in stmt.targets:
|
| 132 |
+
if isinstance(t, ast.Name) and t.id == rhs_name:
|
| 133 |
+
return False
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class _RemoveAssign(ast.NodeTransformer):
|
| 138 |
+
def __init__(self, name: str, occ: dict):
|
| 139 |
+
self._name = name
|
| 140 |
+
self._occ = occ
|
| 141 |
+
|
| 142 |
+
def visit_Assign(self, node: ast.Assign):
|
| 143 |
+
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
|
| 144 |
+
n = node.targets[0].id
|
| 145 |
+
if n == self._name:
|
| 146 |
+
o = self._occ[n]
|
| 147 |
+
if len(o) == 1:
|
| 148 |
+
return ast.Expr(value=node.value)
|
| 149 |
+
if len(o) == 3 and isinstance(o[-1], bool):
|
| 150 |
+
o.append(node.value)
|
| 151 |
+
if o[-2]:
|
| 152 |
+
return None
|
| 153 |
+
return node
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class _InlineTemp(ast.NodeTransformer):
|
| 157 |
+
def __init__(self, name: str, occ: dict):
|
| 158 |
+
self._name = name
|
| 159 |
+
self._occ = occ
|
| 160 |
+
|
| 161 |
+
def visit_Name(self, node: ast.Name):
|
| 162 |
+
o = self._occ.get(node.id, [])
|
| 163 |
+
if node.id == self._name and len(o) == 4 and isinstance(o[-2], bool) and o[-2]:
|
| 164 |
+
return o[-1]
|
| 165 |
+
return node
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/recompiler.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""CodeRecompiler: round-trip decompile -> compile -> extract target CodeType.
|
| 16 |
+
|
| 17 |
+
Pipeline: CodeType -> decompile -> compile -> find target.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
from types import CodeType
|
| 23 |
+
from typing import List
|
| 24 |
+
|
| 25 |
+
from .decompiler import Decompiler
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CodeRecompiler:
|
| 29 |
+
"""Decompile *code*, recompile, and produce a compatible ``CodeType``."""
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def recompile(
|
| 33 |
+
code_to_decompile: CodeType, reference_code: CodeType, indentation: int = 4, temp_prefix: str = "__temp_"
|
| 34 |
+
) -> CodeType:
|
| 35 |
+
"""Full round-trip: decompile -> compile -> find target."""
|
| 36 |
+
fn_name = reference_code.co_name
|
| 37 |
+
|
| 38 |
+
src = Decompiler(code_to_decompile).decompile(
|
| 39 |
+
indentation=indentation, temp_prefix=temp_prefix, overwrite_fn_name=fn_name
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
compiled = compile(src, "noname", "exec")
|
| 43 |
+
all_codes = CodeRecompiler.collect_code_objects(compiled)
|
| 44 |
+
return [c for c in all_codes if c.co_name == fn_name][0]
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def collect_code_objects(code: CodeType) -> List[CodeType]:
|
| 48 |
+
"""Recursively collect all ``CodeType`` objects from *code*."""
|
| 49 |
+
result = [code]
|
| 50 |
+
for c in code.co_consts:
|
| 51 |
+
if isinstance(c, CodeType):
|
| 52 |
+
result.extend(CodeRecompiler.collect_code_objects(c))
|
| 53 |
+
return result
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/demo_toy_example.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Demo: magi_depyf.dump_src with the depyf tutorial toy_example.
|
| 16 |
+
|
| 17 |
+
Run: PYTHONPATH=. python demo_toy_example.py
|
| 18 |
+
"""
|
| 19 |
+
import torch
|
| 20 |
+
from magi_compiler.magi_depyf.inspect import dump_src
|
| 21 |
+
|
| 22 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
torch.set_default_device(device)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@torch.compile
|
| 27 |
+
def toy_example(a, b):
|
| 28 |
+
x = a / (torch.abs(a) + 1)
|
| 29 |
+
if b.sum() < 0:
|
| 30 |
+
b = b * -1
|
| 31 |
+
return x * b
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
for _ in range(100):
|
| 36 |
+
toy_example(torch.randn(10), torch.randn(10))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
import os
|
| 41 |
+
import shutil
|
| 42 |
+
|
| 43 |
+
out = "./magi_dump_src_dir"
|
| 44 |
+
if os.path.exists(out):
|
| 45 |
+
shutil.rmtree(out)
|
| 46 |
+
with dump_src(out):
|
| 47 |
+
main()
|
| 48 |
+
|
| 49 |
+
print("\n=== Generated files ===")
|
| 50 |
+
for root, dirs, files in os.walk(out):
|
| 51 |
+
level = root.replace(out, "").count(os.sep)
|
| 52 |
+
print(f"{' ' * level}{os.path.basename(root)}/")
|
| 53 |
+
for f in files:
|
| 54 |
+
print(f"{' ' * (level + 1)}{f}")
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Inspection layer: capture torch.compile events, introspect artifacts, write structured output."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from .dump_src import dump_src
|
| 20 |
+
from .introspect import Introspector
|
| 21 |
+
from .model import CompiledFnInfo, EntryInfo, FunctionInfo, GuardInfo, GuardNode, SubgraphInfo
|
| 22 |
+
from .result import CaptureResult
|
| 23 |
+
from .session import CaptureSession
|
| 24 |
+
from .writer import FunctionWriter, write_function
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def debug_compiled(fn, output_dir: Optional[str] = None) -> FunctionInfo:
|
| 28 |
+
"""Introspect a compiled function and optionally write debug output.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
fn: The original (uncompiled) function.
|
| 32 |
+
output_dir: If provided, write organized files to this directory.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
FunctionInfo with full compilation state.
|
| 36 |
+
"""
|
| 37 |
+
info = Introspector.build_function_info(fn)
|
| 38 |
+
if output_dir is not None:
|
| 39 |
+
write_function(info, output_dir)
|
| 40 |
+
return info
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"CompiledFnInfo",
|
| 45 |
+
"EntryInfo",
|
| 46 |
+
"FunctionInfo",
|
| 47 |
+
"GuardInfo",
|
| 48 |
+
"GuardNode",
|
| 49 |
+
"SubgraphInfo",
|
| 50 |
+
"Introspector",
|
| 51 |
+
"FunctionWriter",
|
| 52 |
+
"write_function",
|
| 53 |
+
"dump_src",
|
| 54 |
+
"debug_compiled",
|
| 55 |
+
"CaptureSession",
|
| 56 |
+
"CaptureResult",
|
| 57 |
+
]
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/dump_src.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""dump_src: context manager that captures torch.compile artifacts and
|
| 16 |
+
writes structured source output to disk.
|
| 17 |
+
|
| 18 |
+
Usage::
|
| 19 |
+
|
| 20 |
+
from magi_compiler.magi_depyf.inspect import dump_src
|
| 21 |
+
|
| 22 |
+
@torch.compile
|
| 23 |
+
def my_fn(x):
|
| 24 |
+
return x.sum()
|
| 25 |
+
|
| 26 |
+
with dump_src("./output_dir"):
|
| 27 |
+
my_fn(torch.randn(10))
|
| 28 |
+
|
| 29 |
+
Internally uses ``CaptureSession`` to intercept compilation events,
|
| 30 |
+
then runs ``Introspector`` post-hoc for full introspection.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
|
| 35 |
+
import contextlib
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import Set
|
| 38 |
+
|
| 39 |
+
from magi_compiler.utils import magi_logger
|
| 40 |
+
|
| 41 |
+
from .introspect import Introspector
|
| 42 |
+
from .session import CaptureSession
|
| 43 |
+
from .writer import write_function
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@contextlib.contextmanager
|
| 47 |
+
def dump_src(dump_src_dir: str):
|
| 48 |
+
"""Context manager that captures torch.compile artifacts and writes output.
|
| 49 |
+
|
| 50 |
+
Uses CaptureSession for hook management and post-hoc introspection
|
| 51 |
+
of CacheEntries after execution completes.
|
| 52 |
+
"""
|
| 53 |
+
dump_dir = Path(dump_src_dir)
|
| 54 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
with CaptureSession() as session:
|
| 57 |
+
yield
|
| 58 |
+
|
| 59 |
+
seen: Set[str] = set()
|
| 60 |
+
overview_paths: list[Path] = []
|
| 61 |
+
for r in session.results:
|
| 62 |
+
name = r.original_code.co_name
|
| 63 |
+
if name in seen:
|
| 64 |
+
continue
|
| 65 |
+
if name.startswith("torch_dynamo_resume_in_"):
|
| 66 |
+
continue
|
| 67 |
+
seen.add(name)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
info = Introspector.build_function_info(r.original_code, fn_globals=r.fn_globals)
|
| 71 |
+
root = write_function(info, dump_dir)
|
| 72 |
+
overview_paths.append(root / "overview.md")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
magi_logger.warning("[magi_depyf] failed to process '%s': %s", name, e)
|
| 75 |
+
|
| 76 |
+
for p in overview_paths:
|
| 77 |
+
if p.exists():
|
| 78 |
+
magi_logger.info("[magi_depyf] %s", p)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/introspect.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Runtime introspection of torch.compile artifacts.
|
| 16 |
+
|
| 17 |
+
Walk actual runtime state (CacheEntry chain, guard trees, __compiled_fn
|
| 18 |
+
objects) to build the structured model. All torch imports are lazy so
|
| 19 |
+
this module can be imported without torch.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import io
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Dict, Optional
|
| 27 |
+
|
| 28 |
+
from ..decompile import safe_decompile
|
| 29 |
+
from .model import CompiledFnInfo, EntryInfo, FunctionInfo, GuardInfo, GuardNode, SubgraphInfo
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Introspector:
|
| 33 |
+
"""Namespace for runtime introspection helpers (all static methods)."""
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_cache_entries(fn) -> list:
|
| 37 |
+
"""Return CacheEntry list for *fn* (function or code object)."""
|
| 38 |
+
from torch._dynamo.eval_frame import _debug_get_cache_entry_list
|
| 39 |
+
|
| 40 |
+
code = fn.__code__ if hasattr(fn, "__code__") else fn
|
| 41 |
+
return _debug_get_cache_entry_list(code)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def build_guard_tree(node, max_depth: int = 32, _depth: int = 0) -> GuardNode:
|
| 45 |
+
"""Recursively build a GuardNode from a GuardManager C++ object."""
|
| 46 |
+
type_name = type(node).__name__
|
| 47 |
+
leaf_guards = []
|
| 48 |
+
for lg in node.get_leaf_guards():
|
| 49 |
+
for part in lg.verbose_code_parts():
|
| 50 |
+
leaf_guards.append(part.strip()[:120])
|
| 51 |
+
children = []
|
| 52 |
+
if _depth < max_depth:
|
| 53 |
+
for child in node.get_child_managers():
|
| 54 |
+
children.append(Introspector.build_guard_tree(child, max_depth, _depth + 1))
|
| 55 |
+
return GuardNode(type_name=type_name, leaf_guards=leaf_guards, children=children)
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def extract_guard_info(entry) -> Optional[GuardInfo]:
|
| 59 |
+
"""Extract structured guard info from a CacheEntry (post-hoc introspection).
|
| 60 |
+
|
| 61 |
+
Operates on the persisted CacheEntry and builds a full GuardNode tree.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
gm = entry.guard_manager
|
| 65 |
+
tree = Introspector.build_guard_tree(gm.root)
|
| 66 |
+
closure_vars: Dict[str, str] = {}
|
| 67 |
+
if hasattr(gm, "closure_vars") and gm.closure_vars:
|
| 68 |
+
for k, v in list(gm.closure_vars.items())[:10]:
|
| 69 |
+
closure_vars[k] = repr(v)[:100]
|
| 70 |
+
return GuardInfo(tree=tree, closure_vars=closure_vars or None)
|
| 71 |
+
except Exception:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def extract_compiled_fn_info(name: str, fn_globals: dict) -> Optional[CompiledFnInfo]:
|
| 76 |
+
"""Inspect a __compiled_fn_xxx from fn.__globals__.
|
| 77 |
+
|
| 78 |
+
Handles three backend types:
|
| 79 |
+
eager: wrapper -> closure[0]=GraphModule.forward (bound method)
|
| 80 |
+
inductor: wrapper -> closure[0]=aot_forward -> ... -> CompiledFxGraph
|
| 81 |
+
magi_compile: MagiSerializableFunction -> split_gm -> PiecewiseBackend(s)
|
| 82 |
+
"""
|
| 83 |
+
obj = fn_globals.get(name)
|
| 84 |
+
if obj is None:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
magi_info = Introspector._try_extract_magi_info(name, obj)
|
| 88 |
+
if magi_info is not None:
|
| 89 |
+
return magi_info
|
| 90 |
+
|
| 91 |
+
info = CompiledFnInfo(name=name, backend="eager")
|
| 92 |
+
|
| 93 |
+
gm = Introspector._find_graph_module(obj)
|
| 94 |
+
if gm is not None:
|
| 95 |
+
Introspector._fill_graph_module_info(info, gm)
|
| 96 |
+
|
| 97 |
+
cfx = Introspector._find_compiled_fx_graph(obj)
|
| 98 |
+
if cfx is not None:
|
| 99 |
+
info.backend = "inductor"
|
| 100 |
+
Introspector._fill_compiled_fx_graph_info(info, cfx)
|
| 101 |
+
|
| 102 |
+
return info
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _fill_graph_module_info(info: CompiledFnInfo, gm) -> None:
|
| 106 |
+
try:
|
| 107 |
+
info.readable_code = gm.print_readable(print_output=False)
|
| 108 |
+
except Exception:
|
| 109 |
+
pass
|
| 110 |
+
try:
|
| 111 |
+
info.graph_module_code = str(gm.code) if hasattr(gm, "code") else None
|
| 112 |
+
except Exception:
|
| 113 |
+
pass
|
| 114 |
+
try:
|
| 115 |
+
buf = io.StringIO()
|
| 116 |
+
gm.graph.print_tabular(file=buf)
|
| 117 |
+
info.fx_graph_tabular = buf.getvalue()
|
| 118 |
+
except Exception:
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def _fill_compiled_fx_graph_info(info: CompiledFnInfo, cfx) -> None:
|
| 123 |
+
try:
|
| 124 |
+
info.source_code = cfx.source_code
|
| 125 |
+
except Exception:
|
| 126 |
+
pass
|
| 127 |
+
try:
|
| 128 |
+
info.inductor_post_grad_graph = cfx.inductor_post_grad_graph_str
|
| 129 |
+
except Exception:
|
| 130 |
+
pass
|
| 131 |
+
try:
|
| 132 |
+
info.cache_key = cfx.cache_key
|
| 133 |
+
except Exception:
|
| 134 |
+
pass
|
| 135 |
+
try:
|
| 136 |
+
info.runnable_graph_str = cfx.runnable_graph_str
|
| 137 |
+
except Exception:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
# -- Magi backend introspection ----------------------------------------
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def _try_extract_magi_info(name: str, obj) -> Optional[CompiledFnInfo]:
|
| 144 |
+
"""Detect MagiSerializableFunction and walk its hierarchy.
|
| 145 |
+
|
| 146 |
+
MagiSerializableFunction hierarchy:
|
| 147 |
+
.graph_module → fx.GraphModule (full graph before splitting)
|
| 148 |
+
.optimized_call → split_gm (fx.GraphModule with PiecewiseBackend submodules)
|
| 149 |
+
.submod_N → PiecewiseBackend
|
| 150 |
+
.graph → fx.GraphModule (the subgraph)
|
| 151 |
+
.compiled_graph_for_general_shape → inductor compiled output
|
| 152 |
+
|
| 153 |
+
Dynamo wraps the backend result in a DisableContext closure, so the
|
| 154 |
+
MagiSerializableFunction may live one level deep in the closure chain.
|
| 155 |
+
"""
|
| 156 |
+
msf = obj if (hasattr(obj, "graph_module") and hasattr(obj, "optimized_call")) else None
|
| 157 |
+
if msf is None and callable(obj) and getattr(obj, "__closure__", None):
|
| 158 |
+
for cell in obj.__closure__:
|
| 159 |
+
try:
|
| 160 |
+
val = cell.cell_contents
|
| 161 |
+
except ValueError:
|
| 162 |
+
continue
|
| 163 |
+
if hasattr(val, "graph_module") and hasattr(val, "optimized_call"):
|
| 164 |
+
msf = val
|
| 165 |
+
break
|
| 166 |
+
if msf is None:
|
| 167 |
+
return None
|
| 168 |
+
obj = msf
|
| 169 |
+
|
| 170 |
+
import torch.fx
|
| 171 |
+
|
| 172 |
+
info = CompiledFnInfo(name=name, backend="magi_compile")
|
| 173 |
+
|
| 174 |
+
full_gm = getattr(obj, "graph_module", None)
|
| 175 |
+
if isinstance(full_gm, torch.fx.GraphModule):
|
| 176 |
+
Introspector._fill_graph_module_info(info, full_gm)
|
| 177 |
+
|
| 178 |
+
split_gm = getattr(obj, "optimized_call", None)
|
| 179 |
+
|
| 180 |
+
# In FULL cudagraph mode, optimized_call is a wrapper function whose
|
| 181 |
+
# __dict__ carries the GraphModule's attributes (via __dict__.update).
|
| 182 |
+
# Unwrap to find the actual GraphModule for print_readable / named_children.
|
| 183 |
+
actual_gm = split_gm if isinstance(split_gm, torch.fx.GraphModule) else None
|
| 184 |
+
if actual_gm is None and split_gm is not None:
|
| 185 |
+
actual_gm = Introspector._find_graph_module_deep(split_gm)
|
| 186 |
+
|
| 187 |
+
info.cudagraph_mode = Introspector._detect_cudagraph_mode(split_gm, actual_gm)
|
| 188 |
+
|
| 189 |
+
if actual_gm is not None:
|
| 190 |
+
try:
|
| 191 |
+
info.split_graph_readable = actual_gm.print_readable(print_output=False)
|
| 192 |
+
except Exception:
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
# PiecewiseCompileInterpreter replaces submodules via __dict__,
|
| 196 |
+
# so named_children() still sees the original GraphModules while
|
| 197 |
+
# __dict__ contains the PiecewiseBackend (or cudagraph wrapper).
|
| 198 |
+
# In FULL cudagraph mode, those __dict__ entries are copied onto
|
| 199 |
+
# the wrapper function, so we look up runtime objects from
|
| 200 |
+
# split_gm (the wrapper) rather than actual_gm.
|
| 201 |
+
runtime_source = split_gm if split_gm is not None else actual_gm
|
| 202 |
+
for sub_name, original_gm in actual_gm.named_children():
|
| 203 |
+
runtime_obj = runtime_source.__dict__.get(sub_name, original_gm)
|
| 204 |
+
sg_info = Introspector._extract_subgraph_info(sub_name, runtime_obj, original_gm)
|
| 205 |
+
if sg_info is not None:
|
| 206 |
+
info.subgraph_infos.append(sg_info)
|
| 207 |
+
|
| 208 |
+
info.subgraph_infos.sort(key=lambda s: s.name)
|
| 209 |
+
|
| 210 |
+
return info
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def _extract_subgraph_info(sub_name: str, runtime_obj, original_gm=None) -> Optional[SubgraphInfo]:
|
| 214 |
+
"""Extract info from one submodule of the split graph.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
sub_name: The submodule name (e.g. "submod_0").
|
| 218 |
+
runtime_obj: The actual runtime object — PiecewiseBackend,
|
| 219 |
+
cudagraph wrapper, or the original GraphModule.
|
| 220 |
+
original_gm: The original GraphModule before replacement (from _modules).
|
| 221 |
+
"""
|
| 222 |
+
import torch.fx
|
| 223 |
+
|
| 224 |
+
piecewise = Introspector._unwrap_piecewise_backend(runtime_obj)
|
| 225 |
+
|
| 226 |
+
if piecewise is not None:
|
| 227 |
+
sg = SubgraphInfo(name=sub_name, is_splitting_graph=False)
|
| 228 |
+
inner_gm = getattr(piecewise, "graph", None)
|
| 229 |
+
if isinstance(inner_gm, torch.fx.GraphModule):
|
| 230 |
+
Introspector._fill_subgraph_gm_info(sg, inner_gm)
|
| 231 |
+
|
| 232 |
+
compiled = getattr(piecewise, "compiled_graph_for_general_shape", None)
|
| 233 |
+
if compiled is not None:
|
| 234 |
+
sg.inductor_code = Introspector._try_extract_inductor_source(compiled)
|
| 235 |
+
|
| 236 |
+
if sg.inductor_code is None:
|
| 237 |
+
sg.inductor_code = Introspector._read_artifact_source_from_piecewise(piecewise)
|
| 238 |
+
|
| 239 |
+
return sg
|
| 240 |
+
|
| 241 |
+
gm = original_gm if isinstance(original_gm, torch.fx.GraphModule) else None
|
| 242 |
+
if gm is None and isinstance(runtime_obj, torch.fx.GraphModule):
|
| 243 |
+
gm = runtime_obj
|
| 244 |
+
|
| 245 |
+
if gm is not None:
|
| 246 |
+
sg = SubgraphInfo(name=sub_name, is_splitting_graph=True)
|
| 247 |
+
Introspector._fill_subgraph_gm_info(sg, gm)
|
| 248 |
+
return sg
|
| 249 |
+
|
| 250 |
+
return None
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def _unwrap_piecewise_backend(obj):
|
| 254 |
+
"""Find a PiecewiseBackend from obj, unwrapping closures/wrappers if needed."""
|
| 255 |
+
if hasattr(obj, "graph") and hasattr(obj, "compiled_graph_for_general_shape"):
|
| 256 |
+
return obj
|
| 257 |
+
|
| 258 |
+
if callable(obj) and hasattr(obj, "__closure__") and obj.__closure__:
|
| 259 |
+
for cell in obj.__closure__:
|
| 260 |
+
try:
|
| 261 |
+
val = cell.cell_contents
|
| 262 |
+
except ValueError:
|
| 263 |
+
continue
|
| 264 |
+
if hasattr(val, "graph") and hasattr(val, "compiled_graph_for_general_shape"):
|
| 265 |
+
return val
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def _fill_subgraph_gm_info(sg: SubgraphInfo, gm) -> None:
|
| 270 |
+
try:
|
| 271 |
+
sg.readable_code = gm.print_readable(print_output=False)
|
| 272 |
+
except Exception:
|
| 273 |
+
pass
|
| 274 |
+
try:
|
| 275 |
+
sg.graph_module_code = str(gm.code) if hasattr(gm, "code") else None
|
| 276 |
+
except Exception:
|
| 277 |
+
pass
|
| 278 |
+
try:
|
| 279 |
+
buf = io.StringIO()
|
| 280 |
+
gm.graph.print_tabular(file=buf)
|
| 281 |
+
sg.fx_graph_tabular = buf.getvalue()
|
| 282 |
+
except Exception:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def _try_extract_inductor_source(compiled) -> Optional[str]:
|
| 287 |
+
"""Try to extract inductor kernel source from a compiled graph object.
|
| 288 |
+
|
| 289 |
+
Handles CompiledFxGraph, CompiledArtifact, and closure-wrapped variants.
|
| 290 |
+
"""
|
| 291 |
+
for attr in ("source_code", "_source_code"):
|
| 292 |
+
val = getattr(compiled, attr, None)
|
| 293 |
+
if isinstance(val, str) and val:
|
| 294 |
+
return val
|
| 295 |
+
|
| 296 |
+
cfx = Introspector._find_compiled_fx_graph(compiled)
|
| 297 |
+
if cfx is not None:
|
| 298 |
+
try:
|
| 299 |
+
return cfx.source_code
|
| 300 |
+
except Exception:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
+
if hasattr(compiled, "print_readable"):
|
| 304 |
+
try:
|
| 305 |
+
return compiled.print_readable(print_output=False)
|
| 306 |
+
except Exception:
|
| 307 |
+
pass
|
| 308 |
+
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def _read_artifact_source_from_piecewise(piecewise) -> Optional[str]:
|
| 313 |
+
"""Read Inductor-generated source from the saved artifact directory.
|
| 314 |
+
|
| 315 |
+
PiecewiseBackend stores a compiler_manager whose cache maps
|
| 316 |
+
CacheEntry(runtime_shape, graph_index, backend_name) → CacheHandle(key, path).
|
| 317 |
+
The artifact at CacheHandle.path is an unpacked directory containing
|
| 318 |
+
``py/*.py`` — the full Inductor output code.
|
| 319 |
+
"""
|
| 320 |
+
try:
|
| 321 |
+
compiler_manager = getattr(piecewise, "compiler_manager", None)
|
| 322 |
+
if compiler_manager is None:
|
| 323 |
+
return None
|
| 324 |
+
cache = getattr(compiler_manager, "cache", None)
|
| 325 |
+
if not cache:
|
| 326 |
+
return None
|
| 327 |
+
index = getattr(piecewise, "piecewise_compile_index", None)
|
| 328 |
+
if index is None:
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
for cache_entry, cache_handle in cache.items():
|
| 332 |
+
if cache_entry.graph_index == index and cache_entry.runtime_shape is None:
|
| 333 |
+
artifact_path = getattr(cache_handle, "path", None)
|
| 334 |
+
if artifact_path:
|
| 335 |
+
return Introspector._read_py_from_artifact(artifact_path)
|
| 336 |
+
return None
|
| 337 |
+
except Exception:
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def _read_py_from_artifact(artifact_path: str) -> Optional[str]:
|
| 342 |
+
"""Read the Inductor-generated Python wrapper from an artifact directory.
|
| 343 |
+
|
| 344 |
+
The unpacked artifact layout varies across PyTorch versions; the
|
| 345 |
+
wrapper ``.py`` file has been observed under ``yb/`` and ``py/``.
|
| 346 |
+
We try known directories first, then fall back to scanning all
|
| 347 |
+
immediate subdirectories.
|
| 348 |
+
"""
|
| 349 |
+
root = Path(artifact_path)
|
| 350 |
+
if not root.is_dir():
|
| 351 |
+
return None
|
| 352 |
+
|
| 353 |
+
for candidate_dir in ("yb", "py"):
|
| 354 |
+
d = root / candidate_dir
|
| 355 |
+
if d.is_dir():
|
| 356 |
+
py_files = sorted(d.glob("*.py"))
|
| 357 |
+
if py_files:
|
| 358 |
+
try:
|
| 359 |
+
return py_files[0].read_text(encoding="utf-8")
|
| 360 |
+
except Exception:
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
for d in sorted(root.iterdir()):
|
| 364 |
+
if d.is_dir():
|
| 365 |
+
py_files = sorted(d.glob("*.py"))
|
| 366 |
+
if py_files:
|
| 367 |
+
try:
|
| 368 |
+
return py_files[0].read_text(encoding="utf-8")
|
| 369 |
+
except Exception:
|
| 370 |
+
pass
|
| 371 |
+
return None
|
| 372 |
+
|
| 373 |
+
@staticmethod
|
| 374 |
+
def _detect_cudagraph_mode(split_gm, actual_gm) -> str:
|
| 375 |
+
"""Detect cudagraph wrapping mode from the split graph structure.
|
| 376 |
+
|
| 377 |
+
- FULL: split_gm itself is a cudagraph wrapper (not a GraphModule),
|
| 378 |
+
with __qualname__ containing 'Athena_CUDAGraph_full'.
|
| 379 |
+
- PIECEWISE: split_gm is a GraphModule, but its __dict__ submodules are
|
| 380 |
+
cudagraph wrappers with __qualname__ 'Athena_CUDAGraph_piecewise'.
|
| 381 |
+
- NONE: no cudagraph wrapping detected.
|
| 382 |
+
"""
|
| 383 |
+
_CG_PREFIX = "Athena_CUDAGraph_"
|
| 384 |
+
|
| 385 |
+
qualname = getattr(split_gm, "__qualname__", "") or ""
|
| 386 |
+
if qualname.startswith(f"{_CG_PREFIX}full"):
|
| 387 |
+
return "FULL"
|
| 388 |
+
|
| 389 |
+
if actual_gm is not None:
|
| 390 |
+
for key, val in actual_gm.__dict__.items():
|
| 391 |
+
if not key.startswith("submod_"):
|
| 392 |
+
continue
|
| 393 |
+
sub_qualname = getattr(val, "__qualname__", "") or ""
|
| 394 |
+
if sub_qualname.startswith(f"{_CG_PREFIX}piecewise"):
|
| 395 |
+
return "PIECEWISE"
|
| 396 |
+
|
| 397 |
+
return "NONE"
|
| 398 |
+
|
| 399 |
+
@staticmethod
|
| 400 |
+
def _find_graph_module_deep(obj, _depth: int = 0, _max_depth: int = 4) -> Optional[Any]:
|
| 401 |
+
"""Recursively walk closure chain to find a ``torch.fx.GraphModule``.
|
| 402 |
+
|
| 403 |
+
This is needed for FULL cudagraph mode where the split GraphModule is
|
| 404 |
+
wrapped by ``gen_wrap_func_for_cudagraph`` (+ ``@instrument_nvtx``),
|
| 405 |
+
placing the GraphModule 2-3 levels deep in the closure chain.
|
| 406 |
+
"""
|
| 407 |
+
import torch.fx
|
| 408 |
+
|
| 409 |
+
if isinstance(obj, torch.fx.GraphModule):
|
| 410 |
+
return obj
|
| 411 |
+
if _depth >= _max_depth:
|
| 412 |
+
return None
|
| 413 |
+
if not callable(obj) or not getattr(obj, "__closure__", None):
|
| 414 |
+
return None
|
| 415 |
+
for cell in obj.__closure__:
|
| 416 |
+
try:
|
| 417 |
+
val = cell.cell_contents
|
| 418 |
+
except ValueError:
|
| 419 |
+
continue
|
| 420 |
+
if isinstance(val, torch.fx.GraphModule):
|
| 421 |
+
return val
|
| 422 |
+
if callable(val):
|
| 423 |
+
found = Introspector._find_graph_module_deep(val, _depth + 1, _max_depth)
|
| 424 |
+
if found is not None:
|
| 425 |
+
return found
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def _find_graph_module(obj) -> Optional[Any]:
|
| 430 |
+
"""Walk closure chain to find a torch.fx.GraphModule."""
|
| 431 |
+
import torch.fx
|
| 432 |
+
|
| 433 |
+
if isinstance(obj, torch.fx.GraphModule):
|
| 434 |
+
return obj
|
| 435 |
+
if hasattr(obj, "__self__") and isinstance(obj.__self__, torch.fx.GraphModule):
|
| 436 |
+
return obj.__self__
|
| 437 |
+
if not callable(obj) or not hasattr(obj, "__closure__") or not obj.__closure__:
|
| 438 |
+
return None
|
| 439 |
+
for cell in obj.__closure__:
|
| 440 |
+
try:
|
| 441 |
+
val = cell.cell_contents
|
| 442 |
+
except ValueError:
|
| 443 |
+
continue
|
| 444 |
+
if isinstance(val, torch.fx.GraphModule):
|
| 445 |
+
return val
|
| 446 |
+
if hasattr(val, "__self__") and isinstance(val.__self__, torch.fx.GraphModule):
|
| 447 |
+
return val.__self__
|
| 448 |
+
return None
|
| 449 |
+
|
| 450 |
+
@staticmethod
|
| 451 |
+
def _find_compiled_fx_graph(obj, _depth: int = 0) -> Optional[Any]:
|
| 452 |
+
"""Walk closure chain (up to 4 levels) to find a CompiledFxGraph."""
|
| 453 |
+
if _depth > 4:
|
| 454 |
+
return None
|
| 455 |
+
try:
|
| 456 |
+
from torch._inductor.codecache import CompiledFxGraph
|
| 457 |
+
except ImportError:
|
| 458 |
+
return None
|
| 459 |
+
if isinstance(obj, CompiledFxGraph):
|
| 460 |
+
return obj
|
| 461 |
+
if not callable(obj) or not hasattr(obj, "__closure__") or not obj.__closure__:
|
| 462 |
+
return None
|
| 463 |
+
for cell in obj.__closure__:
|
| 464 |
+
try:
|
| 465 |
+
val = cell.cell_contents
|
| 466 |
+
except ValueError:
|
| 467 |
+
continue
|
| 468 |
+
if isinstance(val, CompiledFxGraph):
|
| 469 |
+
return val
|
| 470 |
+
if callable(val):
|
| 471 |
+
found = Introspector._find_compiled_fx_graph(val, _depth + 1)
|
| 472 |
+
if found is not None:
|
| 473 |
+
return found
|
| 474 |
+
return None
|
| 475 |
+
|
| 476 |
+
@staticmethod
|
| 477 |
+
def build_entry_info(entry, index: int, fn_globals: dict) -> EntryInfo:
|
| 478 |
+
"""Build an EntryInfo from a CacheEntry."""
|
| 479 |
+
tc = entry.code
|
| 480 |
+
decompiled = safe_decompile(tc)
|
| 481 |
+
|
| 482 |
+
compiled_names = [n for n in tc.co_names if n.startswith("__compiled")]
|
| 483 |
+
compiled_fns = []
|
| 484 |
+
for cn in compiled_names:
|
| 485 |
+
cf = Introspector.extract_compiled_fn_info(cn, fn_globals)
|
| 486 |
+
if cf:
|
| 487 |
+
compiled_fns.append(cf)
|
| 488 |
+
|
| 489 |
+
resume_names = [n for n in tc.co_names if n.startswith("__resume")]
|
| 490 |
+
resume_fns = []
|
| 491 |
+
for rn in resume_names:
|
| 492 |
+
rfn = fn_globals.get(rn)
|
| 493 |
+
if rfn is not None and hasattr(rfn, "__code__"):
|
| 494 |
+
resume_info = Introspector.build_function_info(rfn, fn_globals=fn_globals)
|
| 495 |
+
resume_info.name = rn
|
| 496 |
+
resume_fns.append(resume_info)
|
| 497 |
+
|
| 498 |
+
guard = Introspector.extract_guard_info(entry)
|
| 499 |
+
|
| 500 |
+
return EntryInfo(
|
| 501 |
+
index=index,
|
| 502 |
+
dynamo_code=tc,
|
| 503 |
+
decompiled_source=decompiled,
|
| 504 |
+
guard=guard,
|
| 505 |
+
compiled_fns=compiled_fns,
|
| 506 |
+
resume_fns=resume_fns,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
@staticmethod
|
| 510 |
+
def build_function_info(fn, fn_globals: Optional[dict] = None) -> FunctionInfo:
|
| 511 |
+
"""Build full FunctionInfo by walking CacheEntry chain."""
|
| 512 |
+
if fn_globals is None:
|
| 513 |
+
fn_globals = fn.__globals__ if hasattr(fn, "__globals__") else {}
|
| 514 |
+
|
| 515 |
+
code = fn.__code__ if hasattr(fn, "__code__") else fn
|
| 516 |
+
name = code.co_name
|
| 517 |
+
original_source = safe_decompile(code)
|
| 518 |
+
|
| 519 |
+
entries_raw = Introspector.get_cache_entries(fn)
|
| 520 |
+
entries = []
|
| 521 |
+
for i, raw_entry in enumerate(entries_raw):
|
| 522 |
+
entries.append(Introspector.build_entry_info(raw_entry, i, fn_globals))
|
| 523 |
+
|
| 524 |
+
return FunctionInfo(name=name, original_code=code, original_source=original_source, entries=entries)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/model.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Data model for structured compilation output.
|
| 16 |
+
|
| 17 |
+
These dataclasses represent the full compilation state that
|
| 18 |
+
torch.compile produces, organized to reflect the actual runtime
|
| 19 |
+
structure: CacheEntry linked list, fn/resume recursion,
|
| 20 |
+
compiled_fn → backend mapping, and guard trees.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import dataclasses
|
| 26 |
+
import dis
|
| 27 |
+
import inspect
|
| 28 |
+
import io
|
| 29 |
+
from types import CodeType
|
| 30 |
+
from typing import Dict, List, Optional
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def format_code_info(code: CodeType) -> str:
|
| 34 |
+
"""Format key attributes of a CodeType for debugging."""
|
| 35 |
+
lines: List[str] = []
|
| 36 |
+
lines.append(f"co_name: {code.co_name}")
|
| 37 |
+
if hasattr(code, "co_qualname"):
|
| 38 |
+
lines.append(f"co_qualname: {code.co_qualname}")
|
| 39 |
+
lines.append(f"co_filename: {code.co_filename}")
|
| 40 |
+
lines.append(f"co_firstlineno: {code.co_firstlineno}")
|
| 41 |
+
lines.append(f"co_argcount: {code.co_argcount}")
|
| 42 |
+
lines.append(f"co_kwonlyargcount:{code.co_kwonlyargcount}")
|
| 43 |
+
lines.append(f"co_varnames: {code.co_varnames}")
|
| 44 |
+
lines.append(f"co_freevars: {code.co_freevars}")
|
| 45 |
+
lines.append(f"co_cellvars: {code.co_cellvars}")
|
| 46 |
+
lines.append(f"co_names: {code.co_names}")
|
| 47 |
+
flags = code.co_flags
|
| 48 |
+
flag_strs = [name for name, val in _CODE_FLAGS.items() if flags & val]
|
| 49 |
+
lines.append(f"co_flags: 0x{flags:04x} ({' | '.join(flag_strs) if flag_strs else 'none'})")
|
| 50 |
+
lines.append(f"co_stacksize: {code.co_stacksize}")
|
| 51 |
+
lines.append("")
|
| 52 |
+
lines.append("co_consts:")
|
| 53 |
+
for i, c in enumerate(code.co_consts):
|
| 54 |
+
lines.append(f" [{i:3d}] {type(c).__name__:12s} {_safe_repr(c)}")
|
| 55 |
+
lines.append("")
|
| 56 |
+
lines.append("dis:")
|
| 57 |
+
buf = io.StringIO()
|
| 58 |
+
dis.dis(code, file=buf)
|
| 59 |
+
lines.append(buf.getvalue())
|
| 60 |
+
return "\n".join(lines)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
_CODE_FLAGS = {
|
| 64 |
+
"CO_OPTIMIZED": inspect.CO_OPTIMIZED,
|
| 65 |
+
"CO_NEWLOCALS": inspect.CO_NEWLOCALS,
|
| 66 |
+
"CO_VARARGS": inspect.CO_VARARGS,
|
| 67 |
+
"CO_VARKEYWORDS": inspect.CO_VARKEYWORDS,
|
| 68 |
+
"CO_NESTED": inspect.CO_NESTED,
|
| 69 |
+
"CO_GENERATOR": inspect.CO_GENERATOR,
|
| 70 |
+
"CO_COROUTINE": inspect.CO_COROUTINE,
|
| 71 |
+
"CO_ASYNC_GENERATOR": inspect.CO_ASYNC_GENERATOR,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _safe_repr(obj, max_len: int = 120) -> str:
|
| 76 |
+
try:
|
| 77 |
+
r = repr(obj)
|
| 78 |
+
except Exception:
|
| 79 |
+
r = f"<repr failed: {type(obj).__name__}>"
|
| 80 |
+
if len(r) > max_len:
|
| 81 |
+
r = r[: max_len - 3] + "..."
|
| 82 |
+
return r
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclasses.dataclass
|
| 86 |
+
class GuardNode:
|
| 87 |
+
"""One node in the guard tree (mirrors RootGuardManager / GuardManager)."""
|
| 88 |
+
|
| 89 |
+
type_name: str
|
| 90 |
+
leaf_guards: List[str]
|
| 91 |
+
children: List["GuardNode"] = dataclasses.field(default_factory=list)
|
| 92 |
+
|
| 93 |
+
def format(self, depth: int = 0, max_depth: int = 32) -> str:
|
| 94 |
+
prefix = " " * depth
|
| 95 |
+
lines = [f"{prefix}[{self.type_name}] " f"({len(self.leaf_guards)} leaf guards, {len(self.children)} children)"]
|
| 96 |
+
for g in self.leaf_guards:
|
| 97 |
+
lines.append(f"{prefix} LEAF: {g}")
|
| 98 |
+
if depth < max_depth:
|
| 99 |
+
for i, child in enumerate(self.children):
|
| 100 |
+
lines.append(f"{prefix} child[{i}]:")
|
| 101 |
+
lines.append(child.format(depth + 2, max_depth))
|
| 102 |
+
elif self.children:
|
| 103 |
+
lines.append(f"{prefix} ... ({len(self.children)} children omitted)")
|
| 104 |
+
return "\n".join(lines)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclasses.dataclass
|
| 108 |
+
class SubgraphInfo:
|
| 109 |
+
"""One piecewise subgraph in the magi split pipeline."""
|
| 110 |
+
|
| 111 |
+
name: str
|
| 112 |
+
is_splitting_graph: bool = False
|
| 113 |
+
readable_code: Optional[str] = None
|
| 114 |
+
graph_module_code: Optional[str] = None
|
| 115 |
+
fx_graph_tabular: Optional[str] = None
|
| 116 |
+
inductor_code: Optional[str] = None
|
| 117 |
+
|
| 118 |
+
def format(self) -> str:
|
| 119 |
+
if self.inductor_code:
|
| 120 |
+
return self.inductor_code
|
| 121 |
+
if self.readable_code:
|
| 122 |
+
return self.readable_code
|
| 123 |
+
if self.graph_module_code:
|
| 124 |
+
return self.graph_module_code
|
| 125 |
+
tag = "splitting_op" if self.is_splitting_graph else "compiled"
|
| 126 |
+
return f"# {self.name} ({tag})\n"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclasses.dataclass
|
| 130 |
+
class CompiledFnInfo:
|
| 131 |
+
"""What __compiled_fn_xxx actually points to in the backend."""
|
| 132 |
+
|
| 133 |
+
name: str
|
| 134 |
+
backend: str # "eager", "inductor", or "magi_compile"
|
| 135 |
+
cudagraph_mode: Optional[str] = None # "NONE", "PIECEWISE", "FULL" (magi_compile only)
|
| 136 |
+
readable_code: Optional[str] = None
|
| 137 |
+
graph_module_code: Optional[str] = None
|
| 138 |
+
fx_graph_tabular: Optional[str] = None
|
| 139 |
+
source_code: Optional[str] = None
|
| 140 |
+
inductor_post_grad_graph: Optional[str] = None
|
| 141 |
+
runnable_graph_str: Optional[str] = None
|
| 142 |
+
cache_key: Optional[str] = None
|
| 143 |
+
split_graph_readable: Optional[str] = None
|
| 144 |
+
subgraph_infos: List["SubgraphInfo"] = dataclasses.field(default_factory=list)
|
| 145 |
+
|
| 146 |
+
def format(self) -> str:
|
| 147 |
+
"""Full content for writing to file (compiled output)."""
|
| 148 |
+
if self.source_code:
|
| 149 |
+
return self.source_code
|
| 150 |
+
if self.readable_code:
|
| 151 |
+
return self.readable_code
|
| 152 |
+
if self.graph_module_code:
|
| 153 |
+
return self.graph_module_code
|
| 154 |
+
return f"# {self.name} (backend={self.backend})\n"
|
| 155 |
+
|
| 156 |
+
def format_summary(self) -> str:
|
| 157 |
+
"""Short summary for overview / full_code."""
|
| 158 |
+
header = f"{self.name} (backend={self.backend}"
|
| 159 |
+
if self.cudagraph_mode:
|
| 160 |
+
header += f", cudagraph={self.cudagraph_mode}"
|
| 161 |
+
header += ")"
|
| 162 |
+
lines = [header]
|
| 163 |
+
if self.cache_key:
|
| 164 |
+
lines.append(f" cache_key: {self.cache_key}")
|
| 165 |
+
if self.graph_module_code:
|
| 166 |
+
lines.append(" GraphModule.code:")
|
| 167 |
+
for l in self.graph_module_code.strip().splitlines():
|
| 168 |
+
lines.append(f" {l}")
|
| 169 |
+
if self.subgraph_infos:
|
| 170 |
+
lines.append(f" piecewise subgraphs: {len(self.subgraph_infos)}")
|
| 171 |
+
for sg in self.subgraph_infos:
|
| 172 |
+
tag = "splitting_op" if sg.is_splitting_graph else "compiled"
|
| 173 |
+
lines.append(f" {sg.name} ({tag})")
|
| 174 |
+
return "\n".join(lines)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@dataclasses.dataclass
|
| 178 |
+
class GuardInfo:
|
| 179 |
+
"""Guard information for a CacheEntry."""
|
| 180 |
+
|
| 181 |
+
tree: Optional[GuardNode] = None
|
| 182 |
+
closure_vars: Optional[Dict[str, str]] = None
|
| 183 |
+
|
| 184 |
+
def format(self) -> str:
|
| 185 |
+
lines = []
|
| 186 |
+
if self.tree:
|
| 187 |
+
lines.append(self.tree.format())
|
| 188 |
+
if self.closure_vars:
|
| 189 |
+
lines.append(" closure_vars:")
|
| 190 |
+
for k, v in list(self.closure_vars.items())[:8]:
|
| 191 |
+
lines.append(f" {k} = {v}")
|
| 192 |
+
return "\n".join(lines)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dataclasses.dataclass
|
| 196 |
+
class EntryInfo:
|
| 197 |
+
"""One CacheEntry in the linked list."""
|
| 198 |
+
|
| 199 |
+
index: int
|
| 200 |
+
dynamo_code: Optional[CodeType] = None
|
| 201 |
+
decompiled_source: str = ""
|
| 202 |
+
guard: Optional[GuardInfo] = None
|
| 203 |
+
compiled_fns: List[CompiledFnInfo] = dataclasses.field(default_factory=list)
|
| 204 |
+
resume_fns: List["FunctionInfo"] = dataclasses.field(default_factory=list)
|
| 205 |
+
|
| 206 |
+
def format(self, indent: int = 0) -> str:
|
| 207 |
+
pfx = " " * indent
|
| 208 |
+
lines = [f"{pfx}entry[{self.index}]:"]
|
| 209 |
+
if self.decompiled_source:
|
| 210 |
+
lines.append(f"{pfx} dynamo_code (decompiled):")
|
| 211 |
+
for l in self.decompiled_source.splitlines():
|
| 212 |
+
lines.append(f"{pfx} {l}")
|
| 213 |
+
if self.compiled_fns:
|
| 214 |
+
lines.append(f"{pfx} compiled functions:")
|
| 215 |
+
for cf in self.compiled_fns:
|
| 216 |
+
lines.append(cf.format_summary())
|
| 217 |
+
if self.guard:
|
| 218 |
+
lines.append(f"{pfx} guards:")
|
| 219 |
+
lines.append(self.guard.format())
|
| 220 |
+
if self.resume_fns:
|
| 221 |
+
lines.append(f"{pfx} resume functions:")
|
| 222 |
+
for rf in self.resume_fns:
|
| 223 |
+
lines.append(rf.format(indent + 2))
|
| 224 |
+
return "\n".join(lines)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@dataclasses.dataclass
|
| 228 |
+
class FunctionInfo:
|
| 229 |
+
"""A compiled function and its CacheEntry chain."""
|
| 230 |
+
|
| 231 |
+
name: str
|
| 232 |
+
original_code: Optional[CodeType] = None
|
| 233 |
+
original_source: str = ""
|
| 234 |
+
entries: List[EntryInfo] = dataclasses.field(default_factory=list)
|
| 235 |
+
|
| 236 |
+
def format(self, indent: int = 0) -> str:
|
| 237 |
+
pfx = " " * indent
|
| 238 |
+
lines = [f"{pfx}{self.name}: {len(self.entries)} cache entries"]
|
| 239 |
+
for entry in self.entries:
|
| 240 |
+
lines.append(entry.format(indent + 1))
|
| 241 |
+
return "\n".join(lines)
|
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/result.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""CaptureResult — structured data model for one compilation event."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import dataclasses
|
| 20 |
+
import time
|
| 21 |
+
from types import CodeType
|
| 22 |
+
from typing import List, Optional
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclasses.dataclass
|
| 26 |
+
class CaptureResult:
|
| 27 |
+
"""Data captured from a single ``torch.compile`` bytecode event.
|
| 28 |
+
|
| 29 |
+
- original_code: the user's original function code
|
| 30 |
+
- dynamo_code: the code after Dynamo transformation (with __compiled_fn / __resume calls)
|
| 31 |
+
- decompiled_source: dynamo_code decompiled back to Python source
|
| 32 |
+
- fn_globals: the function's global namespace (for post-hoc introspection)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
function_name: str
|
| 36 |
+
original_code: CodeType
|
| 37 |
+
dynamo_code: CodeType
|
| 38 |
+
decompiled_source: str
|
| 39 |
+
fn_globals: Optional[dict] = None
|
| 40 |
+
guards: List[str] = dataclasses.field(default_factory=list)
|
| 41 |
+
graph_source: Optional[str] = None
|
| 42 |
+
timestamp: float = dataclasses.field(default_factory=time.time)
|
| 43 |
+
|
| 44 |
+
def summary(self) -> str:
|
| 45 |
+
n_guards = len(self.guards)
|
| 46 |
+
return (
|
| 47 |
+
f"[{self.function_name}] "
|
| 48 |
+
f"original={self.original_code.co_name}, "
|
| 49 |
+
f"guards={n_guards}, "
|
| 50 |
+
f"graph={'yes' if self.graph_source else 'no'}"
|
| 51 |
+
)
|