jiadisu Claude Opus 4.6 commited on
Commit
e6066e8
·
1 Parent(s): 1b389ac

Switch back to Docker SDK with local pkgs

Browse files

- Dockerfile: CUDA 12.4 base image, install MagiCompiler from pkgs/,
flash-attn, and stable-audio whl
- README.md: sdk: docker
- app.py: remove spaces.GPU decorator
- pkgs/: MagiCompiler source + stable-audio whl

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. Dockerfile +30 -18
  3. README.md +1 -2
  4. app.py +1 -4
  5. pkgs/MagiCompiler/.gitignore +216 -0
  6. pkgs/MagiCompiler/.pre-commit-config.yaml +60 -0
  7. pkgs/MagiCompiler/LICENSE +201 -0
  8. pkgs/MagiCompiler/README.md +186 -0
  9. pkgs/MagiCompiler/docs/AutoCudaGraphDesign.md +174 -0
  10. pkgs/MagiCompiler/docs/Hunyuan15Benchmark.md +79 -0
  11. pkgs/MagiCompiler/docs/Wan2.2Benchmark.md +72 -0
  12. pkgs/MagiCompiler/docs/WhyMagiCompiler.md +246 -0
  13. pkgs/MagiCompiler/docs/WhyMagiDepyf.md +175 -0
  14. pkgs/MagiCompiler/docs/assets/submod_0_rank_0.pdf +3 -0
  15. pkgs/MagiCompiler/magi_compiler/__init__.py +17 -0
  16. pkgs/MagiCompiler/magi_compiler/_cache_data_cls.py +28 -0
  17. pkgs/MagiCompiler/magi_compiler/api.py +666 -0
  18. pkgs/MagiCompiler/magi_compiler/compile_artifacts.py +125 -0
  19. pkgs/MagiCompiler/magi_compiler/config.py +282 -0
  20. pkgs/MagiCompiler/magi_compiler/cuda/cudart.py +60 -0
  21. pkgs/MagiCompiler/magi_compiler/cuda_graph_mgr.py +931 -0
  22. pkgs/MagiCompiler/magi_compiler/joint_graph_partition.py +180 -0
  23. pkgs/MagiCompiler/magi_compiler/magi_backend.py +607 -0
  24. pkgs/MagiCompiler/magi_compiler/magi_compiler_base.py +219 -0
  25. pkgs/MagiCompiler/magi_compiler/magi_depyf/__init__.py +21 -0
  26. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/__init__.py +19 -0
  27. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/__init__.py +22 -0
  28. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/decompile_context.py +53 -0
  29. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handler_registry.py +62 -0
  30. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/__init__.py +22 -0
  31. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/arithmetic.py +144 -0
  32. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/calls.py +200 -0
  33. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/containers.py +200 -0
  34. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/control_flow.py +273 -0
  35. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/load_store.py +262 -0
  36. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/stack_ops.py +84 -0
  37. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/instruction.py +129 -0
  38. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/source_emitter.py +153 -0
  39. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/decompiler.py +230 -0
  40. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/__init__.py +35 -0
  41. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/branch_dedup.py +99 -0
  42. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/for_temps.py +57 -0
  43. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/inline_temps.py +165 -0
  44. pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/recompiler.py +53 -0
  45. pkgs/MagiCompiler/magi_compiler/magi_depyf/demo_toy_example.py +54 -0
  46. pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/__init__.py +57 -0
  47. pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/dump_src.py +78 -0
  48. pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/introspect.py +524 -0
  49. pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/model.py +241 -0
  50. pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/result.py +51 -0
.gitattributes CHANGED
@@ -4,3 +4,5 @@
4
  *.png filter=lfs diff=lfs merge=lfs -text
5
  *.jpg filter=lfs diff=lfs merge=lfs -text
6
  *.jpeg filter=lfs diff=lfs merge=lfs -text
 
 
 
4
  *.png filter=lfs diff=lfs merge=lfs -text
5
  *.jpg filter=lfs diff=lfs merge=lfs -text
6
  *.jpeg filter=lfs diff=lfs merge=lfs -text
7
+ *.pdf filter=lfs diff=lfs merge=lfs -text
8
+ *.whl filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -1,28 +1,46 @@
1
  # =============================================================================
2
  # HF Spaces Docker image for daVinci-MagiHuman
3
- # Hardware: A100-80GB (or H100)
4
  # =============================================================================
5
- # Based on the official MagiCompiler image which includes:
6
- # - CUDA 12.4, cuDNN, Python 3.12, PyTorch 2.9
7
- # - MagiCompiler (pre-installed)
8
- # - Flash Attention 3 (Hopper) (pre-installed)
9
- # =============================================================================
10
- FROM sandai/magi-compiler:latest
11
 
12
  ENV DEBIAN_FRONTEND=noninteractive
13
  ENV PYTHONUNBUFFERED=1
14
  ENV GRADIO_SERVER_NAME=0.0.0.0
15
  ENV GRADIO_SERVER_PORT=7860
16
 
17
- # System deps needed for audio/video processing
18
  RUN apt-get update && apt-get install -y --no-install-recommends \
19
- ffmpeg libsndfile1 && \
20
- rm -rf /var/lib/apt/lists/*
 
 
 
21
 
22
  WORKDIR /app
23
 
24
  # ---------------------------------------------------------------------------
25
- # Python dependencies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # ---------------------------------------------------------------------------
27
  COPY requirements.txt requirements-nodeps.txt ./
28
  RUN pip install --no-cache-dir -r requirements.txt && \
@@ -36,16 +54,10 @@ COPY inference/ inference/
36
  COPY example/ example/
37
  COPY app.py .
38
 
39
- # ---------------------------------------------------------------------------
40
  # Model weights are downloaded at runtime from HF Hub.
41
- # Set HF_TOKEN as a Space secret if any repos are gated/private.
42
- #
43
- # Persistent storage (/data) is recommended on HF Spaces so weights survive
44
- # container restarts. Enable it in Space settings → "Persistent storage".
45
- # ---------------------------------------------------------------------------
46
  ENV MODEL_ROOT=/data/models
47
 
48
- # HF Spaces requires the app to listen on port 7860
49
  EXPOSE 7860
50
 
51
  CMD ["python", "app.py"]
 
1
  # =============================================================================
2
  # HF Spaces Docker image for daVinci-MagiHuman
3
+ # Hardware: A100-80GB (recommended)
4
  # =============================================================================
5
+ FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
 
 
 
 
 
6
 
7
  ENV DEBIAN_FRONTEND=noninteractive
8
  ENV PYTHONUNBUFFERED=1
9
  ENV GRADIO_SERVER_NAME=0.0.0.0
10
  ENV GRADIO_SERVER_PORT=7860
11
 
12
+ # System deps
13
  RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ python3.12 python3.12-dev python3.12-venv python3-pip \
15
+ git ffmpeg libsndfile1 ninja-build && \
16
+ rm -rf /var/lib/apt/lists/* && \
17
+ ln -sf /usr/bin/python3.12 /usr/bin/python && \
18
+ ln -sf /usr/bin/python3.12 /usr/bin/python3
19
 
20
  WORKDIR /app
21
 
22
  # ---------------------------------------------------------------------------
23
+ # PyTorch (must be installed first — MagiCompiler build depends on it)
24
+ # ---------------------------------------------------------------------------
25
+ RUN pip install --no-cache-dir --upgrade pip && \
26
+ pip install --no-cache-dir torch torchvision torchaudio \
27
+ --index-url https://download.pytorch.org/whl/cu124
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Local packages: MagiCompiler + stable-audio whl
31
+ # ---------------------------------------------------------------------------
32
+ COPY pkgs/ pkgs/
33
+ RUN pip install -e ./pkgs/MagiCompiler \
34
+ --no-build-isolation --config-settings editable_mode=compat && \
35
+ pip install --no-cache-dir pkgs/magife_stable_audio_open-1.0.0+mav.1-py3-none-any.whl
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Flash Attention (pre-built wheel for CUDA 12.4 + PyTorch 2.9)
39
+ # ---------------------------------------------------------------------------
40
+ RUN pip install --no-cache-dir flash-attn --no-build-isolation
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Project Python dependencies
44
  # ---------------------------------------------------------------------------
45
  COPY requirements.txt requirements-nodeps.txt ./
46
  RUN pip install --no-cache-dir -r requirements.txt && \
 
54
  COPY example/ example/
55
  COPY app.py .
56
 
 
57
  # Model weights are downloaded at runtime from HF Hub.
58
+ # Enable "Persistent storage" in Space settings so /data survives restarts.
 
 
 
 
59
  ENV MODEL_ROOT=/data/models
60
 
 
61
  EXPOSE 7860
62
 
63
  CMD ["python", "app.py"]
README.md CHANGED
@@ -3,8 +3,7 @@ title: daVinci-MagiHuman
3
  emoji: 🎬
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.23.0
8
  app_port: 7860
9
  ---
10
 
 
3
  emoji: 🎬
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: docker
 
7
  app_port: 7860
8
  ---
9
 
app.py CHANGED
@@ -2,7 +2,7 @@
2
  """
3
  Gradio frontend for daVinci-MagiHuman distilled model.
4
 
5
- Designed for Hugging Face Spaces with ZeroGPU (Gradio SDK).
6
  Accepts an image + text prompt + duration, generates audio-video output.
7
  """
8
 
@@ -12,8 +12,6 @@ import sys
12
  import tempfile
13
  import uuid
14
 
15
- import spaces
16
-
17
  # ---------------------------------------------------------------------------
18
  # 1. Download all model weights from HF Hub (runs on CPU, cached)
19
  # ---------------------------------------------------------------------------
@@ -132,7 +130,6 @@ print("[app] Pipeline ready.")
132
  # 4. Inference wrapper — @spaces.GPU requests a ZeroGPU allocation
133
  # duration= sets the max GPU time in seconds (default 60, max 300)
134
  # ---------------------------------------------------------------------------
135
- @spaces.GPU(duration=300)
136
  def generate_video(
137
  image,
138
  prompt: str,
 
2
  """
3
  Gradio frontend for daVinci-MagiHuman distilled model.
4
 
5
+ Designed for Hugging Face Spaces (Docker SDK, A100-80GB GPU).
6
  Accepts an image + text prompt + duration, generates audio-video output.
7
  """
8
 
 
12
  import tempfile
13
  import uuid
14
 
 
 
15
  # ---------------------------------------------------------------------------
16
  # 1. Download all model weights from HF Hub (runs on CPU, cached)
17
  # ---------------------------------------------------------------------------
 
130
  # 4. Inference wrapper — @spaces.GPU requests a ZeroGPU allocation
131
  # duration= sets the max GPU time in seconds (default 60, max 300)
132
  # ---------------------------------------------------------------------------
 
133
  def generate_video(
134
  image,
135
  prompt: str,
pkgs/MagiCompiler/.gitignore ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # magi_compiler
2
+ magi_compiler/_version.py
3
+ magi_dump_src_dir/
4
+ *.nsys-rep
5
+ *.ncu-rep
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[codz]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Vscode stuff:
65
+ .vscode
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # UV
107
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ #uv.lock
111
+
112
+ # poetry
113
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
114
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
115
+ # commonly ignored for libraries.
116
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
117
+ #poetry.lock
118
+ #poetry.toml
119
+
120
+ # pdm
121
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
122
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
123
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
124
+ #pdm.lock
125
+ pdm.toml
126
+ .pdm-python
127
+ .pdm-build/
128
+
129
+ # pixi
130
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
131
+ #pixi.lock
132
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
133
+ # in the .venv directory. It is recommended not to include this directory in version control.
134
+ .pixi
135
+
136
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
137
+ __pypackages__/
138
+
139
+ # Celery stuff
140
+ celerybeat-schedule
141
+ celerybeat.pid
142
+
143
+ # SageMath parsed files
144
+ *.sage.py
145
+
146
+ # Environments
147
+ .env
148
+ .envrc
149
+ .venv
150
+ env/
151
+ venv/
152
+ ENV/
153
+ env.bak/
154
+ venv.bak/
155
+
156
+ # Spyder project settings
157
+ .spyderproject
158
+ .spyproject
159
+
160
+ # Rope project settings
161
+ .ropeproject
162
+
163
+ # mkdocs documentation
164
+ /site
165
+
166
+ # mypy
167
+ .mypy_cache/
168
+ .dmypy.json
169
+ dmypy.json
170
+
171
+ # Pyre type checker
172
+ .pyre/
173
+
174
+ # pytype static type analyzer
175
+ .pytype/
176
+
177
+ # Cython debug symbols
178
+ cython_debug/
179
+
180
+ # PyCharm
181
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
182
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
183
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
184
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
185
+ #.idea/
186
+
187
+ # Abstra
188
+ # Abstra is an AI-powered process automation framework.
189
+ # Ignore directories containing user credentials, local state, and settings.
190
+ # Learn more at https://abstra.io/docs
191
+ .abstra/
192
+
193
+ # Visual Studio Code
194
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
195
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
196
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
197
+ # you could uncomment the following to ignore the entire vscode folder
198
+ # .vscode/
199
+
200
+ # Ruff stuff:
201
+ .ruff_cache/
202
+
203
+ # PyPI configuration file
204
+ .pypirc
205
+
206
+ # Cursor
207
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
208
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
209
+ # refer to https://docs.cursor.com/context/ignore-files
210
+ .cursorignore
211
+ .cursorindexingignore
212
+
213
+ # Marimo
214
+ marimo/_static/
215
+ marimo/_lsp/
216
+ __marimo__/
pkgs/MagiCompiler/.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: \.patch$
2
+ repos:
3
+ - repo: local
4
+ hooks:
5
+ - id: copyright_checker
6
+ name: copyright_checker
7
+ entry: python3 ./.github/.codestyle/copyright.hook
8
+ language: system
9
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
10
+ - repo: https://github.com/pre-commit/pre-commit-hooks
11
+ rev: v4.4.0
12
+ hooks:
13
+ - id: check-added-large-files
14
+ args:
15
+ - --maxkb=30720
16
+ - id: check-merge-conflict
17
+ - id: check-symlinks
18
+ - id: detect-private-key
19
+ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
20
+ - id: end-of-file-fixer
21
+ - id: trailing-whitespace
22
+ - id: requirements-txt-fixer
23
+ - id: sort-simple-yaml
24
+ - repo: https://github.com/Lucas-C/pre-commit-hooks.git
25
+ rev: v1.5.1
26
+ hooks:
27
+ - id: remove-crlf
28
+ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
29
+ - id: remove-tabs
30
+ name: Tabs remover (C++)
31
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
32
+ args: [--whitespaces-count, '2']
33
+ - id: remove-tabs
34
+ name: Tabs remover (Python)
35
+ files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
36
+ args: [--whitespaces-count, '4']
37
+ - repo: https://github.com/psf/black.git
38
+ rev: 23.3.0
39
+ hooks:
40
+ - id: black
41
+ args: [--line-length=127, --skip-string-normalization, --skip-magic-trailing-comma]
42
+ files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
43
+ - repo: https://github.com/pre-commit/mirrors-isort
44
+ rev: v5.10.1
45
+ hooks:
46
+ - id: isort
47
+ args: [--profile=black, --line-length=127, --multi-line=3, --force-grid-wrap=0]
48
+ files: \.py$
49
+ - repo: https://github.com/PyCQA/autoflake
50
+ rev: v2.3.1
51
+ hooks:
52
+ - id: autoflake
53
+ args: [--remove-all-unused-imports, --remove-unused-variables, --in-place, --ignore-init-module-imports, --ignore-pass-after-docstring]
54
+ files: \.py$
55
+ - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks.git
56
+ rev: v2.9.0
57
+ hooks:
58
+ - id: pretty-format-yaml
59
+ args: [--autofix, --indent, '4']
60
+ additional_dependencies: [setuptools]
pkgs/MagiCompiler/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
pkgs/MagiCompiler/README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## MagiCompiler
2
+
3
+ An engineering-oriented compiler and execution augmentation library for PyTorch 2.8+, providing module-level compilation decorators, backend adapters, graph partitioning strategies, readable and reusable compile artifacts, and tightly integrated runtime scheduling for any inference engine. The design goal is to systematically expose capabilities of PyTorch Dynamo / AOTAutograd / Inductor / Triton while prioritizing correctness, stability, observability, and maintainability.
4
+
5
+ ### Design Overview
6
+
7
+ - Compilation entrypoint: the `@magi_compile` decorator augments `nn.Module.forward` with compilation, including dynamic-shape annotation and argument validation.
8
+ - Partitioning and passes: configurable graph partitioning and pass management (e.g., `InductorPass`, `PostGradPassManager`) for fusion, kernel generation, and tuning.
9
+ - Artifact system: persists compile artifacts using a Python file/directory layout for readability, auditability, and portability (see “Compile Cache Overview”).
10
+ - Configuration: `CompileConfig` centralizes backend, partition rules, cache root, runtime shapes, and other key parameters.
11
+
12
+ ### Key Features
13
+
14
+ - Dynamic-shape annotations:
15
+ - Automatic inference: when a `forward` parameter is annotated as `torch.Tensor` or `torch.Tensor | None`, dimension 0 is treated as dynamic by default.
16
+ - Explicit specification: use `@magi_compile(dynamic_arg_dims={...})` to mark dimensions (negative indices supported).
17
+ - Consistency constraints: parameters that alternately appear as `None` and non-`None` across the model lifetime cannot be captured into the same computation graph.
18
+ - Backend selection and standalone compilation:
19
+ - `inductor` mode defaults to PyTorch 2.8+ `standalone_compile`, producing reusable artifacts.
20
+ - `eager` mode is available for debugging or fallback paths.
21
+ - Partitioning and passes: operator-set-driven partition rules and pass contexts that stabilize subgraph boundaries and kernel generation across runtime shapes.
22
+ - Readable, portable artifacts: structured directories with Python files for quick triage and cross-environment debugging.
23
+ - Engine integration: the decorator reads engine-level `CompileConfig` to stay aligned with distributed/scheduling components.
24
+
25
+ ## Installation and Requirements
26
+
27
+ - Python ≥ 3.10
28
+ - PyTorch ≥ 2.8 (with `torch._inductor.standalone_compile` available)
29
+ - Recommended to be used within the Athena environment, together with its dependencies and distributed components (e.g., CUDA Graph manager).
30
+
31
+ For local development, install in editable mode:
32
+
33
+ ```bash
34
+ pip install -e . --no-build-isolation --config-settings editable_mode=compat
35
+ ```
36
+
37
+ ## Quick Start
38
+
39
+ ### Minimal Example (automatic dynamic-dim inference)
40
+
41
+ ```python
42
+ import torch
43
+ from torch import nn
44
+ from magi_compiler.decorator import magi_compile
45
+
46
+ @magi_compile
47
+ class MyModel(nn.Module):
48
+ def __init__(self, *, model_config):
49
+ super().__init__()
50
+ self.linear = nn.Linear(10, 5)
51
+
52
+ def forward(self, x: torch.Tensor, y: torch.Tensor | None) -> torch.Tensor:
53
+ if y is not None:
54
+ return self.linear(x + y)
55
+ return self.linear(x)
56
+
57
+ # In Athena, model_config is typically provided by the engine
58
+ model = MyModel(model_config=...)
59
+ out1 = model(torch.randn(4, 10), torch.randn(4, 10))
60
+ out2 = model(torch.randn(8, 10), None) # dynamic batch dimension
61
+ ```
62
+
63
+ ### Explicit Dynamic-Dim Specification
64
+
65
+ ```python
66
+ @magi_compile(dynamic_arg_dims={"x": -1}) # mark the last dimension as dynamic
67
+ class DynamicDimModel(nn.Module):
68
+ def __init__(self, *, model_config):
69
+ super().__init__()
70
+ self.proj = nn.Linear(16, 16, bias=False)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.proj(x)
74
+
75
+ m = DynamicDimModel(model_config=...)
76
+ _ = m(torch.randn(2, 16))
77
+ _ = m(torch.randn(2, 32)) # allow the last dimension to vary
78
+ ```
79
+
80
+ ## Configuration and Modes
81
+
82
+ - `CompileConfig`: centralizes compile parameters (backend, cache paths, partition strategy, dynamic shapes, traced files, etc.).
83
+ - `CompileMode`: typical setting is `CompileMode.TORCH_COMPILE`.
84
+ - Backends:
85
+ - `inductor`: uses `standalone_compile` to produce reusable artifacts, ideal for production deployments.
86
+ - `eager`: convenient for rapid debugging or as a fallback.
87
+
88
+ ## Architecture and Execution Flow (Brief)
89
+
90
+ 1. `@magi_compile` wraps `nn.Module`:
91
+ - infers/validates `dynamic_arg_dims`;
92
+ - extends MRO by injecting `MagiCompilerBase`;
93
+ - reads engine-level `CompileConfig` in MagiCompiler.
94
+ 2. `CompilerManager`:
95
+ - defines cache keys using `(runtime_shape, graph_index, backend)`;
96
+ - dispatches to backends via `CompilerInterface` (`InductorStandaloneAdaptor` or `EagerAdaptor`);
97
+ - applies partition rules and pass contexts within `compile_context(...)`;
98
+ - serializes compile artifacts into a human-readable directory structure.
99
+ 3. Monitoring and statistics:
100
+ - counters and timestamps report per-shape/per-subgraph latencies and milestones.
101
+
102
+ ## Compile Cache Overview
103
+
104
+ This document summarizes the cache files generated by `torch.compile` (TorchDynamo + TorchInductor + AOTAutograd + Triton). Reference path: `cache/`.
105
+
106
+ ### Directory Layout (Tree)
107
+
108
+ ```text
109
+ cache/
110
+ ├─ depyf/
111
+ │ └─ rank_0/
112
+ │ ├─ __transformed_code_0_for_forward.py
113
+ │ ├─ decompiled_code.py
114
+ │ ├─ full_code_for_forward_0.py
115
+ │ ├─ __compiled_fn_1.BEFORE_PRE_GRAD.{0..N}.py
116
+ │ ├─ __compiled_fn_1.kernel_{0..K}.py
117
+ │ ├─ __compiled_fn_1.__compiled_fn_1_<uuid>.0.py
118
+ │ ├─ __compiled_fn_1.Before_split.0.py
119
+ │ ├─ __compiled_fn_1.After_split.0.py
120
+ │ ├─ __compiled_fn_1.pre_split_module.0.py
121
+ │ ├─ __compiled_fn_1.post_split_module.0.py
122
+ │ └─ __compiled_fn_1.pre_insert_deferred_runtime_asserts__<uuid>.0.py
123
+
124
+ └─ torch_compile_cache/
125
+ └─ bfa0df33ea/ # Hash for graph + compile options + device, etc.
126
+ └─ rank_0/ # Rank id in distributed/multi-GPU runs
127
+ └─ backbone/
128
+ ├─ computation_graph.py
129
+ ├─ magi_compile_cache.py
130
+ ├─ artifact_shape_None_subgraph_0/
131
+ ├─ artifact_shape_None_subgraph_1/
132
+ ├─ ...
133
+ └─ artifact_shape_None_subgraph_30/
134
+ ├─ ir/
135
+ │ └─ *.py # Python/Triton kernels generated by Inductor for this subgraph
136
+ ├─ fxgraph/
137
+ │ └─ */*/<hash> # Binary FX IR/metadata snapshots (not human-readable)
138
+ ├─ aotautograd/
139
+ │ └─ */*/<hash> # AOTAutograd partition/capture metadata and artifacts
140
+ ├─ 44/
141
+ │ └─ *.py # Other sharded/generated code buckets
142
+ └─ ... # Structure may vary slightly across subgraphs
143
+ ```
144
+
145
+ ### What Each File/Dir Is For
146
+
147
+ - `cache/depyf/` (TorchDynamo/Depyf debug exports)
148
+ - `__transformed_code_0_for_forward.py`: The Dynamo-transformed `forward` code (diff-friendly view of pre/post transformation).
149
+ - `decompiled_code.py`: Decompiled snapshot to help map traced graphs back to original Python.
150
+ - `full_code_for_forward_0.py`: A more complete expanded `forward` for inspection.
151
+ - `__compiled_fn_1.BEFORE_PRE_GRAD.{i}.py`: Intermediate wrapper snapshots at specific compile stages (e.g., before autodiff).
152
+ - `__compiled_fn_1.kernel_{k}.py`: Entrypoints/wrappers for kernels generated at various stages.
153
+ - `Before_split` / `After_split` / `pre_split_module` / `post_split_module`: Intermediate forms around graph partitioning.
154
+ - `pre_insert_deferred_runtime_asserts__*.py`: Snapshot before inserting deferred runtime assertions (dynamic shapes/guards).
155
+
156
+ - `cache/torch_compile_cache/` (TorchInductor artifacts)
157
+ - `bfa0df33ea/`: Namespace keyed by a hash of model structure, compile settings, and device info.
158
+ - `rank_0/`: Bucket per process rank for distributed runs.
159
+ - `backbone/`:
160
+ - `computation_graph.py`: Full model FX GraphModule with symbolic dims; shared across subgraph kernels.
161
+ - `magi_compile_cache.py`:
162
+ - Maps subgraph indices to artifact directories, e.g. `(None, i, 'inductor_standalone') -> artifact_shape_None_subgraph_i/`.
163
+ - Registers and asynchronously compiles Triton kernels via `AsyncCompile.triton(...)`, including autotune metadata, device properties, scheduling hints, etc.
164
+ - `artifact_shape_None_subgraph_{N}/`:
165
+ - `ir/*.py`: Inductor-generated Python/Triton kernels and scheduling code for this subgraph (readable).
166
+ - `fxgraph/*/*/<hash>`: FX IR/metadata snapshots for fast graph reconstruction (binary; do not edit).
167
+ - `aotautograd/*/*/<hash>`: AOTAutograd partitions/captures and replay requirements.
168
+ - Additional hashed/prefixed buckets (e.g., `44/`, `o5/`, `55/`, `br/`) containing generated operator/subtask code.
169
+
170
+ ### FAQ
171
+
172
+ - How are these caches produced?
173
+ - At runtime by `torch.compile(...)`, after TorchDynamo tracing, AOTAutograd partitioning, TorchInductor lowering/fusion, and Triton codegen.
174
+ - Will they change across runs?
175
+ - Yes. Different input shapes, env vars, device info, or compile options can produce different hash namespaces (e.g., a new `bfa0df33ea`).
176
+ - Is it safe to delete them?
177
+ - Yes. You can delete `cache/`. It will be rebuilt on demand; the next run will be slower due to recompilation.
178
+
179
+ ## Compatibility and Recommendations
180
+
181
+ - Prefer official PyTorch ≥ 2.8 builds to ensure `standalone_compile` availability.
182
+ - For highly dynamic models, explicitly mark key dynamic dimensions to improve graph capture and cache reuse.
183
+
184
+ ## Acknowledgments
185
+
186
+ This library builds upon capabilities of PyTorch Dynamo, AOTAutograd, Inductor, and Triton, and incorporates engineering practices and interface designs inspired by the vLLM community. We thank the relevant open-source communities and contributors.
pkgs/MagiCompiler/docs/AutoCudaGraphDesign.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## AutoCudaGraph Design
2
+
3
+ Author: ZhiyaoCen
4
+
5
+ ## Overview
6
+ AutoCudaGraph is a CUDA Graph optimization module integrated into the MagiCompiler framework, designed to automate CUDA Graph capture, caching, replay, and tensor memory management for PyTorch-based neural network inference. It targets Transformer architectures with dynamic sequence lengths, optimizing kernel execution by reusing pre-captured computation graphs and static tensor buffers. Core Goals:
7
+ * Automate CUDA Graph lifecycle (capture/replay/cache) with minimal code intrusion
8
+ * Support dynamic shape adaptation (sequence length expansion)
9
+ * Optimize memory efficiency via global memory pool and static tensor reuse
10
+ * Ensure consistency between cached graphs and runtime inputs/outputs
11
+ ## Key Components
12
+
13
+
14
+ ### CudaGraphMgr (Core Manager)
15
+ Singleton class managing all CUDA Graph operations:
16
+ ```python
17
+ class CudaGraphMgr:
18
+ def __init__(self):
19
+ self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
20
+ self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
21
+ ```
22
+
23
+ **Core Methods**
24
+ | Method | Purpose |
25
+ |---------------------------------|----------------------------------------|
26
+ | run() | Main entry: Replay cached graph or warm up & capture new graph|
27
+ | wrapped_graph_capture() | Capture CUDA Graph with sliced static input/output tensors |
28
+ | wrapped_graph_replay() | Replay cached CUDA Graph with sliced static tensors and output template wrapping
29
+ | get_expanded_static_tensors() | Expand static tensors, reuse buffers if dimensionally compatible|
30
+
31
+
32
+ ### Signature System
33
+
34
+ StaticSignature
35
+ ```python
36
+ @dataclass(unsafe_hash=True)
37
+ class StaticSignature(HashableDataclass):
38
+ func_name: str = ""
39
+ tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
40
+ ```
41
+ * Encodes fixed properties of input tensors (dtype, static dimensions)
42
+ * Used as primary key for static tensor buffer caching
43
+
44
+ DynamicSignature
45
+ ```python
46
+ @dataclass(unsafe_hash=True)
47
+ class DynamicSignature(HashableDataclass):
48
+ tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
49
+ literals_info: LiteralsInfo = None
50
+ ```
51
+ * Tracks dynamic dimensions (sequence length) and literal parameters
52
+ * Secondary key for graph entry lookup
53
+
54
+ ### Tensor Management
55
+ ```python
56
+ @dataclass
57
+ class StaticTensorEntry:
58
+ input_tensors: Optional[List[torch.Tensor]] = None
59
+ output_tensors: Optional[List[torch.Tensor]] = None
60
+ template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
61
+ ```
62
+ * Memory Reuse: Reuse existing tensor buffers when possible to avoid reallocation
63
+ * Dynamic Expansion: Only expand static tensors when new input dimensions exceed current buffer size
64
+ * Shape Validation: Ensure static dimensions (non-sequence) match between cached and new tensors
65
+
66
+
67
+ ### Graph Management
68
+ ```python
69
+ @dataclass
70
+ class GraphEntry:
71
+ graph: Optional[torch.cuda.CUDAGraph] = None
72
+ inconsistent: bool = False
73
+ invalid: bool = False
74
+
75
+ @dataclass
76
+ class OutputTemplateEntry:
77
+ graph_entry_dict: Dict[int, GraphEntry] = None
78
+ output_template: Any = None
79
+ ```
80
+ * Graph State Tracking: GraphEntry tracks CUDA Graph instances and validity states to control replay eligibility.
81
+ * Layer-wise Organization: OutputTemplateEntry maps dynamic signatures to per-layer GraphEntry for layer-specific graph reuse.
82
+ * Output Consistency: output_template preserves output object structure to ensure consistent result wrapping during replay.
83
+
84
+ ## Execution Flow
85
+ ### Inline Replay (Fast Path)
86
+ * Extract input signatures from runtime arguments
87
+ * Look up cached CUDA Graph via StaticSignature + DynamicSignature + layer number
88
+ * Validate graph consistency (not inconsistent/invalid)
89
+ * Reuse static tensors with dynamic slicing
90
+ * Replay graph and return sliced output
91
+ ### Graph Capture (Slow Path)
92
+ Triggered when no valid cached graph exists or tensor expansion is needed:
93
+ * Execute function to get output tensors
94
+ * Ensure input signatures match post-warmup
95
+ * Expand static buffers if new shapes require it
96
+ * Capture new CUDA Graph with static tensors
97
+ * Store new graph and update tensor entries
98
+ * Return warmup execution output as final result
99
+ ### Sequence Length Handling
100
+ * Only last dimension is static for ND tensors (ND > 1)
101
+ * All dimension is dynamic for 1D tensors (ND=1)
102
+ * Automatic buffer expansion for increasing sequence lengths
103
+ * Invalidates old graphs when tensors are expanded
104
+
105
+
106
+ ## Examples
107
+ ```python
108
+ import torch
109
+ import torch.nn as nn
110
+ from magi_compiler.cuda_graph_mgr import cuda_graph_mgr, cuda_graph_enable_if
111
+
112
+ class SimpleTransformerLayer(nn.Module):
113
+ def __init__(self, hidden_dim: int = 1024, num_heads: int = 8):
114
+ super().__init__()
115
+ self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
116
+ self.linear = nn.Linear(hidden_dim, hidden_dim)
117
+ self.layer_norm = nn.LayerNorm(hidden_dim)
118
+ self.layer_number = 0
119
+
120
+ @cuda_graph_enable_if(lambda: torch.cuda.is_available())
121
+ def forward(self, x: torch.Tensor):
122
+ attn_out, _ = self.self_attn(x, x, x)
123
+ out = self.linear(self.layer_norm(x + attn_out))
124
+ return out
125
+
126
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ model = SimpleTransformerLayer(hidden_dim=1024, num_heads=8).to(device).eval()
128
+ graph_mgr = cuda_graph_mgr()
129
+
130
+ with torch.no_grad():
131
+ input_1 = torch.randn(2, 512, 1024, device=device)
132
+ output_1 = model(input_1)
133
+ print(f"First run (graph capture): Output shape = {output_1.shape}")
134
+ print(f"Cached graphs count: {graph_mgr.graph_count}")
135
+
136
+ input_2 = torch.randn(2, 512, 1024, device=device)
137
+ output_2 = model(input_2)
138
+ print(f"Second run (graph replay): Output shape = {output_2.shape}")
139
+ print(f"Cached graphs count: {graph_mgr.graph_count}")
140
+
141
+ input_3 = torch.randn(2, 1024, 1024, device=device)
142
+ output_3 = model(input_3)
143
+ print(f"Third run (tensor expansion): Output shape = {output_3.shape}")
144
+ print(f"Cached graphs count: {graph_mgr.graph_count}")
145
+ print(f"Static tensor memory usage: {graph_mgr.tensor_mem_size:.2f} MB")
146
+
147
+ print("\nCUDA Graph Cache Details:")
148
+ print(graph_mgr.formatted_cache_str())
149
+
150
+ # StaticSignature: StaticSignature(_cached_hash=None, func_name='SimpleTransformerLayer.forward', tensor_static_infos=(TensorStaticInfo(_cached_hash=None, name='', shapes=(-1, -1, 1024), dtype='torch.float32'),))
151
+ # Input Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
152
+ # Output Static Tensors: [shape=[2, 1024, 1024],dtype=torch.float32]
153
+ # DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 512, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
154
+ # Output Template: FakeTensor(shape=[2, 512, 1024], dtype='torch.float32', device='cuda:0')
155
+ # Layer 0: Graph Status: Invalid
156
+ # DynamicSignature: DynamicSignature(_cached_hash=None, tensor_dynamic_infos=(TensorDynamicInfo(_cached_hash=None, name='', shapes=(2, 1024, -1)),), literals_info=LiteralsInfo(_cached_hash=None, literals=()))
157
+ # Output Template: FakeTensor(shape=[2, 1024, 1024], dtype='torch.float32', device='cuda:0')
158
+ # Layer 0: Graph Status: Valid
159
+ ```
160
+
161
+ ## Limitations and Constraints
162
+ * No support for data-dependent control flow in captured functions
163
+ * Graph capture fails if function contains CPU/GPU synchronization
164
+ * Only supports CUDA tensors (CPU tensors trigger fallback)
165
+ * Custom input classes must inherit from InplaceSubstituteFakeClass
166
+ * Assumes input tensors of captured graphs are not reused externally (risk of cross-scenario static tensor reuse)
167
+ * Relies on identical function, input tensors shapes, and constants for valid graph reuse
168
+ * No support for multi-stream execution scenarios
169
+
170
+ ## Best Practices
171
+ * Dynamic Dimensions: Tensor use sequence length as dimension 0 where possible
172
+ * Monitor Memory Usage: Track graph_mem_pool_size and tensor_mem_size to avoid OOM
173
+ * Specify Layer IDs: Use layer_number to distinguish graphs across different models/layers
174
+ * LRU Cache (Future): Implement cache eviction to limit total graph/tensor count
pkgs/MagiCompiler/docs/Hunyuan15Benchmark.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Hunyuan1.5 Benchmark
2
+
3
+ ### Executive Summary
4
+ This report presents a comprehensive performance evaluation of the **[Athena](https://github.com/world-sim-dev/athena)** framework compared to the baseline **[LightX2V](https://github.com/ModelTC/LightX2V)** framework. The benchmarks were conducted using the **[Hunyuan-1.5](https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5)** model on NVIDIA H100 hardware.
5
+
6
+ ---
7
+ ### 🎯Test Environment & Versioning
8
+ #### Hardware & Settings
9
+
10
+ | Parameter | Value |
11
+ | ------------------- | -------------- |
12
+ | Hardware | NVIDIA H100 |
13
+ | Model | Hunyuan-1.5 480p_t2v_distilled |
14
+ | Precision | torch.bfloat16 |
15
+ | Inference Steps | 20 |
16
+ | Resolution | 480p |
17
+ | FPS | 24 |
18
+ | CFG | Disable |
19
+ #### Software Versioning
20
+ To ensure reproducibility, the following specific commits were used for this benchmark:
21
+ | Framework | Branch / Tag | Commit |
22
+ | --------- | ------------ | ------ |
23
+ | Athena | main|[5e6086b](https://github.com/world-sim-dev/athena/commit/5e6086b4dc2ab60bc4d44dbe39745b4354075121) |
24
+ | LightX2V | main | [5573905](https://github.com/ModelTC/LightX2V/commit/5573905f3f38d876d468b815f86d417a608975b6) |
25
+
26
+ ### 🏆 Performance Benchmarks
27
+ 📊 We compared the iteration speed (seconds per iteration) between Athena and LightX2V across three distinct Context Parallel (CP) configurations.
28
+ | Configuration | Frames | LightX2V (s/it) | Athena (s/it) | Speedup |
29
+ | ------------- | ------ | -------------- | -------------- | ------- |
30
+ | CP1 | 121 | 2.42 | **2.06** | **1.17x** 🚀|
31
+ | CP2 | 121 | 1.38 | **1.13** | **1.22x** 🚀|
32
+ | CP4 | 241 | 2.25 | **1.85** | **1.22x** 🚀|
33
+ | CP8 | 241 | 1.28 | **1.01** | **1.27x** 🚀|
34
+
35
+ ---
36
+ ### 📹 Output Comparison
37
+ | Framework | Video Result |
38
+ | --------- | ---------------------------- |
39
+ | Athena | <img src="../../../assets/athena_hunyuan_1_5_test_videos_20260213_155842_idx0_A_close-up815965.gif" width="480" /> |
40
+ | LightX2V | <img src="../../../assets/lightx2v_hunyuan_1_5_result_A_close-up122526.gif" width="480" /> |
41
+
42
+
43
+ ### 💡 Reproduction Guide
44
+ To reproduce the results presented in this report, follow the steps below using the specified commit hashes.
45
+ #### Setup
46
+ ```bash
47
+ git clone https://github.com/world-sim-dev/athena
48
+ cd athena
49
+ git checkout 5e6086b
50
+ pip install -r requirements.txt
51
+ pip install -r requirements-nodeps.txt
52
+ pip install -e ./pkgs/MagiCompiler --no-build-isolation --config-settings editable_mode=compat
53
+
54
+
55
+ # Clone and install LightX2V (for baseline comparison)
56
+ git clone https://github.com/ModelTC/LightX2V
57
+ cd lightx2v
58
+ git checkout 5573905
59
+ pip install -v .
60
+ ```
61
+
62
+ #### Running Benchmarks
63
+ For Athena, run:
64
+ ```
65
+ RESOLUTION=480p CFG_DISTILLED=true TASK=t2v CHECKPOINT_PATH=path/to/480p_t2v_distilled bash ./scripts/run_hunyuan.sh
66
+ ```
67
+ For LightX2V:
68
+ Clone the scripts from [Benchmark for LightX2V](https://gist.github.com/wtr0504/d80bbebb7da1ef7b58f3e6faf1c68880) and run:
69
+ ```
70
+ git clone https://gist.github.com/wtr0504/d80bbebb7da1ef7b58f3e6faf1c68880
71
+ MODEL_PATH=path/to/HunyuanVideo-1.5 DISTILL_CKPT=path/to/480p_t2v_distilled/diffusion_pytorch_model.safetensors bash run_hunyuan.sh
72
+ ```
73
+
74
+ ### 🔎 MagiCompiler Optimization Methodology
75
+ **Whole Graph Compilation**
76
+ Constant Folding & Dead Code Elimination: Streamlining the computation graph prior to execution.
77
+
78
+ **Coarse-grained Kernel Fusion**
79
+ MagiCompiler aggregates multiple smaller operators into larger, fused kernels. This optimization is critical for efficient execution on the GPU.
pkgs/MagiCompiler/docs/Wan2.2Benchmark.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Wan2.2 Benchmark
2
+
3
+ ### Executive Summary
4
+ This report presents a comprehensive performance evaluation of the **[Athena](https://github.com/world-sim-dev/athena)** framework compared to the baseline **[LightX2V](https://github.com/ModelTC/LightX2V)** framework. The benchmarks were conducted using the **[Wan2.2-TI2V-5B](https://huggingface.co/Wan-AI)** model on NVIDIA H100 hardware.
5
+
6
+ ---
7
+ ### 🎯Test Environment & Versioning
8
+ #### Hardware & Settings
9
+
10
+ | Parameter | Value |
11
+ | ------------------- | -------------- |
12
+ | Hardware | NVIDIA H100 |
13
+ | Model | Wan2.2-TI2V-5B |
14
+ | Precision | torch.bfloat16 |
15
+ | Inference Steps | 50 |
16
+ | Resolution | 704 × 1280(720p)|
17
+ | FPS | 24 |
18
+ | CFG | Enabled |
19
+ #### Software Versioning
20
+ To ensure reproducibility, the following specific commits were used for this benchmark:
21
+ | Framework | Branch / Tag | Commit |
22
+ | --------- | ------------ | ------ |
23
+ | Athena | main|[f676ae6](https://github.com/world-sim-dev/athena/commit/f676ae64ad2fc581289d1c3ae5eb51c15ce76f1d) |
24
+ | LightX2V | main | [33f0f67](https://github.com/ModelTC/LightX2V/commit/33f0f67f4ecdff86b1db676d3e0786628cc31c7b) |
25
+
26
+ ### 🏆 Performance Benchmarks
27
+ 📊 We compared the iteration speed (seconds per iteration) between Athena and LightX2V across three distinct Context Parallel (CP) configurations.
28
+ | Configuration | Frames | LightX2V (s/it) | Athena (s/it) | Speedup |
29
+ | ------------- | ------ | -------------- | -------------- | ------- |
30
+ | CP1 | 121 | 1.928 | **1.69** | **1.14x** 🚀|
31
+ | CP2 | 121 | 1.197 | **1.06** | **1.13x** 🚀|
32
+ | CP4 | 241 | 1.767 | **1.32** | **1.34x** 🚀|
33
+ | CP8 | 241 | 1.507 | **1.35** | **1.12x** 🚀|
34
+
35
+ ---
36
+
37
+ ### 💡 Reproduction Guide
38
+ To reproduce the results presented in this report, follow the steps below using the specified commit hashes.
39
+ #### Setup
40
+ ```bash
41
+ git clone https://github.com/world-sim-dev/athena
42
+ cd athena
43
+ git checkout f676ae6
44
+ pip install -r requirements.txt
45
+
46
+ # Clone and install LightX2V (for baseline comparison)
47
+ git clone https://github.com/ModelTC/LightX2V
48
+ cd lightx2v
49
+ git checkout 33f0f67
50
+ pip install -r requirements.txt
51
+
52
+ ```
53
+
54
+ #### Running Benchmarks
55
+ For Athena, run:
56
+ ```
57
+ bash ./scripts/run_wan2_2_ti2v_i2v.sh
58
+ ```
59
+ For LightX2V:
60
+ Clone the scripts from [Benchmark for LightX2V](https://gist.github.com/wtr0504/629388f17ed38d1c12d5ef5c25a15197) and run:
61
+ ```
62
+ git clone https://gist.github.com/wtr0504/629388f17ed38d1c12d5ef5c25a15197
63
+ bash run_wan.sh
64
+ ```
65
+
66
+ ### 🔎 MagiCompiler Optimization Methodology
67
+ **Whole Graph Compilation**
68
+ Constant Folding & Dead Code Elimination: Streamlining the computation graph prior to execution.
69
+ **Coarse-grained Kernel Fusion**
70
+ MagiCompiler aggregates multiple smaller operators into larger, fused kernels. This optimization is critical for efficient execution on the GPU.
71
+ **All to All Communication**
72
+ MagiCompiler Uses ``all_to_all_single`` (1 communication op per sync point) while LightX2V Uses all_to_all x 3 (3 separate communication ops).
pkgs/MagiCompiler/docs/WhyMagiCompiler.md ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Why MagiCompiler?
2
+
3
+ ## 1. Compiler Overview
4
+
5
+ ### 1.1 Background
6
+
7
+ We have long encountered several significant challenges in model optimization:
8
+
9
+ 1. **Blurred Acceleration Boundaries:** There is ambiguity regarding the extent of optimization required to achieve "extreme" performance.
10
+ 2. **Complex Performance Tuning:** Optimization strategies are often tightly coupled with model architectures, necessitating extensive and repetitive manual intervention.
11
+ 3. **Deficiency in Optimization Tools:** The infrastructure lacks sufficient mechanisms for computational graph-level optimizations, such as operator substitution and communication overlap.
12
+
13
+ MagiCompiler addresses these issues through the following approaches:
14
+
15
+ * **Addressing Challenge 1:** It adopts **whole-graph compilation**, thoroughly transcending the boundaries of `TransformerLayer` to maximize the scope of kernel fusion.
16
+ * **Addressing Challenge 2:** It integrates infrastructure optimizations directly into MagiCompiler, implementing features such as `AutoCudaGraph` and `AutoCheckpointing(WIP)`.
17
+ * **Addressing Challenge 3:** It leverages the dynamic-to-static capabilities provided by **Dynamo**, capturing `fx.graph` IR in eager mode to perform pass optimizations at the IR level.
18
+
19
+ #### Illustrative Example
20
+
21
+ ```python
22
+ from magi_compiler import magi_compile
23
+
24
+ @magi_compile()
25
+ class TinyModel(nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.linear = nn.Linear(1024, 1024, device="cuda")
29
+
30
+ @no_grad()
31
+ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
32
+ return self.linear(x + y - z + 1)
33
+
34
+
35
+ def magi_compiler_demo():
36
+ model = TinyModel()
37
+ x = torch.randn(1024, 1024, device="cuda")
38
+ y = torch.randn(1024, 1024, device="cuda")
39
+ z = torch.randn(1024, 1024, device="cuda")
40
+ model(x, y, z)
41
+ ```
42
+
43
+ **Optimized Code (Triton Kernel):**
44
+
45
+ ```python
46
+ triton_poi_fused_add_sub_0 = async_compile.triton('triton_poi_fused_add_sub_0', '''
47
+ import triton
48
+ import triton.language as tl
49
+
50
+ from torch._inductor.runtime import triton_helpers, triton_heuristics
51
+ from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
52
+ from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
53
+ triton_helpers.set_driver_to_gpu()
54
+
55
+ @triton_heuristics.pointwise(
56
+ size_hints={'x': 1048576},
57
+ filename=__file__,
58
+ triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
59
+ inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_sub_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'B8F4209CBFC2377D6AF9CF3C65D610CA2B56C138A443862350DE1E56F5BF54C3', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
60
+ min_elem_per_thread=0
61
+ )
62
+ @triton.jit
63
+ def triton_poi_fused_add_sub_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
64
+ xoffset = tl.program_id(0) * XBLOCK
65
+ xindex = xoffset + tl.arange(0, XBLOCK)[:]
66
+ xmask = xindex < xnumel
67
+ x0 = xindex
68
+ tmp0 = tl.load(in_ptr0 + (x0), xmask)
69
+ tmp1 = tl.load(in_ptr1 + (x0), xmask)
70
+ tmp3 = tl.load(in_ptr2 + (x0), xmask)
71
+ tmp2 = tmp0 + tmp1
72
+ tmp4 = tmp2 - tmp3
73
+ tmp5 = 1.0
74
+ tmp6 = tmp4 + tmp5
75
+ tl.store(out_ptr0 + (x0), tmp6, xmask)
76
+ ''', device_str='cuda')
77
+ ```
78
+
79
+ ### 1.2 Frontend (Dynamo)
80
+
81
+ ![Dynamo](./assets/why_magicompiler_1_dynamo.jpeg)
82
+
83
+ * **PyFrameObject (Dynamic Call Stack):**
84
+ * Represents the context environment during function execution. Python creates a new `PyFrameObject` for each function call.
85
+ * **PyCodeObject (Static Bytecode):**
86
+ * The compiled product of Python code, which is static and read-only. A single `PyCodeObject` exists regardless of how many times the function is invoked.
87
+
88
+ ```python
89
+ def f(x, mod):
90
+ for guard, transformed_code in f.compiled_entries:
91
+ if guard(x, mod):
92
+ return transformed_code(x, mod)
93
+ try:
94
+ guard, transformed_code = compile_and_optimize(x, mod)
95
+ f.compiled_entries.append([guard, transformed_code])
96
+ return transformed_code(x, mod)
97
+ except FailToCompileError:
98
+ y = mod(x)
99
+ z = torch.log(y)
100
+ return z
101
+ ```
102
+
103
+ #### Symbolic Shape
104
+
105
+ MagiCompiler specifically targets the Transformer architecture and supports custom `dynamic_arg_dims` (typically for `seq_len`).
106
+
107
+ **Example:**
108
+
109
+ ```python
110
+ @magi_compile(dynamic_arg_dims={"x": 0, "y": 0, "z": 0})
111
+ class TinyModel(nn.Module):
112
+ def __init__(self):
113
+ super().__init__()
114
+ self.linear = nn.Linear(1024, 1024, device="cuda")
115
+
116
+ @no_grad()
117
+ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
118
+ return self.linear(x + y - z + 1)
119
+ ```
120
+
121
+ **Guard Mechanism and Elimination in Symbolic Shape Deduction:**
122
+
123
+ ```log
124
+ I1204 16:31:35.745000 1859360 torch/_dynamo/symbolic_convert.py:3842] [0/0] Step 1: torchdynamo start tracing inner /usr/local/lib/python3.12/dist-packages/torch/_dynamo/external_utils.py:66
125
+ I1204 16:31:35.746000 1859360 torch/fx/experimental/symbolic_shapes.py:3775] [0/0] create_env
126
+ I1204 16:31:35.781000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s33 = 1024 for L['args'][0].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s33" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
127
+ I1204 16:31:35.785000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s6 = 1024 for L['args'][1].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s6" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
128
+ I1204 16:31:35.794000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s6) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s6)"
129
+ I1204 16:31:35.795000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s33 (solve) VR[2, int_oo]
130
+ I1204 16:31:35.800000 1859360 torch/fx/experimental/symbolic_shapes.py:5120] [0/0] create_symbol s21 = 1024 for L['args'][2].size()[0] [2, int_oo] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_dynamo/variables/builder.py:3501 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
131
+ I1204 16:31:35.806000 1859360 torch/fx/experimental/symbolic_shapes.py:7213] [0/0] eval Eq(s33, s21) [guard added] return self.linear(x + y - z + 1) # ome/niubility2/hongyu/athena/integration_test/scripts/linear_demo.py:50 in forward (_subclasses/fake_impls.py:1148 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s33, s21)"
132
+ I1204 16:31:35.807000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s33 = s21 (solve) VR[2, int_oo]
133
+ I1204 16:31:35.828000 1859360 torch/_dynamo/symbolic_convert.py:4059] [0/0] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
134
+ I1204 16:31:35.837000 1859360 torch/fx/experimental/symbolic_shapes.py:6792] [0/0] set_replacement s6 = s21 (find) VR[2, int_oo]
135
+ ```
136
+
137
+ ### 1.3 Backend (Inductor, MagiBackend, etc.)
138
+
139
+ ![Backend Architecture](./assets/why_magicompiler_2_arch.png)
140
+
141
+ MagiCompiler hijack the `torch.compile` logic through the following components:
142
+
143
+ * **`custom_partitioner_fn`:** Segments the forward and backward computational graphs and determines which intermediate results are transmitted to the backward pass.
144
+ * **`post_grad_custom_pre_pass`:** Performs pass optimizations at the whole-graph level (computational graph matching and rewriting).
145
+ * **`PartitionFunc`:** Implements custom subgraph partitioning logic, utilizing attention mechanisms as splitting points.
146
+
147
+ ![Partition](./assets/why_magicompiler_3_partition.png)
148
+
149
+ * **`post_grad_custom_post_pass`:** Executes pass optimizations at the subgraph level (computation/communication overlap).
150
+
151
+ ---
152
+
153
+ ## 2. Best Practices
154
+
155
+ ### 2.1 Model Adaptation
156
+
157
+ MagiCompiler has certain limitations, such as mandatory whole-graph capture and the inability to support implicit subgraph interruptions. Consequently, manual adaptation is required in specific scenarios:
158
+
159
+ **1. Computational Graph Dependencies or CPU/GPU Synchronization**
160
+
161
+ ```python
162
+ @magi_compile
163
+ class MeanModule(torch.nn.Module):
164
+ def __init__(self):
165
+ super().__init__()
166
+
167
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
168
+ x = x.cos().sin()
169
+ if x.mean() > 0.5:
170
+ x = x - 1
171
+ return x * y
172
+ ```
173
+
174
+ > **Note:** In typical Transformer models, certain pre/post-processing operations are unavoidable. Therefore, the recommended practice for `magi_compiler` is to perform **whole-graph capture at the `TransformerBlock` level**, as `TransformerBlock` computations constitute over 95% of the total workload.
175
+
176
+ **2. Custom Operators (e.g., FlashAttention, FlexFlashAttention, MoE kernels)**
177
+
178
+ * **Operator Registration:** A logic for operator registration is provided. Commonly used operators like FlashAttention (FA) and FlexFlashAttention (FFA) are already registered.
179
+
180
+ ```python
181
+ # Operator Registration
182
+ @torch.library.custom_op("athena::flash_attn_func", mutates_args=())
183
+ def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
184
+ ...
185
+
186
+ # Operator Deduce Function
187
+ @flash_attn_func.register_fake
188
+ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
189
+ return torch.empty_like(query)
190
+
191
+ # Call flash_attn_func
192
+ self_attn_out = torch.ops.athena.flash_attn_func(q, k, v)
193
+ out, _ = torch.ops.athena.flex_flash_attn_func(q, k, v, q_ranges=ffa_handler.q_ranges, k_ranges=ffa_handler.k_ranges)
194
+ ```
195
+
196
+ * **Unit Testing:** Independent unit tests for operators should be provided in the production environment.
197
+
198
+ ```python
199
+ @pytest.mark.parametrize("batch_size", [1])
200
+ @pytest.mark.parametrize("seq_len", [1024, 2048, 4096])
201
+ @pytest.mark.parametrize("query_head", [48])
202
+ @pytest.mark.parametrize("kv_head", [4, 8])
203
+ @pytest.mark.parametrize("head_dim", [128, 256])
204
+ def test_fake_fa3(batch_size, seq_len, query_head, kv_head, head_dim):
205
+ q = torch.randn((batch_size, seq_len, query_head, head_dim), device="cuda", dtype=torch.bfloat16)
206
+ k = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
207
+ v = torch.randn((batch_size, seq_len, kv_head, head_dim), device="cuda", dtype=torch.bfloat16)
208
+ torch.library.opcheck(torch.ops.athena.flash_attn_func, (q, k, v))
209
+ ```
210
+
211
+ ### 2.2 Debugging Methods
212
+
213
+ Key questions for debugging:
214
+ * Is the bug originating from the compiler?
215
+ * Which specific component of the compiler is causing the bug?
216
+
217
+ ![Debugging](./assets/why_magicompiler_4_debug.png)
218
+
219
+ ```python
220
+ class CompileConfig(BaseModel):
221
+ # Basic configs
222
+ backend: str = Field("inductor", description="Compilation backend.")
223
+ compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
224
+ ...
225
+
226
+ # Cudagraph configs
227
+ cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
228
+ ...
229
+
230
+ # Pass configs
231
+ pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
232
+ ...
233
+ ```
234
+
235
+ ### 2.3 Profiling Results
236
+
237
+ For further details, please refer to the [**Wan2.2 Benchmark**](Wan2.2Benchmark.md).
238
+
239
+ ---
240
+
241
+ ## References
242
+
243
+ 1. [PyTorch 2.0 Overview](https://docs.pytorch.org/assets/pytorch2-2.pdf)
244
+ 2. [TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation](https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)
245
+ 3. [Depyf Walkthrough](https://depyf.readthedocs.io/en/latest/walk_through.html)
246
+ 4. [Getting Started with PyTorch Compiler](https://docs.pytorch.org/docs/main/torch.compiler_get_started.html)
pkgs/MagiCompiler/docs/WhyMagiDepyf.md ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # magi_depyf
2
+
3
+ A structured inspector for `torch.compile` and MagiCompiler compilation
4
+ artifacts — decompiled source, Inductor kernels, guard conditions, graph break
5
+ chains, and more — all organized in a navigable directory tree.
6
+
7
+ ## Why
8
+
9
+ ### The problem: compilation is a black box
10
+
11
+ `torch.compile` and MagiCompiler accelerate models by transforming Python
12
+ functions through a deep pipeline: Dynamo captures bytecode into FX graphs,
13
+ a backend (Inductor, etc.) compiles them into optimized kernels, and the
14
+ runtime dispatches through a chain of cache entries, compiled functions, and
15
+ resume functions. The result is fast — but opaque.
16
+
17
+ When something goes wrong — a correctness bug, an unexpected graph break, a
18
+ performance cliff — you need to see what the compiler actually produced.
19
+ What bytecode did Dynamo generate? Which subgraphs went to Inductor vs.
20
+ eager fallback? What do the kernels look like? How do resume functions
21
+ chain together? MagiCompiler adds further layers: CUDA graph capture
22
+ regions, piecewise subgraph splits, and its own dispatch logic.
23
+
24
+ None of this is easily accessible.
25
+
26
+ ### depyf: a pioneering effort
27
+
28
+ [depyf](https://github.com/thuml/depyf) was the first tool to address this,
29
+ hooking into `torch._dynamo` to dump decompiled source, FX graphs, and
30
+ Inductor output. It made `torch.compile` significantly more transparent.
31
+
32
+ ### Why a new tool?
33
+
34
+ magi_depyf is purpose-built for MagiCompiler's compilation stack, and takes
35
+ a fundamentally different approach from depyf:
36
+
37
+ | | depyf | magi_depyf |
38
+ |-|-------|------------|
39
+ | **When artifacts are collected** | During compilation, via monkey-patching internal hooks | **After** compilation completes, by walking the final CacheEntry chain — a single, clean post-hoc pass |
40
+ | **Output structure** | Flat files (`full_code_0.py`, `__transformed_code_0_for_xxx.py`, …) — hard to navigate for complex models | **Hierarchical directory tree** mirroring the compilation structure: function → entries → compiled\_fns / resume\_fns |
41
+ | **MagiCompiler support** | None | First-class: per-subgraph Inductor source, CUDA graph mode, piecewise split metadata |
42
+ | **Decompiler** | Monolithic class supporting Python 3.8–3.12 | Modular handler registry; focused on 3.10+ |
43
+
44
+ ### Key features
45
+
46
+ **See everything the compiler produced, in one structured tree.**
47
+ One context manager call gives you a complete, navigable dump: decompiled
48
+ bytecode (before and after Dynamo), Inductor kernel source for every compiled
49
+ function, guard conditions, bytecode metadata (`co_flags`, `co_consts`,
50
+ `dis` output), and the full resume function chain — recursively.
51
+
52
+ **MagiCompiler-native.**
53
+ Understands MagiCompiler's backend, extracting per-subgraph Inductor source,
54
+ CUDA graph capture mode (full / piecewise), and split metadata that
55
+ `torch.compile`-only tools cannot see.
56
+
57
+ **Post-hoc introspection.**
58
+ Artifacts are collected after compilation finishes, by walking the CacheEntry
59
+ linked list and extracting what Dynamo and the backend actually produced.
60
+ No monkey-patching of internal compilation hooks, no interference with the
61
+ compilation process itself.
62
+
63
+ ## Usage
64
+
65
+ ### `dump_src` — the main entry point
66
+
67
+ ```python
68
+ import torch
69
+ from magi_compiler.magi_depyf.inspect import dump_src
70
+
71
+ @torch.compile
72
+ def toy_example(a, b):
73
+ x = a / (torch.abs(a) + 1)
74
+ if b.sum() < 0:
75
+ b = b * -1
76
+ return x * b
77
+
78
+ with dump_src("./output"):
79
+ for _ in range(100):
80
+ toy_example(torch.randn(10), torch.randn(10))
81
+ ```
82
+
83
+ This produces:
84
+
85
+ ```
86
+ output/
87
+ toy_example/
88
+ overview.md # Navigable index with links to everything
89
+ decompiled_code.py # Original function source
90
+ bytecode_info.txt # CodeType metadata + dis output
91
+ entry_0/
92
+ decompiled_code.py # Dynamo-transformed bytecode → Python
93
+ bytecode_info.txt # Transformed code metadata
94
+ guards.txt # Guard conditions for this cache entry
95
+ compiled_fns/
96
+ __compiled_fn_1_xxx.py # FX graph (readable)
97
+ __compiled_fn_1_xxx_post_grad.py # Post-grad graph
98
+ __compiled_fn_1_xxx_runnable.py # Inductor kernel source
99
+ resume_fns/
100
+ __resume_at_94_2/ # Resume function after graph break
101
+ overview.md
102
+ decompiled_code.py # Resume function source
103
+ bytecode_info.txt
104
+ entry_0/ # Dynamo compiles resume fns too
105
+ decompiled_code.py
106
+ guards.txt
107
+ compiled_fns/
108
+ ...
109
+ __resume_at_104_3/
110
+ ...
111
+ ```
112
+
113
+ ### Programmatic API
114
+
115
+ ```python
116
+ from magi_compiler.magi_depyf import decompile
117
+
118
+ # Decompile a code object to Python source
119
+ source = decompile(my_function.__code__)
120
+
121
+ # Introspect a compiled function
122
+ from magi_compiler.magi_depyf.inspect import Introspector
123
+ info = Introspector.build_function_info(fn, fn_globals=fn.__globals__)
124
+ # info.entries[0].decompiled_src — decompiled transformed code
125
+ # info.entries[0].compiled_fns — backend-compiled functions
126
+ # info.entries[0].resume_fns — resume functions after graph breaks
127
+ ```
128
+
129
+ ### Tested model architectures
130
+
131
+ The test suite verifies the decompile → recompile round-trip on real model
132
+ structures, ensuring the decompiler produces correct source for Dynamo output:
133
+
134
+ | Category | Models |
135
+ |----------|--------|
136
+ | **PyTorch core** | MLP, Conv-BN-ReLU, MultiheadAttention, TransformerEncoderLayer, Embedding, residual blocks, depthwise separable conv |
137
+ | **Diffusion blocks** | GEGLU, RMSNorm, sinusoidal embeddings, cross-attention, AdaLayerNorm, DiT blocks, timestep MLP |
138
+ | **HuggingFace transformers** | BERT, GPT-2, T5 encoder (tiny configs) |
139
+ | **HuggingFace diffusers** | Attention (self / cross), BasicTransformerBlock |
140
+ | **timm** | ResNet-18, MobileNetV3, EfficientNet-B0, ViT, ConvNeXt, Swin, DeiT |
141
+ | **Graph breaks** | `print()` breaks, explicit `graph_break()`, multi-break chains — with recursive resume function round-tripping |
142
+
143
+ ## Code structure
144
+
145
+ ```
146
+ magi_depyf/
147
+ ├── __init__.py # Public API: decompile, safe_decompile
148
+
149
+ ├── decompile/ # Bytecode → Python source (no torch dependency)
150
+ │ ├── decompiler.py # Decompiler: orchestrates the pipeline
151
+ │ ├── recompiler.py # CodeRecompiler: decompile → compile() → CodeType
152
+ │ ├── bytecode/
153
+ │ │ ├── instruction.py # Mutable wrapper over dis.Instruction
154
+ │ │ ├── source_emitter.py # Stack machine + source accumulator
155
+ │ │ ├── decompile_context.py # Read-only context for handlers
156
+ │ │ ├── handler_registry.py # Opcode → handler dispatch table
157
+ │ │ └── handlers/ # One module per opcode category
158
+ │ └── postprocess/ # Source-level cleanup passes
159
+
160
+ └── inspect/ # torch.compile introspection (requires torch)
161
+ ├── dump_src.py # dump_src(): the main entry point
162
+ ├── introspect.py # Introspector: walk CacheEntry chain
163
+ ├── model.py # Data model (FunctionInfo, EntryInfo, ...)
164
+ ├── writer.py # Serialize to directory tree
165
+ ├── session.py # CaptureSession: bytecode hook lifecycle
166
+ └── result.py # CaptureResult: one compilation event
167
+ ```
168
+
169
+ ## Compatibility
170
+
171
+ | Requirement | Version |
172
+ |-------------|---------|
173
+ | **Python** | >= 3.10 |
174
+ | **PyTorch** | >= 2.0 (requires `torch._dynamo` internals) |
175
+ | **depyf** | Optional; used as fallback by `safe_decompile` |
pkgs/MagiCompiler/docs/assets/submod_0_rank_0.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:583d97460eb7ebf48efbdeb7a6ae424f640de9ff99e6f33bbadef432583f40d3
3
+ size 16122
pkgs/MagiCompiler/magi_compiler/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .api import magi_compile
16
+
17
+ __all__ = ["magi_compile"]
pkgs/MagiCompiler/magi_compiler/_cache_data_cls.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+
17
+
18
+ @dataclasses.dataclass(frozen=True)
19
+ class CacheHandle:
20
+ key: str | None
21
+ path: str
22
+
23
+
24
+ @dataclasses.dataclass(frozen=True)
25
+ class CacheEntry:
26
+ runtime_shape: int | None
27
+ graph_index: int
28
+ backend_name: str
pkgs/MagiCompiler/magi_compiler/api.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import gc
17
+ import inspect
18
+ import os
19
+ from contextlib import contextmanager
20
+ from typing import Callable, TypeVar, get_args, get_origin, overload
21
+ from unittest.mock import patch
22
+
23
+ import magi_compiler.utils.envs as envs
24
+ import torch
25
+ from magi_compiler.cuda.cudart import pin_memory_in_place
26
+ from magi_compiler.magi_compiler_base import MagiCompilerBase
27
+ from magi_compiler.utils import compilation_counter, magi_logger
28
+ from magi_compiler.utils.compile_time_monitor import CompileMonitor
29
+ from torch import distributed as dist
30
+ from torch import nn
31
+ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
32
+
33
+ from .config import CompileConfig, CompileMode, get_compile_config
34
+
35
+
36
+ # =============================================================================
37
+ # Workaround: TorchInductor autotune get_raw_stream
38
+ # =============================================================================
39
+ # TorchInductor autotune code blocks may reference get_raw_stream() without
40
+ # defining it, causing "name 'get_raw_stream' is not defined" at runtime.
41
+ # Register it as a builtin so the exec'd autotune snippets can always find it.
42
+ def _patch_get_raw_stream():
43
+ try:
44
+ import builtins
45
+
46
+ from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
47
+ except Exception:
48
+ return
49
+ if not hasattr(builtins, "get_raw_stream"):
50
+ builtins.get_raw_stream = _get_raw_stream
51
+
52
+
53
+ _patch_get_raw_stream()
54
+
55
+ # =============================================================================
56
+ # Dynamo Config Isolation
57
+ # =============================================================================
58
+ # Capture the default dynamo config at module load time (before any torch.compile).
59
+ # This ensures we have a "clean" baseline config that hasn't been modified by
60
+ # external torch.compile calls (e.g., with dynamic=True).
61
+ _DEFAULT_DYNAMO_CONFIG: dict = torch._dynamo.config.get_config_copy()
62
+
63
+
64
+ @contextmanager
65
+ def _isolated_dynamo_config():
66
+ """
67
+ Context manager that provides an isolated dynamo config environment.
68
+ """
69
+ with torch._dynamo.config.patch(**_DEFAULT_DYNAMO_CONFIG):
70
+ yield
71
+
72
+
73
+ _T = TypeVar("_T", bound=type[nn.Module])
74
+ _W = TypeVar("_W", bound="MagiCompilerBase")
75
+
76
+
77
+ @overload
78
+ def magi_compile(*, enable_if: Callable[None, bool] | None = None) -> Callable[[_T], _T]:
79
+ ...
80
+
81
+
82
+ @overload
83
+ def magi_compile(*, dynamic_arg_dims: dict[str, int | list[int]] | None) -> Callable[[_T], _T]:
84
+ ...
85
+
86
+
87
+ @overload
88
+ def magi_compile(*, config_patch: Callable[[CompileConfig], CompileConfig] | None = None) -> Callable[[_T], _T]:
89
+ ...
90
+
91
+
92
+ @overload
93
+ def magi_compile(cls: _T) -> _T:
94
+ ...
95
+
96
+
97
+ def magi_compile(
98
+ cls: _T | None = None,
99
+ *,
100
+ model_tag: str | None = None,
101
+ dynamic_arg_dims: dict[str, int | list[int]] | None = None,
102
+ enable_if: Callable[None, bool] | None = None,
103
+ config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
104
+ ) -> Callable[[_T], _T] | _T:
105
+ """
106
+ A decorator to add support for compiling the forward method of a class.
107
+
108
+ Usage:
109
+ 1. use directly as a decorator without arguments:
110
+ ```python
111
+ @magi_compile
112
+ class MyModel(nn.Module):
113
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
114
+ ```
115
+
116
+ 2. use as a decorator with arguments:
117
+ ```python
118
+ @magi_compile(dynamic_arg_dims={"x": 0, "y": 0})
119
+ class MyModel(nn.Module):
120
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
121
+ ```
122
+
123
+ Arguments:
124
+ - model_tag: optional tag in cache path (e.g. "wan_ti2v"). If not set, class name is used.
125
+ Path segment: model_{idx}_{model_tag}_rank_{rank}.
126
+ - dynamic_arg_dims: a dictionary that maps argument names to the dynamic
127
+ dimensions of the argument. The dynamic dimensions can be either a single
128
+ integer or a list of integers.
129
+ - enable_if: a function that returns a boolean value indicating whether to compile the model or not.
130
+ This is useful if you want to compile the model only when certain conditions are met.
131
+
132
+ Notes:
133
+ - dynamic_arg_dims will be inferred from the type annotation of the forward method if not provided,
134
+ if the argument is annotated as `torch.Tensor` or `Optional[torch.Tensor]`,
135
+ the first dimension will be marked as dynamic.
136
+
137
+ - if an argument is `None`, it should always be passed as `None` during
138
+ the lifetime of the model, otherwise, it cannot be captured as a single
139
+ computation graph.
140
+
141
+ """
142
+
143
+ def cls_decorator_helper(cls: _T) -> _T:
144
+ nonlocal dynamic_arg_dims
145
+ dynamic_arg_dims = dynamic_arg_dims or _infer_dynamic_arg_dims(cls)
146
+
147
+ # Accuracy check
148
+ assert hasattr(cls, "forward"), "decorated class should have a forward method."
149
+ assert len(dynamic_arg_dims) > 0, (
150
+ "No dynamic dimensions found in the forward method of " f"{cls}. Please provide dynamic_arg_dims explicitly."
151
+ )
152
+ for k in dynamic_arg_dims:
153
+ assert k in inspect.signature(cls.forward).parameters, f"Argument {k} not found in the forward method of {cls}"
154
+
155
+ return _magi_compile(cls, dynamic_arg_dims, enable_if, config_patch, model_tag=model_tag)
156
+
157
+ if cls is not None:
158
+ # use `magi_compile` as a decorator without arguments, cls is the class to be decorated
159
+ assert isinstance(cls, type)
160
+ return cls_decorator_helper(cls)
161
+
162
+ return cls_decorator_helper
163
+
164
+
165
+ def offload(obj):
166
+ if isinstance(obj, torch.Tensor):
167
+ return obj.cpu()
168
+ elif isinstance(obj, dict):
169
+ return {k: offload(v) for k, v in obj.items()}
170
+ elif isinstance(obj, (list, tuple)):
171
+ return type(obj)(offload(item) for item in obj)
172
+ return obj
173
+
174
+
175
+ def _magi_compile(
176
+ cls: _T,
177
+ dynamic_arg_dims: dict[str, int | list[int]],
178
+ enable_if: Callable[None, bool] | None = None,
179
+ config_patch: Callable[[CompileConfig], CompileConfig] | None = None,
180
+ model_tag: str | None = None,
181
+ ) -> _T:
182
+ """
183
+ A decorator to add support for compiling the forward method of a class.
184
+ """
185
+ if MagiCompilerBase in cls.__bases__:
186
+ return cls
187
+
188
+ # take care of method resolution order, make sure super().__init__ is called on the base class
189
+ # other than MagiCompilerBase
190
+ cls.__bases__ = cls.__bases__ + (MagiCompilerBase,)
191
+
192
+ if get_compile_config().offload_config.model_cpu_offload:
193
+ magi_logger.info(f"Enabling CPU offload for {cls}")
194
+ _orig_apply = cls._apply
195
+
196
+ def _cpu_apply(self, fn):
197
+ if getattr(self, "_magi_offloaded_once", False):
198
+ return _orig_apply(self, fn)
199
+
200
+ # First, move all parameters/buffers to CPU
201
+ def _force_cpu(t):
202
+ return fn(t).cpu()
203
+
204
+ _orig_apply(self, _force_cpu)
205
+
206
+ # create shared memory tensors for all parameters/buffers on CPU
207
+ if dist.is_initialized():
208
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
209
+ full_state_dict = self.state_dict()
210
+
211
+ grouped_params = {} # {dtype: [(name, tensor), ...]}
212
+ for name, tensor in full_state_dict.items():
213
+ if tensor.device.type == 'cpu':
214
+ dt = tensor.dtype
215
+ if dt not in grouped_params:
216
+ grouped_params[dt] = []
217
+ grouped_params[dt].append((name, tensor))
218
+
219
+ shared_state_dict = {}
220
+ self._magi_giant_buffers = []
221
+
222
+ dist.barrier()
223
+
224
+ for dtype, param_list in grouped_params.items():
225
+ dtype_str = str(dtype).split('.')[-1]
226
+ shared_bin_path = (
227
+ f"{envs.MAGI_SHARED_BIN_PATH}/magi_model_shared_{dtype_str}_{self.__class__.__name__}.bin"
228
+ )
229
+
230
+ total_numel = sum(t.numel() for _, t in param_list)
231
+
232
+ if local_rank == 0:
233
+ flat_buffer = torch.zeros(total_numel, dtype=dtype)
234
+ offset = 0
235
+ for _, tensor in param_list:
236
+ numel = tensor.numel()
237
+ flat_buffer[offset : offset + numel].copy_(tensor.view(-1))
238
+ offset += numel
239
+
240
+ if dtype == torch.bfloat16:
241
+ flat_buffer.view(torch.int16).numpy().tofile(shared_bin_path)
242
+ elif dtype.itemsize == 1 and dtype.is_floating_point:
243
+ # fp8
244
+ flat_buffer.view(torch.uint8).numpy().tofile(shared_bin_path)
245
+ else:
246
+ flat_buffer.numpy().tofile(shared_bin_path)
247
+
248
+ del flat_buffer
249
+ gc.collect()
250
+
251
+ dist.barrier()
252
+
253
+ giant_shared_tensor = torch.from_file(
254
+ shared_bin_path, shared=True, size=total_numel, dtype=dtype, device="cpu"
255
+ )
256
+ self._magi_giant_buffers.append(giant_shared_tensor)
257
+
258
+ pin_memory_in_place(giant_shared_tensor)
259
+
260
+ offset = 0
261
+ for name, original_tensor in param_list:
262
+ numel = original_tensor.numel()
263
+ shared_param = giant_shared_tensor[offset : offset + numel].view(original_tensor.shape)
264
+
265
+ if original_tensor.requires_grad:
266
+ shared_param.requires_grad_(True)
267
+
268
+ shared_state_dict[name] = shared_param
269
+ offset += numel
270
+
271
+ dist.barrier()
272
+ if local_rank == 0:
273
+ if os.path.exists(shared_bin_path):
274
+ os.remove(shared_bin_path)
275
+
276
+ self.load_state_dict(shared_state_dict, assign=True)
277
+
278
+ else:
279
+
280
+ def _pinner(t):
281
+ return t.pin_memory()
282
+
283
+ _orig_apply(self, _pinner)
284
+
285
+ self._magi_offloaded_once = True
286
+ return self
287
+
288
+ cls._apply = _cpu_apply
289
+
290
+ old_init = cls.__init__
291
+
292
+ def __init__(self: _W, *args, **kwargs):
293
+ old_init(self, *args, **kwargs)
294
+ compile_config = get_compile_config()
295
+ if config_patch is not None:
296
+ compile_config = config_patch(compile_config)
297
+ # deepcopy the compile config to avoid modifying the original compile config
298
+ self.compile_config = copy.deepcopy(compile_config)
299
+
300
+ enable_compile = enable_if is None or enable_if()
301
+ self.enable_compile = self.compile_config.compile_mode != CompileMode.NONE and enable_compile
302
+ if not self.enable_compile:
303
+ return
304
+
305
+ compilation_counter.num_models_seen += 1
306
+ self.compile_config.model_idx = compilation_counter.num_models_seen
307
+ self.compile_config.model_tag = model_tag if model_tag is not None else self.__class__.__name__
308
+ MagiCompilerBase.__init__(self, compile_config=self.compile_config)
309
+
310
+ cls.__init__ = __init__
311
+
312
+ old_call = cls.__call__
313
+
314
+ def __call__(self: _W, *args, **kwargs):
315
+ ### Step1: Run compiled module directly if disable compile or captured before ###
316
+ if self.compile_config.offload_config.model_cpu_offload and self.compiled_code is None:
317
+ args = offload(args)
318
+ kwargs = offload(kwargs)
319
+
320
+ if not self.enable_compile or torch.compiler.is_compiling():
321
+ # Skip compiling the model if inside the compilation process.
322
+ return old_call(self, *args, **kwargs)
323
+
324
+ if self.compiled_code is not None:
325
+ # Run the compiled function if compiled code is available.
326
+ with self.dispatch_to_compiled_fwd(mode="jit"):
327
+ return old_call(self, *args, **kwargs)
328
+
329
+ if envs.MAGI_AOT_COMPILE:
330
+ # Try load AOT artifacts from cache and run directly.
331
+ self.aot_compiled_fn = self.try_load_aot_compile_artifacts()
332
+ if self.aot_compiled_fn is not None:
333
+ with self.dispatch_to_compiled_fwd(mode="aot"):
334
+ return old_call(self, *args, **kwargs)
335
+
336
+ ### Step2: Mark dynamic shapes for the first compilation ###
337
+ bound_args = inspect.signature(self.__class__.forward).bind(self, *args, **kwargs)
338
+ bound_args.apply_defaults()
339
+ for k, dims in dynamic_arg_dims.items():
340
+ arg = bound_args.arguments.get(k)
341
+ if arg is None:
342
+ continue
343
+ dims = [dims] if isinstance(dims, int) else dims
344
+ assert isinstance(arg, torch.Tensor), f"Unsupported dynamic dim {dims} for argument {k} with type {type(arg)}."
345
+ dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
346
+ torch._dynamo.mark_dynamic(arg, dims)
347
+
348
+ ### Step3: Start compiling the model ###
349
+ magi_logger.info(f"Start compiling function {self.original_code_object}")
350
+
351
+ CompileMonitor().start(
352
+ self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path()
353
+ )
354
+ # Dynamo reuse the compilation across instances, but we need to make sure the compiled code is not reused.
355
+ torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
356
+
357
+ with (
358
+ _hijack_inline_call_to_collect_traced_files(self),
359
+ patch.object(torch.compiler.config, "dynamic_sources", self.compile_config.dynamic_sources),
360
+ patch.object(torch._dynamo.config, "enable_cpp_symbolic_shape_guards", False),
361
+ # 允许 mark_dynamic 在 module 属性链上的 tensor 生效
362
+ # (默认 True 会强制 module property tensor 为 static shape,忽略 mark_dynamic)
363
+ patch.object(torch._dynamo.config, "force_nn_module_property_static_shapes", False),
364
+ patch.dict(
365
+ os.environ, {"TORCHINDUCTOR_CACHE_DIR": (self.compile_config.cache_dump_path() / "inductor_cache").as_posix()}
366
+ ),
367
+ ):
368
+ if envs.MAGI_AOT_COMPILE:
369
+ self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
370
+ self.aot_compiled_fn.save_compiled_function(self.aot_compilation_path)
371
+ with self.dispatch_to_compiled_fwd(mode="aot"):
372
+ output = old_call(self, *args, **kwargs)
373
+ else:
374
+ with patch.object(self, "forward", self.jit_compile):
375
+ output = old_call(self, *args, **kwargs)
376
+
377
+ return output
378
+
379
+ # 使用 @torch.compiler.disable 和 _isolated_dynamo_config 包裹整个 __call__
380
+ # 确保 magi compile 在外部嵌套 torch.compile 时也能独立工作不受影响
381
+ isolated_call = _isolated_dynamo_config()(__call__)
382
+ cls.__call__ = torch.compiler.disable(isolated_call)
383
+ return cls
384
+
385
+
386
+ # Collect all relevant files traced by Dynamo, re-compile the model when any of these files change.
387
+ # 1. the file containing the top-level forward function
388
+ # 2. hijack function to know all the functions called during Dynamo tracing, every time Dynamo sees a function call, it will inline
389
+ # the function by calling InliningInstructionTranslator.inline_call_
390
+ def _hijack_inline_call_to_collect_traced_files(owner: _W):
391
+ owner.compile_config.traced_files.add(owner.original_code_object.co_filename)
392
+ inline_call = InliningInstructionTranslator.inline_call_
393
+
394
+ def patched_inline_call(self_):
395
+ code = self_.f_code
396
+ owner.compile_config.traced_files.add(code.co_filename)
397
+ return inline_call(self_)
398
+
399
+ return patch.object(InliningInstructionTranslator, "inline_call_", patched_inline_call)
400
+
401
+
402
+ def _infer_dynamic_arg_dims(cls: _T) -> dict[str, int | list[int]]:
403
+ sig = inspect.signature(cls.forward)
404
+ inferred_dynamic_arg_dims = {}
405
+ for k, v in sig.parameters.items():
406
+ if v.annotation in [torch.Tensor, torch.Tensor | None]:
407
+ inferred_dynamic_arg_dims[k] = 0
408
+
409
+ magi_logger.info(f"Inferred dynamic dimensions for forward method of {cls}: {list(inferred_dynamic_arg_dims.keys())}")
410
+ return inferred_dynamic_arg_dims
411
+
412
+
413
+ def _get_num_outputs_from_return_annotation(fn: Callable) -> int:
414
+ """
415
+ Get the number of outputs from the function's return type annotation.
416
+
417
+ Returns:
418
+ - 1 if the return type is a single Tensor
419
+ - N if the return type is tuple[Tensor, Tensor, ...] with N elements
420
+ - 1 if no annotation or unrecognized annotation (default to single output)
421
+ """
422
+ sig = inspect.signature(fn)
423
+ return_annotation = sig.return_annotation
424
+
425
+ if return_annotation is inspect.Parameter.empty:
426
+ return 1
427
+
428
+ # Check if it's a tuple type (e.g., tuple[Tensor, Tensor])
429
+ origin = get_origin(return_annotation)
430
+ if origin is tuple:
431
+ args = get_args(return_annotation)
432
+ # Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...])
433
+ if args and args[-1] is not ...:
434
+ return len(args)
435
+ return 1
436
+
437
+ return 1
438
+
439
+
440
+ def _generate_op_name(fn: Callable) -> str:
441
+ """
442
+ Generate a unique operator name from function's name and source file.
443
+
444
+ The generated name follows the format: namespace::op_name
445
+ - namespace: derived from the source file path (module-like structure)
446
+ - op_name: the function name
447
+
448
+ Example:
449
+ Function `_my_custom_op` in file `/path/to/my_module.py`
450
+ -> "my_module::_my_custom_op"
451
+ """
452
+ import re
453
+ from pathlib import Path
454
+
455
+ func_name = fn.__name__
456
+
457
+ # Get the source file path
458
+ try:
459
+ source_file = inspect.getfile(fn)
460
+ # Extract the file stem (without extension) as namespace
461
+ namespace = Path(source_file).stem
462
+ # Clean up namespace: replace invalid characters with underscores
463
+ namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace)
464
+ except (TypeError, OSError):
465
+ # If we can't get the source file, use a default namespace
466
+ namespace = "magi_custom"
467
+
468
+ return f"{namespace}::{func_name}"
469
+
470
+
471
+ def _create_identity_meta_fn(fn: Callable) -> Callable:
472
+ """
473
+ Create a default identity meta function for the given function.
474
+
475
+ This identity meta function assumes that:
476
+ - The number of outputs is determined by the function's return type annotation
477
+ - Each output's metadata (shape, dtype, device) matches the corresponding input tensor
478
+
479
+ For example, if a function has signature:
480
+ def my_op(a: Tensor, b: Tensor, scale: float) -> tuple[Tensor, Tensor]:
481
+ The identity meta function will return:
482
+ (torch.empty_like(a), torch.empty_like(b))
483
+ """
484
+ num_outputs = _get_num_outputs_from_return_annotation(fn)
485
+ sig = inspect.signature(fn)
486
+ # Get parameter names, excluding 'self' if present
487
+ param_names = [name for name in sig.parameters.keys() if name != "self"]
488
+
489
+ def identity_meta_fn(*args, **kwargs):
490
+ # Bind arguments to get a mapping of param_name -> value
491
+ bound = sig.bind(*args, **kwargs)
492
+ bound.apply_defaults()
493
+
494
+ # Collect the first `num_outputs` tensor arguments
495
+ tensor_args = []
496
+ for name in param_names:
497
+ arg = bound.arguments.get(name)
498
+ if isinstance(arg, torch.Tensor):
499
+ tensor_args.append(arg)
500
+ if len(tensor_args) >= num_outputs:
501
+ break
502
+
503
+ if len(tensor_args) < num_outputs:
504
+ raise ValueError(
505
+ f"identity_meta_fn requires at least {num_outputs} tensor inputs to match "
506
+ f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. "
507
+ f"Please provide a custom infer_output_meta_fn."
508
+ )
509
+
510
+ # Return outputs with same metadata as the first N inputs
511
+ if num_outputs == 1:
512
+ return torch.empty_like(tensor_args[0])
513
+ return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs])
514
+
515
+ return identity_meta_fn
516
+
517
+
518
+ def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable:
519
+ """
520
+ Create a meta function that returns torch.empty_like() for each specified parameter.
521
+
522
+ This is useful when output tensors have the same shape/dtype/device as specific input
523
+ parameters, but not necessarily in positional order.
524
+
525
+ Example:
526
+ param_names = ["weight", "bias"]
527
+ def my_op(grad: Tensor, weight: Tensor, bias: Tensor) -> tuple[Tensor, Tensor]:
528
+ ...
529
+
530
+ Generated meta function returns:
531
+ (torch.empty_like(weight), torch.empty_like(bias))
532
+ """
533
+ sig = inspect.signature(fn)
534
+
535
+ def meta_fn(*args, **kwargs):
536
+ # Bind arguments to get a mapping of param_name -> value
537
+ bound = sig.bind(*args, **kwargs)
538
+ bound.apply_defaults()
539
+
540
+ # Collect tensors for each specified parameter name
541
+ tensor_outputs = []
542
+ for name in param_names:
543
+ if name not in bound.arguments:
544
+ raise ValueError(
545
+ f"Parameter '{name}' not found in function signature. "
546
+ f"Available parameters: {list(bound.arguments.keys())}"
547
+ )
548
+ arg = bound.arguments[name]
549
+ if not isinstance(arg, torch.Tensor):
550
+ raise ValueError(
551
+ f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). "
552
+ f"infer_output_meta_fn list should only contain tensor parameter names."
553
+ )
554
+ tensor_outputs.append(torch.empty_like(arg))
555
+
556
+ # Return single tensor or tuple based on number of outputs
557
+ if len(tensor_outputs) == 1:
558
+ return tensor_outputs[0]
559
+ return tuple(tensor_outputs)
560
+
561
+ return meta_fn
562
+
563
+
564
+ def magi_register_custom_op(
565
+ name: str | None = None,
566
+ mutates_args: tuple[str, ...] = (),
567
+ infer_output_meta_fn: Callable | list[str] | None = None,
568
+ setup_context_fn: Callable | None = None,
569
+ backward_fn: Callable | None = None,
570
+ ):
571
+ """
572
+ A unified decorator to register a custom operator with PyTorch's library.
573
+
574
+ This decorator combines the functionality of:
575
+ - @torch.library.custom_op
576
+ - @torch.library.register_fake
577
+ - fn.register_autograd
578
+
579
+ Arguments:
580
+ name: The fully qualified name of the operator (e.g., "namespace::op_name").
581
+ If None, auto-generated from the function name and source file.
582
+ mutates_args: Tuple of argument names that are mutated by the operator.
583
+ infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing.
584
+ - None (default): Assumes each output has the same metadata as the corresponding
585
+ input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.).
586
+ - list[str]: Parameter names whose metadata to use for outputs.
587
+ E.g., ["weight", "bias"] means output[0] has same shape as `weight`,
588
+ output[1] has same shape as `bias`.
589
+ - Callable: Custom function with same signature as the op, returns torch.empty_like()
590
+ tensors matching the expected output shapes.
591
+ setup_context_fn: Function to save tensors/values for backward.
592
+ Signature: setup_context_fn(ctx, inputs, output)
593
+ backward_fn: Function to compute gradients.
594
+ Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients
595
+
596
+ Returns:
597
+ The registered custom operator function.
598
+
599
+ Examples:
600
+ 1. Basic usage (forward only, auto-generated name and meta function):
601
+
602
+ >>> @magi_register_custom_op()
603
+ ... def my_relu(x: torch.Tensor) -> torch.Tensor:
604
+ ... return torch.maximum(x, torch.zeros_like(x))
605
+
606
+ 2. Multiple outputs with explicit output metadata via parameter names:
607
+
608
+ >>> @magi_register_custom_op(
609
+ ... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias
610
+ ... )
611
+ ... def compute_gradients(
612
+ ... grad_output: torch.Tensor,
613
+ ... weight: torch.Tensor,
614
+ ... bias: torch.Tensor,
615
+ ... ) -> tuple[torch.Tensor, torch.Tensor]:
616
+ ... grad_weight = grad_output.sum(dim=0).view_as(weight)
617
+ ... grad_bias = grad_output.sum(dim=0).view_as(bias)
618
+ ... return grad_weight, grad_bias
619
+
620
+ 3. Full custom op with autograd support:
621
+
622
+ >>> def _square_meta(x: torch.Tensor) -> torch.Tensor:
623
+ ... return torch.empty_like(x)
624
+ ...
625
+ >>> def _square_setup_context(ctx, inputs, output):
626
+ ... (x,) = inputs
627
+ ... ctx.save_for_backward(x)
628
+ ...
629
+ >>> def _square_backward(ctx, grad_output):
630
+ ... (x,) = ctx.saved_tensors
631
+ ... return grad_output * 2 * x
632
+ ...
633
+ >>> @magi_register_custom_op(
634
+ ... name="my_ops::square",
635
+ ... infer_output_meta_fn=_square_meta,
636
+ ... setup_context_fn=_square_setup_context,
637
+ ... backward_fn=_square_backward,
638
+ ... )
639
+ ... def square(x: torch.Tensor) -> torch.Tensor:
640
+ ... return x * x
641
+ """
642
+
643
+ def decorator(fn: Callable) -> Callable:
644
+ # Auto-generate name if not provided
645
+ op_name = name if name is not None else _generate_op_name(fn)
646
+
647
+ # Step 1: Register the custom op with torch.library.custom_op
648
+ registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn)
649
+
650
+ # Step 2: Register the output meta inference function
651
+ # Determine meta_fn based on the type of infer_output_meta_fn
652
+ if infer_output_meta_fn is None:
653
+ meta_fn = _create_identity_meta_fn(fn)
654
+ elif isinstance(infer_output_meta_fn, list):
655
+ meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn)
656
+ else:
657
+ meta_fn = infer_output_meta_fn
658
+ torch.library.register_fake(op_name)(meta_fn)
659
+
660
+ # Step 3: Register autograd if backward_fn is provided
661
+ if backward_fn is not None:
662
+ registered_op.register_autograd(backward_fn, setup_context=setup_context_fn)
663
+
664
+ return registered_op
665
+
666
+ return decorator
pkgs/MagiCompiler/magi_compiler/compile_artifacts.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
17
+
18
+ import inspect
19
+ import pickle
20
+ from unittest.mock import patch
21
+
22
+ import torch
23
+ from torch.utils._pytree import tree_map_only
24
+
25
+ try:
26
+ from torch._dynamo.aot_compile import SerializableCallable
27
+ except ImportError:
28
+ SerializableCallable = object
29
+
30
+ assert isinstance(SerializableCallable, type)
31
+
32
+
33
+ class MagiSerializableFunction(SerializableCallable):
34
+ """
35
+ A wrapper around a compiled function by vllm. It will forward the tensor
36
+ inputs to the compiled function and return the result.
37
+ It also implements a serialization interface to support PyTorch's precompile
38
+ with custom backend, so that we can save and load the compiled function on
39
+ disk. There's no need to wrap around the compiled function if we don't want
40
+ to serialize them in particular cases.
41
+ Right now serialization for the custom backend is done via
42
+ serializing the Dynamo fx graph plus example inputs.
43
+ """
44
+
45
+ def __init__(self, graph_module, example_inputs, model_tag, optimized_call):
46
+ assert isinstance(graph_module, torch.fx.GraphModule)
47
+ self.graph_module = graph_module
48
+ self.example_inputs = example_inputs
49
+ self.model_tag = model_tag
50
+ self.optimized_call = optimized_call
51
+ self.shape_env = None
52
+ sym_input = next((i for i in self.example_inputs if isinstance(i, torch.SymInt)), None)
53
+ if sym_input is not None:
54
+ self.shape_env = sym_input.node.shape_env
55
+
56
+ def __call__(self, *args, **kwargs):
57
+ return self.optimized_call(*args, **kwargs)
58
+
59
+ @classmethod
60
+ def serialize_compile_artifacts(cls, compiled_fn: "MagiSerializableFunction") -> bytes:
61
+ import sympy
62
+ from torch._subclasses import FakeTensorMode
63
+ from torch.fx._graph_pickler import GraphPickler, Options
64
+
65
+ state = compiled_fn.__dict__.copy()
66
+ state.pop("optimized_call")
67
+ state.pop("shape_env")
68
+ for node in state["graph_module"].graph.nodes:
69
+ node.meta.pop("source_fn_stack", None)
70
+ node.meta.pop("nn_module_stack", None)
71
+
72
+ graph_reducer_override = GraphPickler.reducer_override
73
+
74
+ def _graph_reducer_override(self, obj):
75
+ if inspect.isclass(obj) and issubclass(obj, sympy.Function) and hasattr(obj, "_torch_unpickler"):
76
+ return obj._torch_unpickler, (obj._torch_handler_name,)
77
+ if isinstance(obj, FakeTensorMode):
78
+ return type(None), ()
79
+ return graph_reducer_override(self, obj)
80
+
81
+ # Mask off tensor inputs since they are large and not needed.
82
+ state["example_inputs"] = tree_map_only(torch.Tensor, lambda _: None, state["example_inputs"])
83
+ with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
84
+ state["graph_module"] = GraphPickler.dumps(state["graph_module"], Options(ops_filter=None))
85
+ state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
86
+ return pickle.dumps(state)
87
+
88
+ @classmethod
89
+ def deserialize_compile_artifacts(cls, data: bytes) -> "MagiSerializableFunction":
90
+ from torch._guards import TracingContext, tracing
91
+ from torch._subclasses import FakeTensorMode
92
+ from torch.fx._graph_pickler import GraphPickler
93
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
94
+
95
+ from .config import get_compile_config
96
+ from .magi_backend import MagiBackend
97
+
98
+ state = pickle.loads(data)
99
+ fake_mode = FakeTensorMode(shape_env=ShapeEnv())
100
+ state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
101
+ state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
102
+ magi_backend = MagiBackend(get_compile_config(), state["model_tag"])
103
+
104
+ def optimized_call(*example_inputs):
105
+ """
106
+ On the first run of the optimized call, we rerun the compiler
107
+ backend which should result in a cache hit. After the backend
108
+ call returns, we just do a one-time replacement of the optimized
109
+ call with the compiled function, so that subsequent calls are on
110
+ the AOT compiled path.
111
+ """
112
+ compile_inputs = [inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)]
113
+ with tracing(TracingContext(fake_mode)):
114
+ fn.optimized_call = magi_backend(state["graph_module"], compile_inputs).optimized_call
115
+ return fn.optimized_call(*example_inputs)
116
+
117
+ fn = cls(**state, optimized_call=optimized_call)
118
+ return fn
119
+
120
+ @property
121
+ def co_name(self):
122
+ """
123
+ Used for depyf debugging.
124
+ """
125
+ return "MagiSerializableFunction"
pkgs/MagiCompiler/magi_compiler/config.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from enum import Enum, unique
18
+ from pathlib import Path
19
+ from typing import Any, Literal
20
+
21
+ import torch
22
+ from pydantic import BaseModel, Field
23
+ from pydantic_settings import BaseSettings, SettingsConfigDict
24
+
25
+ from .utils import OrderedSet, compute_hash, magi_logger
26
+
27
+
28
+ @unique
29
+ class CompileMode(Enum):
30
+ """
31
+ The compilation approach used for torch.compile-based compilation of the model.
32
+
33
+ NONE: No torch.compile compilation is applied, model runs in fully eager pytorch mode. The model runs as-is.
34
+ TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
35
+ MAGI_COMPILE: Custom Inductor-based backend with caching, piecewise compilation, shape specialization, and custom passes.
36
+ """
37
+
38
+ NONE = 'NONE'
39
+ TORCH_COMPILE = 'TORCH_COMPILE'
40
+ MAGI_COMPILE = 'MAGI_COMPILE'
41
+
42
+
43
+ @unique
44
+ class CudaGraphMode(Enum):
45
+ """
46
+ Constants for the cudagraph mode in CompileConfig.
47
+ Different from the CUDAGraphMode for llm, PIECEWISE and FULL modes are enough for diffusion models.
48
+
49
+ NONE: No cudagraph is used.
50
+ PIECEWISE: Cudagraph is used for piecewise compilation.
51
+ FULL: Cudagraph is used for full compilation.
52
+ """
53
+
54
+ NONE = 'NONE'
55
+ PIECEWISE = 'PIECEWISE'
56
+ FULL = 'FULL'
57
+
58
+
59
+ class PassConfig(BaseModel):
60
+ """Configuration for custom Inductor passes"""
61
+
62
+ enable_fusion: bool = Field(False, description="Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.")
63
+ enable_attn_fusion: bool = Field(False, description="Whether to enable the custom attention+quant fusion pass.")
64
+ enable_noop: bool = Field(False, description="Whether to enable the custom no-op elimination pass.")
65
+ enable_sequence_parallelism: bool = Field(False, description="Whether to enable sequence parallelism.")
66
+ enable_async_tp: bool = Field(False, description="Whether to enable async TP.")
67
+ enable_fi_allreduce_fusion: bool = Field(False, description="Whether to enable flashinfer allreduce fusion.")
68
+ enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.")
69
+ fi_allreduce_fusion_max_token_num: int = Field(
70
+ 16384, description="Max number of tokens to used in flashinfer allreduce fusion."
71
+ )
72
+
73
+ def __post_init__(self) -> None:
74
+ if not self.enable_noop:
75
+ if self.enable_fusion:
76
+ magi_logger.warning(
77
+ "Fusion enabled but reshape elimination disabled. " "RMSNorm/SiluMul + quant (fp8) fusion might not work"
78
+ )
79
+ if self.enable_attn_fusion:
80
+ magi_logger.warning(
81
+ "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work"
82
+ )
83
+
84
+ @property
85
+ def hash(self) -> str:
86
+ return compute_hash(self.model_dump(mode="json"))
87
+
88
+ # Compatible with torch pass
89
+ def uuid(self) -> str:
90
+ return self.hash
91
+
92
+
93
+ @unique
94
+ class RecomputePolicy(Enum):
95
+ """
96
+ Defines the strategy for activation recomputation (rematerialization) to trade off
97
+ memory usage against computational overhead.
98
+
99
+ HANDCRAFT:
100
+ A manual strategy where the user controls the trade-off via a `memory_budget`
101
+ parameter. This parameter acts as a threshold (0.0 to 1.0) determining the
102
+ target percentage of activations to save.
103
+
104
+ HEURISTIC:
105
+ A rule-based strategy that selectively saves activations from compute-bound
106
+ operators (e.g., MatMul, Attention). Conversely, outputs from memory-bound
107
+ or element-wise operators are prioritized for recomputation to save memory.
108
+
109
+ AUTOSEARCH:
110
+ An automated strategy that searches for the optimal set of saved tensors based
111
+ on available device memory. It prioritizes saving tensors with high computational
112
+ cost relative to their memory footprint.
113
+
114
+ .. note::
115
+ Currently, a `repeat_number` argument is required to stabilize the profiling/search
116
+ phase. This requirement is temporary and will be deprecated once full-graph
117
+ capture is natively supported.
118
+ """
119
+
120
+ HANDCRAFT = "HANDCRAFT"
121
+ HEURISTIC = "HEURISTIC"
122
+ AUTOSEARCH = "AUTOSEARCH"
123
+
124
+
125
+ class RecomputeConfig(BaseModel):
126
+ recompute_policy: RecomputePolicy = Field(RecomputePolicy.HEURISTIC, description="Recompute policy.")
127
+ memory_budget: float = Field(0.5, description="Activation memory budget for recomputation, only used for handcraft.")
128
+ repeat_number: int = Field(default=1, description="Repeat number for recomputation, only used for autosearch.")
129
+
130
+
131
+ @unique
132
+ class OffloadPolicy(Enum):
133
+ """
134
+ The policy for offloading the model to CPU.
135
+
136
+ BASE:
137
+ The base policy for offloading the model to CPU.
138
+ Offload all the submodules to CPU.
139
+ COST_EFFECTIVE:
140
+ The cost effective policy for offloading the model to CPU.
141
+ Offload the submodules to CPU based on the cost effective policy.
142
+ HEURISTIC:
143
+ The heuristic policy for offloading the model to CPU.
144
+ Offload the submodules to CPU based on the heuristic policy.
145
+ """
146
+
147
+ BASE = "BASE"
148
+ COST_EFFECTIVE = "COST_EFFECTIVE"
149
+ HEURISTIC = "HEURISTIC"
150
+
151
+
152
+ class OffloadConfig(BaseModel):
153
+ model_cpu_offload: bool = Field(False, description="Whether to offload the model to CPU.")
154
+ gpu_resident_weight_ratio: float = Field(
155
+ 0.3, description="The ratio of GPU memory to keep when offloading the model to CPU."
156
+ )
157
+ offload_policy: OffloadPolicy = Field(
158
+ OffloadPolicy.COST_EFFECTIVE, description="The policy for offloading the model to CPU."
159
+ )
160
+ bandwidth_safety_factor: float = Field(0.9, description="The safety factor for the H2D bandwidth.")
161
+
162
+
163
+ class CompileConfig(BaseSettings):
164
+ model_config = SettingsConfigDict(cli_parse_args=True, cli_ignore_unknown_args=True, cli_implicit_flags=True)
165
+
166
+ # Basic configs
167
+ backend: Literal["inductor", "eager"] = Field("inductor", description="Compilation backend.")
168
+ compile_mode: CompileMode = Field(CompileMode.MAGI_COMPILE, description="Compilation mode.")
169
+ cache_root_dir: str = Field(
170
+ default=os.path.expanduser("~/.cache/magi_compiler"), description="Directory to cache the compiled model."
171
+ )
172
+ dynamic_sources: str = Field(
173
+ default=os.environ.get("TORCH_COMPILE_DYNAMIC_SOURCES", ""),
174
+ description="Comma delimited list of sources that should be marked as dynamic.",
175
+ )
176
+
177
+ # CPU Offload
178
+ offload_config: OffloadConfig = Field(OffloadConfig(), description="Offload configuration.")
179
+
180
+ # Inductor configs
181
+ # TODO(hongyu): Add unittest for compile_sizes
182
+ compile_sizes: list[int] = Field(default_factory=list, description="Sizes to compile the model for.")
183
+ use_inductor_graph_partition: bool = Field(
184
+ False, description="Whether to use inductor graph partition. Not fully supported yet."
185
+ )
186
+ # TODO(hongyu): Find a better way to specify the splitting ops.
187
+ splitting_ops: list[str] = Field(
188
+ default_factory=lambda: [
189
+ "athena::flash_attn_func",
190
+ "athena::flex_flash_attn_func",
191
+ "athena::sage_attn_func",
192
+ "athena::flash_attn_with_cp",
193
+ "athena::flex_flash_attn_with_cp",
194
+ ],
195
+ description="Operators to split the graph into piecewise graphs.",
196
+ )
197
+
198
+ # Pass configs
199
+ pass_config: PassConfig = Field(PassConfig(), description="Pass configuration.")
200
+
201
+ # Recompute configs
202
+ recompute_config: RecomputeConfig = Field(RecomputeConfig(), description="Recompute configuration.")
203
+
204
+ # Cudagraph configs
205
+ cudagraph_mode: CudaGraphMode = Field(CudaGraphMode.NONE, description="Cudagraph mode.")
206
+ cudagraph_copy_inputs: bool = Field(True, description="Whether to copy inputs for cudagraph.")
207
+
208
+ # Runtime configs, maybe changed at runtime
209
+ model_idx: int = Field(0, description="Index of the model.")
210
+ model_tag: str | None = Field(
211
+ default=None, description="Tag in cache path: model_{idx}_{model_tag}_rank_{rank}. Class name if unset."
212
+ )
213
+ inductor_compile_config: dict[str, Any] = Field(default_factory=dict, description="Inductor compilation configuration.")
214
+ traced_files: OrderedSet[str] = Field(default_factory=OrderedSet, description="Files traced by Dynamo.")
215
+
216
+ def _model_rank_dir_name(self) -> str:
217
+ """Directory name for this model instance: model_{idx}[_{model_tag}]_rank_{rank}."""
218
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
219
+ if self.model_tag:
220
+ return f"model_{self.model_idx}_{self.model_tag}_rank_{rank}"
221
+ return f"model_{self.model_idx}_rank_{rank}"
222
+
223
+ def debug_dump_path(self) -> Path:
224
+ return Path(self.cache_root_dir) / "magi_depyf" / self._model_rank_dir_name()
225
+
226
+ def cache_dump_path(self) -> Path:
227
+ return Path(self.cache_root_dir) / "torch_compile_cache" / self._model_rank_dir_name()
228
+
229
+ @property
230
+ def hash(self) -> str:
231
+ # Create a copy of the config data for serialization
232
+ data = self.model_dump(mode="json", exclude={"inductor_compile_config"})
233
+
234
+ # Handle inductor_compile_config separately to serialize objects with uuid() method
235
+ # This is a workaround to support serialization of PostGradPassManager in Pydantic models.
236
+ if self.inductor_compile_config:
237
+ serialized_inductor_config = {}
238
+ for key, value in self.inductor_compile_config.items():
239
+ # If the value has a uuid() method (like PostGradPassManager), use it
240
+ if hasattr(value, "uuid") and callable(getattr(value, "uuid", None)):
241
+ try:
242
+ serialized_inductor_config[key] = value.uuid()
243
+ except (AttributeError, RuntimeError):
244
+ # Fallback to string representation if uuid() fails
245
+ serialized_inductor_config[key] = str(value)
246
+ else:
247
+ # For other types, try to serialize normally
248
+ try:
249
+ # Try to serialize as JSON-serializable
250
+ json.dumps(value)
251
+ serialized_inductor_config[key] = value
252
+ except (TypeError, ValueError):
253
+ # If not JSON-serializable, use string representation
254
+ serialized_inductor_config[key] = str(value)
255
+ data["inductor_compile_config"] = serialized_inductor_config
256
+
257
+ return compute_hash(data)
258
+
259
+ def __str__(self, indent: int = 4):
260
+ data = self.model_dump(mode="json")
261
+ formatted = json.dumps(data, indent=indent, ensure_ascii=False, sort_keys=False)
262
+
263
+ # add configuration class name as title
264
+ class_name = self.__class__.__name__
265
+ return f"{class_name}:\n{formatted}".replace('"', "")
266
+
267
+ def __repr__(self, indent: int = 4):
268
+ return self.__str__(indent=indent)
269
+
270
+
271
+ _GLOBAL_COMPILE_CONFIG = None
272
+
273
+
274
+ def get_compile_config() -> CompileConfig:
275
+ global _GLOBAL_COMPILE_CONFIG
276
+ if _GLOBAL_COMPILE_CONFIG is None:
277
+ _GLOBAL_COMPILE_CONFIG = CompileConfig()
278
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
279
+ # 仅在首次初始化时打印一次编译配置,默认 WARNING 级别不会输出
280
+ magi_logger.info("compile config: %s", _GLOBAL_COMPILE_CONFIG)
281
+ assert _GLOBAL_COMPILE_CONFIG is not None, "compile config is not initialized"
282
+ return _GLOBAL_COMPILE_CONFIG
pkgs/MagiCompiler/magi_compiler/cuda/cudart.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import ctypes
17
+ import os
18
+
19
+ import torch
20
+
21
+ _cudart = None
22
+
23
+
24
+ def init_cudart():
25
+ global _cudart
26
+ if _cudart is not None:
27
+ return _cudart
28
+ candidates = ["libcudart.so", "libcudart.so.11.0", "libcudart.so.12.0, libcudart.so.13"]
29
+ try:
30
+ cuda_path = os.path.dirname(torch.utils.cpp_extension._find_cuda_home())
31
+ candidates.append(os.path.join(cuda_path, "lib64", "libcudart.so"))
32
+ except:
33
+ pass
34
+ for lib in candidates:
35
+ try:
36
+ _cudart = ctypes.CDLL(lib)
37
+ return _cudart
38
+ except OSError:
39
+ continue
40
+ return None
41
+
42
+
43
+ def pin_memory_in_place(tensor: torch.Tensor):
44
+ """
45
+ Pin memory in-place using cudaHostRegister.
46
+ """
47
+ if tensor.is_cuda:
48
+ return tensor
49
+ cudart = init_cudart()
50
+ if cudart is None:
51
+ return tensor
52
+
53
+ ptr = tensor.data_ptr()
54
+ size = tensor.numel() * tensor.element_size()
55
+ res = cudart.cudaHostRegister(ctypes.c_void_p(ptr), ctypes.c_size_t(size), 0)
56
+
57
+ if res == 0:
58
+ return tensor
59
+ else:
60
+ raise RuntimeError(f"cudaHostRegister failed with error code {res}")
pkgs/MagiCompiler/magi_compiler/cuda_graph_mgr.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, fields, is_dataclass
16
+ from functools import wraps
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+
21
+ from .utils import magi_logger, nvtx
22
+
23
+
24
+ class InplaceSubstituteFakeClass:
25
+ """
26
+ The class which inherits from this class will not be replaced with a new instance,
27
+ but the attributes will be updated in-place.
28
+ For example, InferenceParams.
29
+ """
30
+
31
+ pass
32
+
33
+
34
+ @dataclass
35
+ class FakeTensor:
36
+ shape: Tuple[int, ...] = None
37
+ dtype: str = None
38
+ device: str = None
39
+
40
+
41
+ @dataclass
42
+ class HashableDataclass:
43
+ _cached_hash: Optional[int] = None
44
+
45
+ @nvtx.instrument_nvtx
46
+ def _get_hashable_fields(self) -> Tuple[Any, ...]:
47
+ hashable_values = []
48
+ for f in fields(self):
49
+ if f.name == "_cached_hash":
50
+ continue
51
+ value = getattr(self, f.name)
52
+ if value is None:
53
+ continue
54
+ if isinstance(value, HashableDataclass):
55
+ hashable_values.append(value._get_cached_hash())
56
+ elif isinstance(value, tuple):
57
+ tuple_vals = []
58
+ for item in value:
59
+ if isinstance(item, (HashableDataclass, str, int, float, bool)):
60
+ if isinstance(item, HashableDataclass):
61
+ tuple_vals.append(item._get_cached_hash())
62
+ else:
63
+ tuple_vals.append(item)
64
+ if tuple_vals:
65
+ hashable_values.append(tuple(tuple_vals))
66
+ elif isinstance(value, (str, int, float, bool)):
67
+ hashable_values.append(value)
68
+ return tuple(hashable_values)
69
+
70
+ @nvtx.instrument_nvtx
71
+ def _compute_hash(self) -> int:
72
+ """Computes a hash value based on the dataclass's hashable fields."""
73
+ hashable_fields = self._get_hashable_fields()
74
+ return hash(hashable_fields) % (1 << 64) # 限制为 64 位
75
+
76
+ @nvtx.instrument_nvtx
77
+ def _get_cached_hash(self) -> int:
78
+ if self._cached_hash is None:
79
+ self._cached_hash = self._compute_hash()
80
+ return self._cached_hash
81
+
82
+ @nvtx.instrument_nvtx
83
+ def __hash__(self) -> int:
84
+ return self._get_cached_hash()
85
+
86
+ @nvtx.instrument_nvtx
87
+ def __eq__(self, other: Any) -> bool:
88
+ if not isinstance(other, self.__class__):
89
+ return False
90
+ if self._get_cached_hash() != other._get_cached_hash():
91
+ return False
92
+ return True
93
+
94
+
95
+ @dataclass(unsafe_hash=True)
96
+ class LiteralsInfo(HashableDataclass):
97
+ literals: Tuple[Any, ...] = tuple()
98
+
99
+
100
+ @dataclass(unsafe_hash=True)
101
+ class TensorStaticInfo(HashableDataclass):
102
+ name: str = ""
103
+ shapes: Tuple[int, ...] = tuple()
104
+ dtype: str = ""
105
+
106
+
107
+ @dataclass(unsafe_hash=True)
108
+ class TensorDynamicInfo(HashableDataclass):
109
+ name: str = ""
110
+ shapes: Tuple[int, ...] = tuple()
111
+
112
+
113
+ @dataclass(unsafe_hash=True)
114
+ class StaticSignature(HashableDataclass):
115
+ func_name: str = ""
116
+ tensor_static_infos: Tuple[TensorStaticInfo, ...] = tuple()
117
+
118
+
119
+ @dataclass(unsafe_hash=True)
120
+ class DynamicSignature(HashableDataclass):
121
+ tensor_dynamic_infos: Tuple[TensorDynamicInfo, ...] = tuple()
122
+ literals_info: LiteralsInfo = None
123
+
124
+
125
+ @dataclass
126
+ class GraphEntry:
127
+ graph: Optional[torch.cuda.CUDAGraph] = None
128
+ inconsistent: bool = False
129
+ invalid: bool = False
130
+
131
+
132
+ @dataclass
133
+ class OutputTemplateEntry:
134
+ graph_entry_dict: Dict[int, GraphEntry] = None # key = layer_number
135
+ output_template: Any = None # 用于存储输出对象literals的结构模板
136
+
137
+
138
+ @dataclass
139
+ class StaticTensorEntry:
140
+ input_tensors: Optional[List[torch.Tensor]] = None
141
+ output_tensors: Optional[List[torch.Tensor]] = None
142
+ template_entry_dict: Dict[DynamicSignature, OutputTemplateEntry] = None
143
+
144
+
145
+ class ArgsUtils:
146
+ @staticmethod
147
+ @nvtx.instrument_nvtx
148
+ def generate_both_signatures_from_tensors(
149
+ func_name: str, tensors: List[torch.Tensor], names: List[str], literals: List[Any]
150
+ ) -> Tuple[StaticSignature, DynamicSignature]:
151
+ num_tensors = len(tensors)
152
+ tensor_static_infos = [TensorStaticInfo() for _ in range(num_tensors)]
153
+ tensor_dynamic_infos = [TensorDynamicInfo() for _ in range(num_tensors)]
154
+
155
+ # Local references for performance
156
+ TensorStaticInfo_setattr = TensorStaticInfo.__setattr__
157
+ TensorDynamicInfo_setattr = TensorDynamicInfo.__setattr__
158
+ _tuple = tuple
159
+
160
+ for i in range(num_tensors):
161
+ t = tensors[i]
162
+ t_dim = t.dim()
163
+ t_shape = t.shape
164
+ t_dtype_str = str(t.dtype)
165
+ # Last dimension is static, others are dynamic (except for 1D tensor)
166
+ static_shapes = (
167
+ _tuple(-1 if idx != t_dim - 1 else dim_size for idx, dim_size in enumerate(t_shape)) if t_dim > 1 else (-1,)
168
+ )
169
+ static_info = tensor_static_infos[i]
170
+ TensorStaticInfo_setattr(static_info, "shapes", static_shapes)
171
+ TensorStaticInfo_setattr(static_info, "dtype", t_dtype_str)
172
+
173
+ dynamic_shapes = static_shapes = (
174
+ _tuple(-1 if idx == t_dim - 1 else dim_size for idx, dim_size in enumerate(t_shape))
175
+ if t_dim > 1
176
+ else _tuple(t_shape)
177
+ )
178
+ dynamic_info = tensor_dynamic_infos[i]
179
+ TensorDynamicInfo_setattr(dynamic_info, "shapes", dynamic_shapes)
180
+
181
+ literals_info = LiteralsInfo(literals=_tuple(literals))
182
+ static_sig = StaticSignature(func_name=func_name, tensor_static_infos=_tuple(tensor_static_infos))
183
+ dynamic_sig = DynamicSignature(tensor_dynamic_infos=_tuple(tensor_dynamic_infos), literals_info=literals_info)
184
+ return static_sig, dynamic_sig
185
+
186
+ @staticmethod
187
+ @nvtx.instrument_nvtx
188
+ def replace_sliced_with_static(obj: Any, static_tensors: List[torch.Tensor]) -> Any:
189
+ tensor_idx = 0
190
+
191
+ def recursive_replace(o: Any) -> Any:
192
+ nonlocal tensor_idx
193
+ if isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter):
194
+ # Copy data to the corresponding static tensor slice
195
+ static_tensor = static_tensors[tensor_idx]
196
+ slices = [slice(None)] * static_tensor.ndim
197
+ for i in range(min(o.ndim, static_tensor.ndim)):
198
+ slices[i] = slice(0, o.shape[i])
199
+ # Only copy if the data_ptrs are different
200
+ if not o.data_ptr() == static_tensor[tuple(slices)].data_ptr():
201
+ static_tensor[tuple(slices)].copy_(o)
202
+ tensor_idx += 1
203
+ return static_tensor[tuple(slices)]
204
+
205
+ elif isinstance(o, dict):
206
+ return {k: recursive_replace(v) for k, v in o.items()}
207
+ elif isinstance(o, (list, tuple)):
208
+ return type(o)(recursive_replace(item) for item in o)
209
+ elif is_dataclass(o):
210
+ field_values = {f.name: recursive_replace(getattr(o, f.name)) for f in fields(o)}
211
+ return type(o)(**field_values)
212
+ elif issubclass(o.__class__, InplaceSubstituteFakeClass):
213
+ # Do not create a new instance, but modify attributes in place (to keep original initialization logic)
214
+ for k, v in o.__dict__.items():
215
+ if not callable(v):
216
+ o.__dict__[k] = recursive_replace(v)
217
+ return o
218
+ elif o is None or isinstance(o, (int, float, str, bool)):
219
+ return o # Keep None and basic types
220
+ else:
221
+ return o
222
+
223
+ return recursive_replace(obj)
224
+
225
+ @staticmethod
226
+ @nvtx.instrument_nvtx
227
+ def replace_sliced_with_static_simple(
228
+ sliced_tensors: List[torch.Tensor], static_tensors: List[torch.Tensor]
229
+ ) -> List[torch.Tensor]:
230
+ for sliced_tensor, static_tensor in zip(sliced_tensors, static_tensors):
231
+ if not sliced_tensor.data_ptr() == static_tensor.data_ptr():
232
+ slices = [slice(None)] * static_tensor.ndim
233
+ for i in range(sliced_tensor.ndim):
234
+ slices[i] = slice(0, sliced_tensor.shape[i])
235
+ static_tensor[tuple(slices)].copy_(sliced_tensor)
236
+
237
+ @staticmethod
238
+ @nvtx.instrument_nvtx
239
+ def replace_static_with_sliced(obj: Any, static_tensors: List[torch.Tensor]) -> Any:
240
+ tensor_idx = 0
241
+
242
+ def recursive_replace(o: Any) -> Any:
243
+ nonlocal tensor_idx
244
+ if (isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)) or isinstance(o, FakeTensor):
245
+ # Replace with the corresponding sliced tensor
246
+ static_tensor = static_tensors[tensor_idx]
247
+ shape_to_slice = o.shape
248
+ slices = [slice(0, dim_size) for dim_size in shape_to_slice]
249
+ result_tensor = static_tensor[tuple(slices)]
250
+ tensor_idx += 1
251
+ return result_tensor
252
+
253
+ elif isinstance(o, dict):
254
+ return {k: recursive_replace(v) for k, v in o.items()}
255
+ elif isinstance(o, (list, tuple)):
256
+ return type(o)(recursive_replace(item) for item in o)
257
+ elif is_dataclass(o):
258
+ field_values = {f.name: recursive_replace(getattr(o, f.name)) for f in fields(o)}
259
+ return type(o)(**field_values)
260
+ elif issubclass(o.__class__, InplaceSubstituteFakeClass):
261
+ # Do not create a new instance, but modify attributes in place (to keep original initialization logic)
262
+ for k, v in o.__dict__.items():
263
+ if not callable(v):
264
+ o.__dict__[k] = recursive_replace(v)
265
+ return o
266
+ elif o is None or isinstance(o, (int, float, str, bool)):
267
+ return o # Keep None and basic types
268
+ else:
269
+ return o
270
+
271
+ return recursive_replace(obj)
272
+
273
+ @staticmethod
274
+ @nvtx.instrument_nvtx
275
+ def try_fx_extract_core(
276
+ obj: Any, extract_tensors: bool = True, extract_literals: bool = True, with_names: bool = False
277
+ ) -> Tuple[List[torch.Tensor], List[str], List[Any]]:
278
+ failed_tuple = None, None, None
279
+ tensors = []
280
+ names = []
281
+ literals = []
282
+
283
+ if not isinstance(obj, dict) or "args" not in obj or "kwargs" not in obj:
284
+ return failed_tuple
285
+ args, kwargs = obj["args"], obj["kwargs"]
286
+ if kwargs:
287
+ return failed_tuple
288
+ if not isinstance(args, (list, tuple)):
289
+ return failed_tuple
290
+
291
+ for idx, item in enumerate(args):
292
+ if extract_tensors and isinstance(item, torch.Tensor) and not isinstance(item, torch.nn.Parameter):
293
+ tensors.append(item)
294
+ elif extract_literals and isinstance(item, (int, float, str, bool)):
295
+ literals.append(item)
296
+
297
+ names = [""] * len(tensors)
298
+ return tensors, names, literals
299
+
300
+ @staticmethod
301
+ @nvtx.instrument_nvtx
302
+ def recursive_extract_core(
303
+ obj: Any, extract_tensors: bool = True, extract_literals: bool = True, with_names: bool = False
304
+ ) -> Tuple[List[torch.Tensor], List[str], List[Any]]:
305
+ tensors = []
306
+ names = []
307
+ literals = []
308
+
309
+ def recursive_traverse(o: Any, prefix: str = ""):
310
+ # 1. Extract tensors (if enabled)
311
+ if extract_tensors and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter):
312
+ tensors.append(o)
313
+ names.append(prefix) if with_names else None
314
+ elif extract_literals and isinstance(o, (int, float, str, bool)):
315
+ literals.append(o) if extract_literals else None
316
+ elif isinstance(o, dict):
317
+ for k, v in o.items():
318
+ new_prefix = f"{prefix}.{k}" if (with_names and extract_tensors) else prefix
319
+ recursive_traverse(v, new_prefix)
320
+ elif isinstance(o, (list, tuple)):
321
+ for idx, item in enumerate(o):
322
+ new_prefix = f"{prefix}[{idx}]" if (with_names and extract_tensors) else prefix
323
+ recursive_traverse(item, new_prefix)
324
+ elif is_dataclass(o):
325
+ for f in fields(o):
326
+ new_prefix = f"{prefix}.{f.name}" if (with_names and extract_tensors) else prefix
327
+ recursive_traverse(getattr(o, f.name), new_prefix)
328
+ elif issubclass(o.__class__, InplaceSubstituteFakeClass):
329
+ for k, v in o.__dict__.items():
330
+ if not callable(v):
331
+ new_prefix = f"{prefix}.{k}" if (with_names and extract_tensors) else prefix
332
+ recursive_traverse(v, new_prefix)
333
+ elif o is None:
334
+ pass
335
+ else:
336
+ pass
337
+
338
+ recursive_traverse(obj)
339
+ return tensors, names if with_names else [""] * len(tensors), literals if extract_literals else None
340
+
341
+ @staticmethod
342
+ @nvtx.instrument_nvtx
343
+ def extract_output_template(obj: Any) -> Any:
344
+ def recursive_template(o: Any) -> Any:
345
+ if isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter):
346
+ return FakeTensor(shape=list(o.shape), dtype=str(o.dtype), device=str(o.device))
347
+ elif isinstance(o, dict):
348
+ return {k: recursive_template(v) for k, v in o.items()}
349
+ elif isinstance(o, (list, tuple)):
350
+ return type(o)(recursive_template(item) for item in o)
351
+ elif is_dataclass(o):
352
+ field_values = {f.name: recursive_template(getattr(o, f.name)) for f in fields(o)}
353
+ return type(o)(**field_values)
354
+ elif issubclass(o.__class__, InplaceSubstituteFakeClass):
355
+ # 不重新创建实例,直接修改属性(保持原有初始化逻辑)
356
+ for k, v in o.__dict__.items():
357
+ if not callable(v):
358
+ o.__dict__[k] = recursive_template(v)
359
+ return o
360
+ elif o is None or isinstance(o, (int, float, str, bool)):
361
+ return o
362
+ else:
363
+ return o
364
+
365
+ return recursive_template(obj)
366
+
367
+
368
+ class CudaGraphMgr:
369
+ """CUDA Graph Manager for caching and managing CUDA Graphs and static tensors."""
370
+
371
+ def __init__(self):
372
+ self.cache: Dict[StaticSignature, StaticTensorEntry] = dict()
373
+ self.graph_mem_pool: Optional[torch.cuda.graph_pool_handle] = None
374
+ self.check_output_inconsistency = False # Not enabled by default
375
+
376
+ @property
377
+ def graph_count(self) -> int:
378
+ count = 0
379
+ for tensor_entry in self.cache.values():
380
+ if tensor_entry.template_entry_dict is not None:
381
+ for template_entry in tensor_entry.template_entry_dict.values():
382
+ for graph_entry in template_entry.graph_entry_dict.values():
383
+ if graph_entry.graph is not None and not graph_entry.inconsistent and not graph_entry.invalid:
384
+ count += 1
385
+ return count
386
+
387
+ @property
388
+ def tensor_entry_count(self) -> int:
389
+ count = 0
390
+ for tensor_entry in self.cache.values():
391
+ if tensor_entry.input_tensors is not None and tensor_entry.output_tensors is not None:
392
+ count += 1
393
+ return count
394
+
395
+ @property
396
+ def graph_mem_pool_size(self) -> float:
397
+ if not hasattr(self, "graph_mem_pool") or self.graph_mem_pool is None:
398
+ return 0.0
399
+ pool_stats = torch.cuda.memory.memory_stats(self.graph_mem_pool)
400
+ used_mem = pool_stats.get("allocated_bytes.all.current", 0)
401
+ return used_mem / (1024 * 1024) # 转换为MB
402
+
403
+ @property
404
+ def tensor_mem_size(self) -> float:
405
+ total_size = 0 # 字节
406
+ for tensor_entry in self.cache.values():
407
+ if tensor_entry.input_tensors is not None:
408
+ for t in tensor_entry.input_tensors:
409
+ total_size += t.element_size() * t.nelement()
410
+ if tensor_entry.output_tensors is not None:
411
+ for t in tensor_entry.output_tensors:
412
+ total_size += t.element_size() * t.nelement()
413
+ return total_size / (1024 * 1024) # 转换为MB
414
+
415
+ @nvtx.instrument_nvtx
416
+ def formatted_cache_str(self) -> str:
417
+ """Format the cache content as a string for debugging."""
418
+ lines = []
419
+ for static_sig, tensor_entry in self.cache.items():
420
+ lines.append(f"StaticSignature: {static_sig}")
421
+ s = " Input Static Tensors: "
422
+ for it in tensor_entry.input_tensors:
423
+ s += f"[shape={list(it.shape)},dtype={str(it.dtype)}] "
424
+ lines.append(s)
425
+ s = " Output Static Tensors: "
426
+ for ot in tensor_entry.output_tensors:
427
+ s += f"[shape={list(ot.shape)},dtype={str(ot.dtype)}] "
428
+ lines.append(s)
429
+ if tensor_entry.template_entry_dict is not None:
430
+ for dynamic_sig, template_entry in tensor_entry.template_entry_dict.items():
431
+ lines.append(f" DynamicSignature: {dynamic_sig}")
432
+ lines.append(f" Output Template: {template_entry.output_template}")
433
+ for layer_number, graph_entry in template_entry.graph_entry_dict.items():
434
+ status = "Valid"
435
+ if graph_entry.inconsistent:
436
+ status = "Inconsistent"
437
+ elif graph_entry.invalid:
438
+ status = "Invalid"
439
+ lines.append(f" Layer {layer_number}: Graph Status: {status}")
440
+ return "\n".join(lines)
441
+
442
+ @nvtx.instrument_nvtx
443
+ def try_get_cuda_graph(
444
+ self, static_sig: StaticSignature, dynamic_sig: DynamicSignature, layer_number: int
445
+ ) -> Optional[torch.cuda.CUDAGraph]:
446
+ graph_entry = self.try_get_graph_entry(static_sig, dynamic_sig, layer_number)
447
+ if (
448
+ graph_entry is not None
449
+ and graph_entry.graph is not None
450
+ and not graph_entry.inconsistent
451
+ and not graph_entry.invalid
452
+ ):
453
+ return graph_entry.graph
454
+ return None
455
+
456
+ @nvtx.instrument_nvtx
457
+ def get_static_tensors(self, input_static_sig: StaticSignature) -> Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]]:
458
+ if input_static_sig in self.cache:
459
+ cached_entry = self.cache[input_static_sig]
460
+ return cached_entry.input_tensors, cached_entry.output_tensors
461
+ raise ValueError("Cached input/output tensors not found for the given static signature.")
462
+
463
+ @nvtx.instrument_nvtx
464
+ def warmup_run(self, func: Callable, *args, **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]:
465
+ warmup_outputs = None
466
+ s = torch.cuda.Stream()
467
+ s.wait_stream(torch.cuda.current_stream())
468
+ with torch.cuda.stream(s), torch.no_grad():
469
+ for _ in range(1):
470
+ warmup_outputs = func(*args, **kwargs)
471
+ torch.cuda.current_stream().wait_stream(s)
472
+ return warmup_outputs
473
+
474
+ @nvtx.instrument_nvtx
475
+ def add_static_entry(
476
+ self,
477
+ static_sig: StaticSignature,
478
+ input_tensors: Optional[List[torch.Tensor]] = None,
479
+ output_tensors: Optional[List[torch.Tensor]] = None,
480
+ ) -> None:
481
+ assert static_sig not in self.cache
482
+ self.cache[static_sig] = StaticTensorEntry(
483
+ input_tensors=input_tensors, output_tensors=output_tensors, template_entry_dict=dict()
484
+ )
485
+
486
+ @nvtx.instrument_nvtx
487
+ def add_template_entry(
488
+ self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, output_obj: Any = None
489
+ ) -> None:
490
+ try:
491
+ output_template = ArgsUtils.extract_output_template(output_obj)
492
+ self.cache[input_static_sig].template_entry_dict[input_dynamic_sig] = OutputTemplateEntry(
493
+ graph_entry_dict=dict(), output_template=output_template
494
+ )
495
+ except KeyError:
496
+ raise ValueError("StaticSignature not found in cache when adding template entry.")
497
+
498
+ @nvtx.instrument_nvtx
499
+ def add_graph_entry(
500
+ self,
501
+ input_static_sig: StaticSignature,
502
+ input_dynamic_sig: DynamicSignature,
503
+ layer_number: int,
504
+ graph: torch.cuda.CUDAGraph,
505
+ ) -> None:
506
+ try:
507
+ self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number] = GraphEntry(
508
+ graph=graph, inconsistent=False, invalid=False
509
+ )
510
+ except KeyError:
511
+ raise ValueError("StaticSignature or DynamicSignature not found in cache when adding graph entry.")
512
+
513
+ @nvtx.instrument_nvtx
514
+ def try_get_graph_entry(
515
+ self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, layer_number: int
516
+ ) -> Optional[GraphEntry]:
517
+ try:
518
+ return self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number]
519
+ except KeyError:
520
+ pass
521
+ return None
522
+
523
+ @nvtx.instrument_nvtx
524
+ def batch_set_graph_invalid(self, static_sig: StaticSignature) -> None:
525
+ if static_sig in self.cache:
526
+ static_tensor_entry = self.cache[static_sig]
527
+ if static_tensor_entry.template_entry_dict is not None:
528
+ for template_entry in static_tensor_entry.template_entry_dict.values():
529
+ for graph_entry in template_entry.graph_entry_dict.values():
530
+ graph_entry.invalid = True
531
+
532
+ @nvtx.instrument_nvtx
533
+ def set_graph_inconsistent(
534
+ self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, layer_number: int
535
+ ) -> None:
536
+ if input_static_sig not in self.cache:
537
+ self.add_static_entry(input_static_sig, None, None)
538
+ if input_dynamic_sig not in self.cache[input_static_sig].template_entry_dict:
539
+ self.add_template_entry(input_static_sig, input_dynamic_sig, None)
540
+ if layer_number not in self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict:
541
+ self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number] = GraphEntry(
542
+ graph=None, inconsistent=True, invalid=False
543
+ )
544
+ self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].graph_entry_dict[layer_number].inconsistent = True
545
+
546
+ @nvtx.instrument_nvtx
547
+ def wrapped_graph_capture(
548
+ self,
549
+ func: Callable,
550
+ input_obj: Any,
551
+ static_input_tensors: List[torch.Tensor],
552
+ static_output_tensors: List[torch.Tensor],
553
+ ) -> torch.cuda.CUDAGraph:
554
+ init_cudagraph_global_pool()
555
+ _set_capture_start()
556
+ try:
557
+ graph = torch.cuda.CUDAGraph()
558
+ _static_input_obj = ArgsUtils.replace_sliced_with_static(input_obj, static_input_tensors)
559
+ s = None # future: s = GreenCtxManager(0).create_stream()
560
+ with torch.cuda.graph(graph, pool=self.graph_mem_pool, stream=s), torch.no_grad():
561
+ _sliced_output_obj = func(*_static_input_obj["args"], **_static_input_obj["kwargs"])
562
+ _static_output_obj = ArgsUtils.replace_sliced_with_static(_sliced_output_obj, static_output_tensors)
563
+ except Exception as e:
564
+ torch.cuda.synchronize() # 等待所有异步操作完成
565
+ _set_capture_end()
566
+ raise e
567
+ _set_capture_end()
568
+ return graph
569
+
570
+ @nvtx.instrument_nvtx
571
+ def wrapped_graph_replay(
572
+ self,
573
+ graph: torch.cuda.CUDAGraph,
574
+ static_input_tensors: List[torch.Tensor],
575
+ static_output_tensors: List[torch.Tensor],
576
+ input_obj: Any,
577
+ output_template: Any,
578
+ ) -> Any:
579
+ _static_input_obj = ArgsUtils.replace_sliced_with_static(input_obj, static_input_tensors)
580
+ graph.replay()
581
+ output_obj = ArgsUtils.replace_static_with_sliced(output_template, static_output_tensors)
582
+ return output_obj
583
+
584
+ @nvtx.instrument_nvtx
585
+ def replay_graph(
586
+ self, input_static_sig: StaticSignature, input_dynamic_sig: DynamicSignature, input_obj: Any, layer_number: int
587
+ ) -> Any:
588
+ output_template = self.cache[input_static_sig].template_entry_dict[input_dynamic_sig].output_template
589
+ static_input_tensors = self.cache[input_static_sig].input_tensors
590
+ static_output_tensors = self.cache[input_static_sig].output_tensors
591
+ graph = self.try_get_cuda_graph(input_static_sig, input_dynamic_sig, layer_number=layer_number)
592
+ assert graph is not None, "CUDA Graph not found for replay."
593
+ output_obj = self.wrapped_graph_replay(
594
+ graph=graph,
595
+ static_input_tensors=static_input_tensors,
596
+ static_output_tensors=static_output_tensors,
597
+ input_obj=input_obj,
598
+ output_template=output_template,
599
+ )
600
+ return output_obj
601
+
602
+ @nvtx.instrument_nvtx
603
+ def capture_and_cache(
604
+ self,
605
+ func: Callable,
606
+ input_obj: Any,
607
+ layer_number: int,
608
+ input_static_sig: StaticSignature,
609
+ input_dynamic_sig: DynamicSignature,
610
+ ) -> Any:
611
+ """Capture a new CUDA Graph and cache it."""
612
+ # Access static tensors from cache
613
+ static_tensor_entry = self.cache[input_static_sig]
614
+ assert static_tensor_entry.input_tensors is not None
615
+ assert static_tensor_entry.output_tensors is not None
616
+ static_input_tensors = static_tensor_entry.input_tensors
617
+ static_output_tensors = static_tensor_entry.output_tensors
618
+
619
+ # Capture CUDA Graph
620
+ graph = self.wrapped_graph_capture(
621
+ func=func,
622
+ input_obj=input_obj,
623
+ static_input_tensors=static_input_tensors,
624
+ static_output_tensors=static_output_tensors,
625
+ )
626
+
627
+ # Cache the captured graph
628
+ graph_entry = self.try_get_graph_entry(input_static_sig, input_dynamic_sig, layer_number)
629
+ if graph_entry:
630
+ graph_entry.graph = graph
631
+ graph_entry.inconsistent = False
632
+ graph_entry.invalid = False
633
+ else:
634
+ self.add_graph_entry(
635
+ input_static_sig=input_static_sig, input_dynamic_sig=input_dynamic_sig, layer_number=layer_number, graph=graph
636
+ )
637
+
638
+ @nvtx.instrument_nvtx
639
+ def if_need_expand_static_tensors(
640
+ self, static_tensors: List[torch.Tensor], new_tensors: List[torch.Tensor], input_static_sig: StaticSignature
641
+ ) -> bool:
642
+ """Judge whether static tensors need to be expanded based on new tensors."""
643
+ res = False
644
+ static_infos = input_static_sig.tensor_static_infos
645
+
646
+ if len(static_tensors) != len(new_tensors) or len(static_tensors) != len(static_infos):
647
+ raise AssertionError(
648
+ f"[CUDA Graph] Tensor count mismatch. {len(static_tensors)=}, {len(new_tensors)=}, {len(static_infos)=}"
649
+ )
650
+
651
+ for static_t, new_t, static_info in zip(static_tensors, new_tensors, static_infos):
652
+ if static_t.ndim != new_t.ndim:
653
+ raise AssertionError(f"[CUDA Graph] Rank mismatch. {static_t.shape=}, {new_t.shape=}")
654
+ if static_t.dtype != new_t.dtype:
655
+ raise AssertionError(f"[CUDA Graph] Dtype mismatch. {static_t.dtype=}, {new_t.dtype=}")
656
+ for i in range(static_t.ndim):
657
+ if static_info.shapes[i] != -1 and static_info.shapes[i] != new_t.shape[i]:
658
+ raise AssertionError(
659
+ f"[CUDA Graph] Static dimension mismatch. {static_t.shape=}, {new_t.shape=}, {static_info.shapes=}, dim={i}"
660
+ )
661
+ if static_t.shape[i] < new_t.shape[i]:
662
+ res = True
663
+ return res
664
+
665
+ @nvtx.instrument_nvtx
666
+ def get_expanded_static_tensors(
667
+ self, static_tensors: List[torch.Tensor], new_tensors: List[torch.Tensor]
668
+ ) -> List[torch.Tensor]:
669
+ """Get expanded static tensors based on new tensors. Reuses existing tensors when possible."""
670
+ expanded_tensors = []
671
+ for static_t, new_t in zip(static_tensors, new_tensors):
672
+ if static_t.ndim != new_t.ndim:
673
+ raise AssertionError(
674
+ f"[CUDA Graph] Rank mismatch during expansion. Static: {static_t.shape}, New: {new_t.shape}"
675
+ )
676
+ new_shape = tuple(max(s, n) for s, n in zip(static_t.shape, new_t.shape))
677
+
678
+ if static_t.shape == new_shape:
679
+ expanded_tensors.append(static_t)
680
+ elif new_shape == new_t.shape:
681
+ expanded_tensors.append(new_t)
682
+ else:
683
+ expanded_tensor = torch.empty(new_shape, dtype=static_t.dtype, device=static_t.device)
684
+ expanded_tensors.append(expanded_tensor)
685
+ return expanded_tensors
686
+
687
+ @nvtx.instrument_nvtx
688
+ def try_replay_graph_inline(
689
+ self, func: Callable, args: Tuple, kwargs: Dict, layer_number: int
690
+ ) -> Tuple[bool, Optional[Union[torch.Tensor, List[torch.Tensor]]]]:
691
+ """Try to replay the CUDA Graph inline for fast execution."""
692
+ try:
693
+ func_name = func.__qualname__
694
+ input_obj = {"args": args, "kwargs": kwargs}
695
+
696
+ input_tensors, input_tensor_names, literals = ArgsUtils.try_fx_extract_core(input_obj)
697
+ if None in (input_tensors, input_tensor_names, literals):
698
+ input_tensors, input_tensor_names, literals = ArgsUtils.recursive_extract_core(input_obj)
699
+ input_static_sig, input_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
700
+ func_name, input_tensors, input_tensor_names, literals
701
+ )
702
+ static_tensor_entry = self.cache[input_static_sig]
703
+ static_input_tensors = static_tensor_entry.input_tensors
704
+ static_output_tensors = static_tensor_entry.output_tensors
705
+
706
+ template_entry = static_tensor_entry.template_entry_dict[input_dynamic_sig]
707
+ output_template = template_entry.output_template
708
+
709
+ graph_entry = template_entry.graph_entry_dict[layer_number]
710
+ graph = graph_entry.graph
711
+
712
+ assert graph is not None, "CUDA Graph not found for inline replay."
713
+ assert graph_entry.inconsistent is False, "CUDA Graph marked as inconsistent for inline replay."
714
+ assert graph_entry.invalid is False, "CUDA Graph marked as invalid for inline replay."
715
+
716
+ ArgsUtils.replace_sliced_with_static_simple(input_tensors, static_input_tensors)
717
+ graph.replay()
718
+ output_obj = ArgsUtils.replace_static_with_sliced(output_template, static_output_tensors)
719
+
720
+ if self.check_output_inconsistency:
721
+ cur_output_tensors, cur_output_tensor_names, cur_output_literals = ArgsUtils.recursive_extract_core(output_obj)
722
+ cur_output_static_sig, cur_output_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
723
+ func.__qualname__, cur_output_tensors, cur_output_tensor_names, cur_output_literals
724
+ )
725
+ output_tensors, output_tensor_names, output_literals = ArgsUtils.recursive_extract_core(output_obj)
726
+ cached_output_static_sig, cached_output_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
727
+ func.__qualname__, output_tensors, output_tensor_names, output_literals
728
+ )
729
+ if cur_output_static_sig != cached_output_static_sig or cur_output_dynamic_sig != cached_output_dynamic_sig:
730
+ magi_logger.warning(
731
+ f"[CUDA Graph] Warning: Output signature changed during inline replay. {func.__qualname__=}, {layer_number=}"
732
+ )
733
+ self.set_graph_inconsistent(input_static_sig, input_dynamic_sig, layer_number)
734
+ return False, None
735
+ return True, output_obj
736
+ except KeyError:
737
+ return False, None
738
+ except AssertionError:
739
+ return False, None
740
+ except Exception as e:
741
+ magi_logger.info(
742
+ f"[CUDA Graph] Exception during inline replay: {e=}, {func.__qualname__=}, {layer_number=}", rank="all"
743
+ )
744
+ raise e
745
+
746
+ @nvtx.instrument_nvtx
747
+ def run(self, func: Callable, *args, layer_number: Optional[int], **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]:
748
+ """Run the function with CUDA Graph optimization if possible."""
749
+
750
+ # Try inline replay first
751
+ success, output_obj = self.try_replay_graph_inline(func=func, args=args, kwargs=kwargs, layer_number=layer_number)
752
+ if success:
753
+ # print_rank_0(f"[CUDA Graph] Current cache stats: {self.tensor_entry_count=}, {self.graph_count=}.")
754
+ return output_obj
755
+
756
+ # Extract input signatures
757
+ func_name = func.__qualname__
758
+ input_obj = {"args": args, "kwargs": kwargs}
759
+ input_tensors, input_tensor_names, literals = ArgsUtils.recursive_extract_core(input_obj)
760
+ input_static_sig, input_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
761
+ func_name, input_tensors, input_tensor_names, literals
762
+ )
763
+
764
+ # Judge if the graph is marked as inconsistent
765
+ graph_entry = self.try_get_graph_entry(input_static_sig, input_dynamic_sig, layer_number)
766
+ if graph_entry is not None and graph_entry.inconsistent:
767
+ return func(*args, **kwargs)
768
+
769
+ # Judge if need to expand static tensors
770
+ if_need_expand_static_tensors = False
771
+ if_cached_tensor_entry = input_static_sig in self.cache
772
+ if if_cached_tensor_entry:
773
+ static_input_tensors, static_output_tensors = self.get_static_tensors(input_static_sig)
774
+ if_need_expand_static_tensors = self.if_need_expand_static_tensors(
775
+ static_input_tensors, input_tensors, input_static_sig
776
+ )
777
+
778
+ # Warmup run
779
+ warmup_output_obj = self.warmup_run(func, *args, **kwargs)
780
+
781
+ # Check input signature consistency after warmup
782
+ warmup_input_tensors, warmup_input_tensor_names, warmup_literals = ArgsUtils.recursive_extract_core(input_obj)
783
+ warmup_input_static_sig, warmup_input_dynamic_sig = ArgsUtils.generate_both_signatures_from_tensors(
784
+ func_name, warmup_input_tensors, warmup_input_tensor_names, warmup_literals
785
+ )
786
+ if warmup_input_static_sig != input_static_sig or warmup_input_dynamic_sig != input_dynamic_sig:
787
+ magi_logger.warning(
788
+ f"[CUDA Graph] Warning: Input signature changed during warmup run. {func_name=}, {layer_number=}"
789
+ )
790
+ self.set_graph_inconsistent(input_static_sig, input_dynamic_sig, layer_number)
791
+ return warmup_output_obj
792
+
793
+ # Update cache entries
794
+ if if_cached_tensor_entry:
795
+ if if_need_expand_static_tensors:
796
+ output_tensors, _, _ = ArgsUtils.recursive_extract_core(warmup_output_obj, extract_literals=False)
797
+ # Need to expand static tensors
798
+ new_static_input_tensors = self.get_expanded_static_tensors(static_input_tensors, input_tensors)
799
+ new_static_output_tensors = self.get_expanded_static_tensors(static_output_tensors, output_tensors)
800
+ # Register as new cache entries
801
+ self.batch_set_graph_invalid(input_static_sig)
802
+ self.cache[input_static_sig].input_tensors = new_static_input_tensors
803
+ self.cache[input_static_sig].output_tensors = new_static_output_tensors
804
+
805
+ self.add_template_entry(input_static_sig, input_dynamic_sig, warmup_output_obj)
806
+ else:
807
+ # Simply reuse existing static tensor entry
808
+ static_tensor_entry = self.cache[input_static_sig]
809
+ if input_dynamic_sig not in static_tensor_entry.template_entry_dict:
810
+ self.add_template_entry(input_static_sig, input_dynamic_sig, warmup_output_obj)
811
+
812
+ else:
813
+ # Create new static tensor entry
814
+ output_tensors, _, _ = ArgsUtils.recursive_extract_core(warmup_output_obj, extract_literals=False)
815
+ self.add_static_entry(input_static_sig, input_tensors, output_tensors)
816
+ self.add_template_entry(input_static_sig, input_dynamic_sig, warmup_output_obj)
817
+
818
+ # Capture and cache new CUDA Graph
819
+ self.capture_and_cache(
820
+ func=func,
821
+ input_obj=input_obj,
822
+ layer_number=layer_number,
823
+ input_static_sig=input_static_sig,
824
+ input_dynamic_sig=input_dynamic_sig,
825
+ )
826
+
827
+ magi_logger.info(
828
+ f"[CUDA Graph] Current cache stats: {self.tensor_entry_count=}, {self.graph_count=}, {self.tensor_mem_size=:.2f} MB, {self.graph_mem_pool_size=:.2f} MB"
829
+ )
830
+ return warmup_output_obj
831
+
832
+
833
+ _IS_GRAPH_CAPTURING = False
834
+
835
+
836
+ def _is_graph_capturing():
837
+ """Query if currently capturing."""
838
+ global _IS_GRAPH_CAPTURING
839
+ return _IS_GRAPH_CAPTURING
840
+
841
+
842
+ def _set_capture_start():
843
+ """Set graph capture has started."""
844
+ global _IS_GRAPH_CAPTURING
845
+ _IS_GRAPH_CAPTURING = True
846
+
847
+
848
+ def _set_capture_end():
849
+ """Set graph capture has ended."""
850
+ global _IS_GRAPH_CAPTURING
851
+ _IS_GRAPH_CAPTURING = False
852
+
853
+
854
+ # Singleton instance of CudaGraphMgr
855
+ _CUDA_GRAPH_MGR = CudaGraphMgr()
856
+
857
+
858
+ def cuda_graph_mgr() -> CudaGraphMgr:
859
+ """
860
+ Get the current CudaGraphMgr instance.
861
+ Returns:
862
+ CudaGraphMgr: The current CudaGraphMgr instance.
863
+ Raises:
864
+ AssertionError: If the CudaGraphMgr has not been initialized.
865
+ """
866
+ assert _CUDA_GRAPH_MGR is not None, "cuda graph manager is not initialized"
867
+ return _CUDA_GRAPH_MGR
868
+
869
+
870
+ def cuda_graph_enable_if(condition: Callable):
871
+ def decorator(func):
872
+ """
873
+ Decorator to enable CUDA graph option for a function. The function will be executed using CUDA Graph if the condition func provided outputs True.
874
+ Args:
875
+ condition (Callable): A callable that returns a bool indicating whether enable CUDA Graph.
876
+ """
877
+
878
+ @wraps(func)
879
+ def wrapped_func(*args, **kwargs):
880
+ enable_cuda_graph = condition()
881
+ if not enable_cuda_graph or _is_graph_capturing():
882
+ return func(*args, **kwargs)
883
+
884
+ layer_number = getattr(args[0], "layer_number", None) if args else None
885
+
886
+ return cuda_graph_mgr().run(func, *args, layer_number=layer_number, **kwargs)
887
+
888
+ return wrapped_func
889
+
890
+ return decorator
891
+
892
+
893
+ def gen_wrap_func_for_cudagraph(func: Callable, mode_prefix: str, target_prefix=None) -> Callable:
894
+ """
895
+ Wrap the given function for CUDA Graph:
896
+ 1. Generate a unique __qualname__ for caching
897
+ 2. Built-in call to cuda_graph_mgr().run
898
+ """
899
+ # Generate a unique identifier to avoid cache conflicts
900
+ func_id = id(func) if not hasattr(func, "__name__") else func.__name__
901
+ if mode_prefix == "full":
902
+ wrapped_func_name = f"Athena_CUDAGraph_{mode_prefix}_{func_id}"
903
+ else: # piecewise
904
+ wrapped_func_name = f"Athena_CUDAGraph_{mode_prefix}_{target_prefix}_{func_id}"
905
+
906
+ @nvtx.instrument_nvtx
907
+ def wrapped_func(*args, **kwargs):
908
+ layer_number = kwargs.pop("layer_number", None)
909
+ res = cuda_graph_mgr().run(func, *args, layer_number=layer_number, **kwargs)
910
+ return res
911
+
912
+ func.__qualname__ = wrapped_func_name
913
+ magi_logger.info(f"Set original function qualname to {wrapped_func_name} for CUDA Graph caching.")
914
+
915
+ # Copy attributes from the original function to the wrapped function
916
+ wrapped_func.__dict__.update(func.__dict__)
917
+ wrapped_func.__qualname__ = wrapped_func_name
918
+ for attr in ["__is_first_graph", "__is_last_graph", "__sym_shape_indices"]:
919
+ if hasattr(func, attr):
920
+ setattr(wrapped_func, attr, getattr(func, attr))
921
+
922
+ return wrapped_func
923
+
924
+
925
+ def init_cudagraph_global_pool():
926
+ """Initialize the global CUDA graph memory pool if not already initialized."""
927
+ from magi_compiler.cuda_graph_mgr import cuda_graph_mgr
928
+
929
+ if cuda_graph_mgr().graph_mem_pool is None:
930
+ cuda_graph_mgr().graph_mem_pool = torch.cuda.graph_pool_handle()
931
+ magi_logger.info("Initialized global CUDA graph pool for Athena.")
pkgs/MagiCompiler/magi_compiler/joint_graph_partition.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import Any, Optional, Sequence, Tuple
17
+ from unittest.mock import patch
18
+
19
+ import torch
20
+ import torch.fx as fx
21
+ from torch._functorch.compile_utils import get_aten_target
22
+ from torch._functorch.partitioners import NodeInfo, OpTypes, get_default_op_list, min_cut_rematerialization_partition
23
+ from torch._inductor.custom_graph_pass import CustomPartitionerFn
24
+ from torch.utils._ordered_set import OrderedSet
25
+
26
+ # from magi_compiler.partitioners import min_cut_rematerialization_partition
27
+ from .config import RecomputePolicy, get_compile_config
28
+ from .utils import compute_code_hash, magi_logger
29
+ from .utils.visualize import joint_graph_vis
30
+
31
+ SAVE_TENSOR_NODES: Optional[list[fx.Node]] = None
32
+
33
+
34
+ def is_memory_increase_by_node(node: fx.Node) -> bool:
35
+ # Only support aten.to now
36
+ assert get_aten_target(node) == torch.ops.prims.convert_element_type
37
+ input_dtype = node.args[0].meta["tensor_meta"].dtype
38
+ output_dtype = node.args[1]
39
+ assert output_dtype is not None
40
+ return output_dtype.itemsize > input_dtype.itemsize
41
+
42
+
43
+ def is_primal_contribute_to_bwd_directly(primal_node: fx.Node, node_info: NodeInfo, op_types: OpTypes) -> bool:
44
+ """
45
+ FSDP ensures that weights already reside in memory. If there exists a path from the primal to the bwd, and the path does not contain any matmul, then the primal contributes to the bwd directly.
46
+ And we should save this primals.
47
+ """
48
+ if node_info.is_required_bw(primal_node):
49
+ return True
50
+ topology_start = set({primal_node})
51
+
52
+ while len(topology_start) > 0:
53
+ cur_node = topology_start.pop()
54
+ for user in cur_node.users:
55
+ if node_info.is_required_bw(user):
56
+ return True
57
+ if op_types.is_compute_intensive(user):
58
+ continue
59
+ topology_start.add(user)
60
+ return False
61
+
62
+
63
+ def is_compute_intensive_and_has_following_recomputable_ops(
64
+ intermidiate_node: fx.Node, node_info: NodeInfo, op_types: OpTypes
65
+ ) -> Tuple[bool, fx.Node]:
66
+ """
67
+ If compute-intensive node(CIN) is not the output of fwd graph(has following memory-intensive ops in the fwd graph), then we should save this CIN node.
68
+ NOTE: For CIN+aten.to, we should save aten.to op instead of CIN op to save more memory.
69
+ """
70
+ if not op_types.is_compute_intensive(intermidiate_node):
71
+ return False, None
72
+
73
+ save_node = intermidiate_node
74
+ topology_start = set({save_node})
75
+ while len(topology_start) > 0:
76
+ cur_node = topology_start.pop()
77
+ fwd_user_nodes = []
78
+ for user in cur_node.users:
79
+ if node_info.is_required_fw(user):
80
+ fwd_user_nodes.append(user)
81
+
82
+ if len(fwd_user_nodes) > 1: # multiple users, save current node
83
+ return True, save_node
84
+ elif len(fwd_user_nodes) == 0: # output, return
85
+ return False, None
86
+
87
+ # save current node if it's user is recomputable
88
+ next_node = fwd_user_nodes[0]
89
+ if op_types.is_view(next_node):
90
+ if save_node == cur_node:
91
+ save_node = next_node
92
+ topology_start.add(next_node)
93
+ # Special case for aten.to, memory efficient case
94
+ elif get_aten_target(next_node) == torch.ops.prims.convert_element_type:
95
+ is_memory_increase = is_memory_increase_by_node(next_node)
96
+ if not is_memory_increase:
97
+ save_node = next_node
98
+ topology_start.add(next_node)
99
+ elif next_node.op == "output":
100
+ return False, None
101
+ else:
102
+ return True, save_node
103
+ assert False, f"Should not reach here: {intermidiate_node=} {save_node=}"
104
+
105
+
106
+ # TODO: We find an elegant impl to heuristically save nodes, reconstruct this later
107
+ def heuristic_choose_saved_values_set(joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1) -> list[fx.Node]:
108
+ output: OrderedSet[fx.Node] = OrderedSet()
109
+ op_types = get_default_op_list()
110
+ # Select the inputs that are required by the backward pass
111
+ for primal_node in node_info.inputs:
112
+ if is_primal_contribute_to_bwd_directly(primal_node, node_info, op_types):
113
+ output.add(primal_node)
114
+ magi_logger.info("MagiCompiler: saved_output forward-input = %s", output)
115
+ # Select the compute-intensive nodes that are required by the forward pass
116
+ for intermidiate_node in node_info.required_fw_nodes:
117
+ is_save, save_node = is_compute_intensive_and_has_following_recomputable_ops(intermidiate_node, node_info, op_types)
118
+ if is_save:
119
+ output.add(save_node)
120
+ magi_logger.info("MagiCompiler: saved_output compute-intensive = %s", output)
121
+ global SAVE_TENSOR_NODES
122
+ SAVE_TENSOR_NODES = list(output)
123
+ return list(output)
124
+
125
+
126
+ def custom_joint_graph_partition_fn(
127
+ joint_module: fx.GraphModule,
128
+ _joint_inputs,
129
+ compiler="inductor",
130
+ *,
131
+ num_fwd_outputs,
132
+ static_lifetime_input_indices: Optional[list[int]] = None,
133
+ ) -> tuple[fx.GraphModule, fx.GraphModule]:
134
+ recompute_config = get_compile_config().recompute_config
135
+ if recompute_config.recompute_policy == RecomputePolicy.HANDCRAFT:
136
+ magi_logger.info("MagiCompiler using handcraft recompute policy")
137
+ # TODO: different memory budget definition from torch
138
+ with patch("torch._functorch.config.activation_memory_budget", recompute_config.memory_budget):
139
+ fwd_module, bwd_module = min_cut_rematerialization_partition(
140
+ joint_module,
141
+ _joint_inputs,
142
+ compiler,
143
+ num_fwd_outputs=num_fwd_outputs,
144
+ static_lifetime_input_indices=static_lifetime_input_indices,
145
+ )
146
+ elif recompute_config.recompute_policy == RecomputePolicy.HEURISTIC:
147
+ magi_logger.info("MagiCompiler using heuristic recompute policy")
148
+ with patch("torch._functorch.partitioners.choose_saved_values_set", heuristic_choose_saved_values_set):
149
+ fwd_module, bwd_module = min_cut_rematerialization_partition(
150
+ joint_module,
151
+ _joint_inputs,
152
+ compiler,
153
+ num_fwd_outputs=num_fwd_outputs,
154
+ static_lifetime_input_indices=static_lifetime_input_indices,
155
+ )
156
+ elif recompute_config.recompute_policy == RecomputePolicy.AUTOSEARCH:
157
+ raise ValueError(f"AutoSearch recompute policy is not supported yet")
158
+ else:
159
+ raise ValueError(f"Invalid recompute policy: {recompute_config.recompute_policy}")
160
+
161
+ joint_graph_vis(joint_module, fwd_module, bwd_module, save_tensor_nodes=SAVE_TENSOR_NODES)
162
+
163
+ return fwd_module, bwd_module
164
+
165
+
166
+ class CustomJointGraphPartitionFn(CustomPartitionerFn):
167
+ def __call__(
168
+ self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any
169
+ ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
170
+ """
171
+ Implementation of the custom partitioner.
172
+ """
173
+ return custom_joint_graph_partition_fn(gm, joint_inputs, **kwargs)
174
+
175
+ def uuid(self) -> Optional[Any]:
176
+ """
177
+ Return an ID to uniquely identify your custom partitioner implementation.
178
+ Return None to skip inductor code caching entirely.
179
+ """
180
+ return compute_code_hash({os.path.abspath(__file__)})
pkgs/MagiCompiler/magi_compiler/magi_backend.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import ast
17
+ import dataclasses
18
+ import pprint
19
+ import time
20
+ from collections.abc import Callable
21
+ from contextlib import contextmanager
22
+ from pathlib import Path
23
+ from typing import Any
24
+
25
+ import magi_compiler.utils.envs as envs
26
+ import torch
27
+ import torch.fx as fx
28
+ from torch._dispatch.python import enable_python_dispatcher
29
+ from torch._dynamo.utils import lazy_format_graph_code
30
+ from torch._guards import detect_fake_mode
31
+
32
+ from ._cache_data_cls import CacheEntry, CacheHandle
33
+ from .compile_artifacts import MagiSerializableFunction
34
+ from .config import CompileConfig, CompileMode, CudaGraphMode
35
+ from .cuda_graph_mgr import gen_wrap_func_for_cudagraph
36
+ from .joint_graph_partition import CustomJointGraphPartitionFn
37
+ from .offload.offload_warpper import OffloadWrapper
38
+ from .partition_rules import inductor_partition_rule_context, resolve_defined_ops
39
+ from .passes import PostGradPassManager
40
+ from .passes.inductor_pass import pass_context
41
+ from .passes.replace_pass import FullGraphPassManager
42
+ from .piecewise_backend import PiecewiseBackend
43
+ from .piecewise_compiler import CompilerInterface, EagerAdaptor, InductorStandaloneAdaptor
44
+ from .utils import (
45
+ CompileMonitor,
46
+ compilation_counter,
47
+ compute_code_hash,
48
+ compute_hash,
49
+ detect_symbolic_tensor_indices,
50
+ magi_logger,
51
+ )
52
+ from .utils.envs import MAGI_CUSTOM_PARTITIONER_FN, MAGI_MODEL_TAG, MAGI_POST_GRAD_PASS
53
+ from .utils.visualize import save_fx_graph_visualization
54
+
55
+ compilation_start_time: float = 0.0
56
+
57
+
58
+ def _print_with_shape_and_time(runtime_shape: int | None, prefix: str = ""):
59
+ elapsed = time.time() - compilation_start_time
60
+ if runtime_shape is None:
61
+ magi_logger.info("%s for dynamic shape, took %.3f s", prefix, elapsed)
62
+ else:
63
+ magi_logger.info("%s for shape %s, took %.3f s", prefix, str(runtime_shape), elapsed)
64
+
65
+
66
+ @dataclasses.dataclass
67
+ class SplitItem:
68
+ submod_name: str
69
+ graph_id: int
70
+ is_splitting_graph: bool
71
+ graph: fx.GraphModule
72
+
73
+
74
+ def make_compiler(compile_config: CompileConfig) -> CompilerInterface:
75
+ if compile_config.backend == "inductor":
76
+ # Use standalone_compile with PyTorch 2.8+
77
+ assert hasattr(torch._inductor, "standalone_compile"), "standalone_compile not found in PyTorch Inductor"
78
+ magi_logger.info("Using InductorStandaloneAdaptor")
79
+ return InductorStandaloneAdaptor()
80
+ else:
81
+ assert compile_config.backend == "eager", f"Invalid backend for MagiCompiler: {compile_config.backend}"
82
+ magi_logger.info("Using EagerAdaptor")
83
+ return EagerAdaptor()
84
+
85
+
86
+ class CompilerManager:
87
+ """
88
+ Manage the compilation process, including graph compilation, compile artifacts caching and loading.
89
+
90
+ The cache is a dict mapping `(runtime_shape, graph_index, backend_name)` to `any_data` returned from the compiler.
91
+
92
+ When serializing the cache, we save it to a Python file for readability. We don't use json here because json doesn't support int as key.
93
+ """
94
+
95
+ def __init__(self, compile_config: CompileConfig):
96
+ self.cache: dict[CacheEntry, CacheHandle] = dict()
97
+ self.compile_config = compile_config
98
+ self.compiler = make_compiler(compile_config)
99
+ self.disable_cache = envs.MAGI_DISABLE_COMPILE_CACHE
100
+
101
+ @property
102
+ def hash(self) -> str:
103
+ return self.compiler.hash
104
+
105
+ @contextmanager
106
+ def compile_context(self, runtime_shape: int | None = None):
107
+ """Provide compilation context for the duration of compilation to set
108
+ any torch global properties we want to scope to a single Inductor
109
+ compilation (e.g. partition rules, pass context)."""
110
+ with pass_context(runtime_shape):
111
+ if self.compile_config.use_inductor_graph_partition:
112
+ inductor_partition_ops = resolve_defined_ops(self.compile_config.splitting_ops)
113
+ with inductor_partition_rule_context(inductor_partition_ops):
114
+ yield
115
+ else:
116
+ yield
117
+
118
+ def initialize_cache(self, cache_dir: Path, prefix: str = ""):
119
+ """
120
+ Initialize the cache directory for the compiler.
121
+
122
+ The organization of the cache directory is as follows:
123
+ cache_dir=/path/to/torch_compile_cache/rank_i_j/hash_str/prefix/
124
+ inside cache_dir, there will be:
125
+ - magi_compile_cache.py
126
+ - computation_graph.py
127
+
128
+ for multiple prefixes, they can share the same base cache dir of
129
+ /path/to/torch_compile_cache/rank_i_j/hash_str/ to store some
130
+ common compilation artifacts.
131
+ """
132
+
133
+ self.cache_dir: Path = cache_dir
134
+ self.cache_file_path: Path = cache_dir / "magi_compile_cache.py"
135
+
136
+ if self.disable_cache:
137
+ magi_logger.info("MagiCompiler's cache is disabled.")
138
+ return
139
+
140
+ magi_logger.info("Using cache directory: %s for MagiCompiler", cache_dir)
141
+ if self.cache_file_path.exists():
142
+ # load the cache from the file
143
+ with self.cache_file_path.open() as f:
144
+ # Parse Python literals using ast.literal_eval, which is a safe alternative to eval().
145
+ raw = ast.literal_eval(f.read())
146
+ self.cache = {CacheEntry(*entry): CacheHandle(*handle) for entry, handle in raw.items()}
147
+
148
+ self.compiler.initialize_cache(cache_dir=self.cache_dir, prefix=prefix)
149
+
150
+ def save_to_file(self):
151
+ if self.disable_cache:
152
+ return
153
+ # serialize to a literal-friendly dict
154
+ serializable = {(e.runtime_shape, e.graph_index, e.backend_name): (h.key, h.path) for e, h in self.cache.items()}
155
+ printer = pprint.PrettyPrinter(indent=4)
156
+ data = printer.pformat(serializable)
157
+ with self.cache_file_path.open("w") as f:
158
+ f.write(data)
159
+
160
+ def load(self, graph: fx.GraphModule, example_inputs: list[Any], cache_entry: CacheEntry) -> Callable | None:
161
+ if cache_entry not in self.cache:
162
+ return None
163
+ cache_handle = self.cache[cache_entry]
164
+ _print_with_shape_and_time(
165
+ cache_entry.runtime_shape,
166
+ f"Directly load the {cache_entry.graph_index}-th graph from {cache_entry.backend_name} via handle {cache_handle}",
167
+ )
168
+ return self.compiler.load(graph, example_inputs, cache_entry, cache_handle)
169
+
170
+ # TODO(hongyu): Support training mode here
171
+ def compile(
172
+ self,
173
+ graph: fx.GraphModule,
174
+ example_inputs: tuple[torch.fx.node.Argument, ...],
175
+ compile_config: CompileConfig,
176
+ graph_index: int = 0,
177
+ num_graphs: int = 1,
178
+ runtime_shape: int | None = None,
179
+ ) -> Callable:
180
+ # Step0: update some global metrics
181
+ compilation_counter.num_backend_compilations += 1
182
+ if graph_index == 0:
183
+ global compilation_start_time
184
+ compilation_start_time = time.time()
185
+
186
+ # Step1: Try loading from the cache
187
+ cache_entry = CacheEntry(runtime_shape, graph_index, self.compiler.name)
188
+ compiled_graph = self.load(graph, example_inputs, cache_entry)
189
+ if compiled_graph is not None:
190
+ return compiled_graph
191
+
192
+ # Step2: Compile the graph
193
+ key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
194
+ with self.compile_context(runtime_shape):
195
+ compiled_graph, cache_handle = self.compiler.compile(
196
+ graph, example_inputs, compile_config.inductor_compile_config, runtime_shape, key
197
+ )
198
+ assert compiled_graph is not None, "Failed to compile the graph"
199
+
200
+ # Step3: Store the artifact in the cache
201
+ if not self.disable_cache and cache_handle is not None:
202
+ assert cache_entry not in self.cache, "Cache entry already exists"
203
+ self.cache[cache_entry] = cache_handle
204
+ compilation_counter.num_cache_entries += 1
205
+ _print_with_shape_and_time(runtime_shape, f"Compile the {graph_index}/{num_graphs} graph")
206
+
207
+ return compiled_graph
208
+
209
+
210
+ # TODO(hongyu): Support training mode here
211
+ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
212
+ """
213
+ Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
214
+ It runs the given graph with fake inputs, and compile some submodules specified by `compile_submod_names` with compilation configs.
215
+
216
+ NOTE: the order in `compile_submod_names` matters, because it will be used to determine the order of the compiled piecewise graphs.
217
+ The first graph will handle logging, and the last graph has some special cudagraph output handling.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ module: torch.fx.GraphModule,
223
+ compiler_manager: CompilerManager,
224
+ compile_submod_names: list[str],
225
+ compile_config: CompileConfig,
226
+ ):
227
+ super().__init__(module)
228
+
229
+ self.fake_mode = detect_fake_mode()
230
+ self.compiler_manager = compiler_manager
231
+ self.compile_submod_names = compile_submod_names
232
+ self.compile_config = compile_config
233
+ # extra_traceback is attribute of torch.fx.Interpreter, when it is True, it annoyingly dumps the torch.fx.Graph on errors.
234
+ self.extra_traceback = False
235
+
236
+ def _fix_graph_device_placement(self, module: torch.nn.Module):
237
+ for name, child in module.named_children():
238
+ self._fix_graph_device_placement(child)
239
+
240
+ if isinstance(module, torch.fx.GraphModule):
241
+ needs_recompile = False
242
+ target_device = torch.cuda.current_device()
243
+
244
+ factory_functions = [
245
+ torch.empty,
246
+ torch.zeros,
247
+ torch.ones,
248
+ torch.full,
249
+ torch.rand,
250
+ torch.randn,
251
+ torch.arange,
252
+ torch.tensor,
253
+ torch.ops.aten.empty.memory_format,
254
+ ]
255
+
256
+ for node in module.graph.nodes:
257
+ if node.op == 'call_function':
258
+ is_factory = node.target in factory_functions or (
259
+ hasattr(node.target, '__name__') and node.target.__name__ in ['empty', 'zeros', 'ones', 'full']
260
+ )
261
+
262
+ if is_factory:
263
+ if 'device' in node.kwargs:
264
+ current_dev = node.kwargs['device']
265
+ if str(current_dev) == 'cpu' or current_dev == torch.device('cpu'):
266
+ node.update_kwarg('device', target_device)
267
+ needs_recompile = True
268
+
269
+ if needs_recompile:
270
+ module.recompile()
271
+
272
+ def run(self, *args):
273
+ fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
274
+ if self.compile_config.offload_config.model_cpu_offload:
275
+ self._fix_graph_device_placement(self.module)
276
+ for i, arg in enumerate(fake_args):
277
+ if isinstance(arg, torch.Tensor):
278
+ fake_args[i] = arg.cuda()
279
+
280
+ with self.fake_mode, enable_python_dispatcher():
281
+ return super().run(*fake_args)
282
+
283
+ def call_module(
284
+ self, target: torch.fx.node.Target, args: tuple[torch.fx.node.Argument, ...], kwargs: dict[str, Any]
285
+ ) -> Any:
286
+ assert isinstance(target, str)
287
+ output = super().call_module(target, args, kwargs)
288
+ if target not in self.compile_submod_names:
289
+ return output
290
+
291
+ index = self.compile_submod_names.index(target)
292
+ submod = self.fetch_attr(target)
293
+ sym_shape_indices = [i for i, x in enumerate(args) if isinstance(x, torch.SymInt)]
294
+ magi_logger.info(f"Compiling {target=}, {sym_shape_indices=}, {args=}")
295
+
296
+ compiled_graph_for_dynamic_shape = self.compiler_manager.compile(
297
+ submod, args, self.compile_config, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None
298
+ )
299
+
300
+ piecewise_backend = PiecewiseBackend(
301
+ submod,
302
+ compiled_graph_for_dynamic_shape,
303
+ self.compile_config,
304
+ index,
305
+ len(self.compile_submod_names),
306
+ sym_shape_indices,
307
+ self.compiler_manager,
308
+ )
309
+
310
+ if self.compile_config.use_inductor_graph_partition or self.compile_config.cudagraph_mode != CudaGraphMode.PIECEWISE:
311
+ self.module.__dict__[target] = piecewise_backend
312
+ else:
313
+ wrapped_backend = gen_wrap_func_for_cudagraph(
314
+ func=piecewise_backend, mode_prefix=CudaGraphMode.PIECEWISE.name.lower(), target_prefix=target
315
+ )
316
+
317
+ self.module.__dict__[target] = wrapped_backend
318
+ magi_logger.info(
319
+ f"Wrapped piecewise submodule {target} (index {index}) with CUDA Graph "
320
+ f"[PIECEWISE mode, first_graph={piecewise_backend.is_first_graph}, last_graph={piecewise_backend.is_last_graph}]"
321
+ )
322
+
323
+ return output
324
+
325
+
326
+ class MagiBackend:
327
+ """
328
+ The compilation backend for `torch.compile` with MagiCompiler.
329
+ It is used for compilation mode of `CompileMode.MAGI_COMPILE`,
330
+ where we customize the compilation.
331
+
332
+ The major work of this backend is to split the graph into
333
+ piecewise graphs, and pass them to the piecewise backend.
334
+
335
+ This backend also adds the PostGradPassManager to Inductor config,
336
+ which handles the post-grad passes.
337
+ """
338
+
339
+ compile_config: CompileConfig
340
+ _called_once: bool = False
341
+ # for the graph we compiled
342
+ graph: fx.GraphModule
343
+ compiler_manager: CompilerManager
344
+ # for cudagraph
345
+ sym_tensor_indices: list[int] # indices for tensors that have symbolic shapes
346
+ input_buffers: list[torch.Tensor] # buffers for input tensors that have symbolic shapes
347
+
348
+ def __init__(self, compile_config: CompileConfig, model_tag: str = ""):
349
+ self.model_tag = model_tag or MAGI_MODEL_TAG
350
+ self.compile_config = compile_config
351
+ self._configure_custom_passes()
352
+ self.compiler_manager: CompilerManager = CompilerManager(self.compile_config)
353
+
354
+ self.sym_tensor_indices = []
355
+ self.input_buffers = []
356
+
357
+ def _configure_custom_passes(self):
358
+ # Custom pass 1: full graph passes between Dynamo and AOTAutograd
359
+ self.full_graph_pass_manager = FullGraphPassManager(self.compile_config.pass_config)
360
+
361
+ # Custom pass 2: custom partitioner function
362
+ custom_partitioner_fn = CustomJointGraphPartitionFn()
363
+ if MAGI_CUSTOM_PARTITIONER_FN in self.compile_config.inductor_compile_config:
364
+ existing_fn = self.compile_config.inductor_compile_config[MAGI_CUSTOM_PARTITIONER_FN]
365
+ assert isinstance(existing_fn, CustomJointGraphPartitionFn)
366
+ assert existing_fn.uuid() == custom_partitioner_fn.uuid()
367
+ self.compile_config.inductor_compile_config[MAGI_CUSTOM_PARTITIONER_FN] = custom_partitioner_fn
368
+
369
+ # Custom pass 3: post-grad passes after AOTAutograd
370
+ post_grad_pass_manager = PostGradPassManager()
371
+ post_grad_pass_manager.configure(self.compile_config)
372
+
373
+ # Run post-grad custom passes with post_grad_custom_post_pass hook
374
+ if MAGI_POST_GRAD_PASS in self.compile_config.inductor_compile_config:
375
+ existing_pass = self.compile_config.inductor_compile_config[MAGI_POST_GRAD_PASS]
376
+ assert isinstance(existing_pass, PostGradPassManager)
377
+ assert existing_pass.uuid() == post_grad_pass_manager.uuid()
378
+
379
+ self.compile_config.inductor_compile_config[MAGI_POST_GRAD_PASS] = post_grad_pass_manager
380
+
381
+ def _init_cache(self) -> str:
382
+ hash_key = compute_hash(
383
+ [self.compile_config.hash, self.compiler_manager.hash, compute_code_hash(self.compile_config.traced_files)]
384
+ )
385
+ self.compile_config.traced_files.clear()
386
+
387
+ # Path: .../model_{idx}_{model_tag}_rank_{rank}/{hash}/{model_tag}/ (last segment = class name or user tag)
388
+ self.local_cache_dir: Path = self.compile_config.cache_dump_path() / hash_key / self.model_tag
389
+ self.local_cache_dir.mkdir(parents=True, exist_ok=True)
390
+
391
+ self.compiler_manager.initialize_cache(self.local_cache_dir, self.model_tag)
392
+
393
+ def _save_partitioned_graph(self, split_gm: fx.GraphModule):
394
+ graph_path = self.local_cache_dir / "computation_graph.py"
395
+ if not graph_path.exists():
396
+ # code adapted from
397
+ # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
398
+ # use `print_readable` because it can include submodules
399
+ src = "from __future__ import annotations\nimport torch\n" + split_gm.print_readable(print_output=False)
400
+ src = src.replace("<lambda>", "GraphModule")
401
+ with open(graph_path, "w") as f:
402
+ f.write(src)
403
+ magi_logger.info("Computation graph saved to %s", graph_path)
404
+
405
+ def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[SplitItem]]:
406
+ # Step 1: resolve the splitting ops
407
+ if self.compile_config.use_inductor_graph_partition:
408
+ # Let Inductor decide partitioning; avoid FX-level pre-splitting.
409
+ fx_split_ops: list[str] = []
410
+ else:
411
+ fx_split_ops = self.compile_config.splitting_ops or []
412
+ resolved_ops: list[torch._ops.OpOverload] = resolve_defined_ops(fx_split_ops)
413
+ magi_logger.info(f"Setting up FX-level graph split with ops: {fx_split_ops=}")
414
+ magi_logger.info(f"Resolved splitting ops for FX-level graph split: {resolved_ops=}")
415
+
416
+ # Step 2: split graph by ops, we split graph based on resolved_ops, which becomes the partitioned single graph.
417
+ subgraph_id = 0
418
+ node_to_subgraph_id = {}
419
+ split_op_graphs = []
420
+ for node in graph.graph.nodes:
421
+ if node.op in ("output", "placeholder"):
422
+ continue
423
+ # Match node.target against resolved_ops, node.target can be OpOverloadPacket, need to check .default
424
+ if node.op == "call_function" and (
425
+ node.target in resolved_ops or (hasattr(node.target, "default") and node.target.default in resolved_ops)
426
+ ):
427
+ magi_logger.info(f"Splitting graph at {node=} with {node.target=}")
428
+ subgraph_id += 1
429
+ node_to_subgraph_id[node] = subgraph_id
430
+ split_op_graphs.append(subgraph_id)
431
+ subgraph_id += 1
432
+ else:
433
+ node_to_subgraph_id[node] = subgraph_id
434
+
435
+ # Step 3: split the graph based on node_to_subgraph_id
436
+ # pytorch might reorder the nodes and the semantics of the graph will change when we have mutations in the graph, if we don't set keep_original_order=True
437
+ split_gm = torch.fx.passes.split_module.split_module(
438
+ graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
439
+ )
440
+
441
+ def _extract_example_values(args) -> list:
442
+ example_values = []
443
+
444
+ def _recurse_extract(arg):
445
+ if isinstance(arg, (list, tuple)):
446
+ for sub_arg in arg:
447
+ _recurse_extract(sub_arg)
448
+ else:
449
+ example_value = arg.meta.get("example_value")
450
+ assert example_value is not None, f"Output arg {arg} has no example_value for tensor_meta recovery"
451
+ example_values.append(example_value)
452
+
453
+ _recurse_extract(args)
454
+ return example_values
455
+
456
+ def _format_output_values(values: list):
457
+ if not values:
458
+ return None
459
+ return tuple(values) if len(values) > 1 else values[0]
460
+
461
+ def _recursive_recover_tensor_meta(gm: fx.GraphModule):
462
+ """
463
+ 递归恢复指定 GraphModule 及其所有嵌套 submodule 中所有 node 的 example_value
464
+ 支持任意层级的 submodule 嵌套
465
+ """
466
+ for node in gm.graph.nodes:
467
+ if node.meta.get("example_value") is not None:
468
+ continue
469
+
470
+ if node.op == "call_module":
471
+ submod: fx.GraphModule = getattr(gm, node.target)
472
+ _recursive_recover_tensor_meta(submod) # 递归调用,处理嵌套 submodule
473
+ output_node = next(n for n in submod.graph.nodes if n.op == "output")
474
+ assert output_node is not None, f"Output node not found in submodule {node.target}"
475
+ output_values = _extract_example_values(output_node.args)
476
+ node.meta["example_value"] = _format_output_values(output_values)
477
+ elif node.op == "call_function":
478
+ if "getitem" in str(node.target):
479
+ prev_node, getitem_index = node.args
480
+ prev_example_value = prev_node.meta.get("example_value")
481
+ assert (
482
+ prev_example_value is not None
483
+ ), f"Previous node {prev_node} has no example_value for tensor_meta recovery of node {node}"
484
+ node.meta["example_value"] = prev_example_value[getitem_index]
485
+ elif "grad" in str(node.target) or "device" in str(node.target): # 暂时不做处理
486
+ node.meta["example_value"] = None
487
+ elif node.op == "output":
488
+ output_values = _extract_example_values(node.args[0])
489
+ node.meta["example_value"] = _format_output_values(output_values)
490
+
491
+ else:
492
+ raise ValueError(f"Unsupported node op for tensor_meta recovery: {node.op} for node {node}")
493
+
494
+ magi_logger.info(f"Recovered example_value for node {node.name}: {node.meta['example_value']=}")
495
+
496
+ # Recover tensor_meta for all nodes in split_gm and its submodules
497
+ if envs.MAGI_ENABLE_PROFILE:
498
+ _recursive_recover_tensor_meta(split_gm)
499
+
500
+ # Step 4: fetch all the submodules
501
+ piecewise_graphs = []
502
+ names = [name for (name, module) in split_gm.named_modules()]
503
+ for name in names:
504
+ # Only keep the top-level modules, skip recursive child modules or the root module
505
+ if "." in name or name == "":
506
+ continue
507
+
508
+ module = getattr(split_gm, name)
509
+ assert isinstance(module, fx.GraphModule), f"Expected fx.GraphModule, got {type(module)}"
510
+
511
+ graph_id = int(name.replace("submod_", ""))
512
+ piecewise_graphs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
513
+ # sort by integer graph_id, rather than string name
514
+ piecewise_graphs.sort(key=lambda x: x.graph_id)
515
+
516
+ # Step 5: visualize the split graph
517
+ # depyf already hooks lazy_format_graph_code and dumps the graph, we do not print the graph here
518
+ lazy_format_graph_code("Before split", graph, print_output=True, include_stride=True, include_device=True)
519
+ lazy_format_graph_code("After split", split_gm, print_output=True, include_stride=True, include_device=True)
520
+
521
+ if envs.MAGI_ENABLE_FX_GRAPH_VIZ:
522
+ save_fx_graph_visualization(split_gm.graph, sub_dir="after_split", filename="split_gm_root")
523
+ for item in piecewise_graphs:
524
+ save_fx_graph_visualization(item.graph.graph, sub_dir="after_split", filename=item.submod_name)
525
+
526
+ return split_gm, piecewise_graphs
527
+
528
+ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFunction:
529
+ assert not self._called_once, "MagiBackend can only be called once cause compilation is a one-time process"
530
+ self._called_once = True
531
+ magi_logger.info("Dynamo traced files (for compilation cache):\n%s", "\n".join(self.compile_config.traced_files))
532
+ compilation_counter.num_graphs_seen += 1
533
+ CompileMonitor().mark("Dynamo bytecode transform")
534
+
535
+ self._init_cache()
536
+
537
+ self.full_graph_pass_manager(graph)
538
+
539
+ split_gm, piecewise_graphs = self._split_graph(graph)
540
+
541
+ submod_names_to_compile = [item.submod_name for item in piecewise_graphs if not item.is_splitting_graph]
542
+ compilation_counter.num_piecewise_graphs_seen += len(piecewise_graphs)
543
+ compilation_counter.num_piecewise_capturable_graphs_seen += len(submod_names_to_compile)
544
+ magi_logger.info(f"Piecewise modules waiting for compilation: {submod_names_to_compile}")
545
+
546
+ # propagate the split graph to the piecewise backend, compile submodules with symbolic shapes
547
+ try:
548
+ PiecewiseCompileInterpreter(split_gm, self.compiler_manager, submod_names_to_compile, self.compile_config).run(
549
+ *example_inputs
550
+ )
551
+ except Exception as e:
552
+ # Magi compile 的集中失败入口:直接打印 ERROR,方便在大模型日志中 grep
553
+ magi_logger.error("Magi compile failed while compiling piecewise submodules %s: %s", submod_names_to_compile, e)
554
+ raise
555
+ self._save_partitioned_graph(split_gm)
556
+
557
+ # TODO: Support DBO and NAT here
558
+ # split_gm = DBOGraphModule(split_gm, self.compile_config)
559
+ if self.compile_config.offload_config.model_cpu_offload:
560
+ split_gm = OffloadWrapper(split_gm, self.compile_config)
561
+
562
+ # if envs.MAGI_ENABLE_TOKENFLOW:
563
+ # from magi_compiler.tokenflow.graph_fork import GraphForkWrapper
564
+
565
+ if envs.MAGI_ENABLE_PROFILE:
566
+ from magi_compiler.tokenflow.graph_profile import gen_profile_wrap_func
567
+
568
+ split_gm = gen_profile_wrap_func(split_gm)
569
+
570
+ if self.compile_config.cudagraph_mode == CudaGraphMode.FULL and self.compile_config.cudagraph_copy_inputs:
571
+ return self._serialize_func_with_cudagraph(graph, split_gm, example_inputs)
572
+
573
+ return MagiSerializableFunction(graph, example_inputs, self.model_tag, split_gm)
574
+
575
+ def _serialize_func_with_cudagraph(
576
+ self, graph: fx.GraphModule, split_gm: fx.GraphModule, example_inputs: list[Any]
577
+ ) -> MagiSerializableFunction:
578
+ fake_mode = detect_fake_mode()
579
+ fake_args = [fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in example_inputs]
580
+
581
+ self.sym_tensor_indices = detect_symbolic_tensor_indices(fake_args)
582
+
583
+ wrapped_split_gm = gen_wrap_func_for_cudagraph(func=split_gm, mode_prefix=CudaGraphMode.FULL.name.lower())
584
+
585
+ return MagiSerializableFunction(graph, example_inputs, self.model_tag, wrapped_split_gm)
586
+
587
+
588
+ def init_backend(compile_config: CompileConfig) -> str | Callable:
589
+ """
590
+ Initialize the backend based on CompileConfig.
591
+ """
592
+ if compile_config.compile_mode is None or compile_config.compile_mode == CompileMode.NONE:
593
+ raise ValueError("No compilation mode is set.")
594
+
595
+ from torch._dynamo.backends.registry import list_backends
596
+
597
+ torch_backends = list_backends(exclude_tags=tuple())
598
+ magi_logger.info("Supported torch backends: %s", torch_backends)
599
+ if compile_config.compile_mode == CompileMode.TORCH_COMPILE:
600
+ assert compile_config.backend in torch_backends, f"Invalid backend for torch compilation: {compile_config.backend}"
601
+ return compile_config.backend
602
+ elif compile_config.compile_mode == CompileMode.MAGI_COMPILE:
603
+ assert compile_config.backend in ["eager", "inductor"], f"Invalid backend for MagiCompiler: {compile_config.backend}"
604
+ model_tag = getattr(compile_config, "model_tag", None) or MAGI_MODEL_TAG
605
+ return MagiBackend(compile_config, model_tag=model_tag)
606
+ else:
607
+ raise ValueError(f"Invalid compile mode: {compile_config.compile_mode}")
pkgs/MagiCompiler/magi_compiler/magi_compiler_base.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
17
+
18
+ import inspect
19
+ import os
20
+ import sys
21
+ from abc import abstractmethod
22
+ from contextlib import contextmanager
23
+ from types import CodeType
24
+ from typing import Callable, Literal
25
+
26
+ import magi_compiler.utils.envs as envs
27
+ import torch
28
+ from magi_compiler.utils import compute_hash, get_git_version
29
+ from magi_compiler.utils.compile_time_monitor import CompileMonitor
30
+
31
+ from .config import CompileConfig, CompileMode
32
+ from .magi_backend import init_backend
33
+ from .utils import compute_code_hash, compute_code_hash_with_content, magi_logger
34
+
35
+
36
+ def _verify_source_unchanged(source_info, compile_config: CompileConfig) -> None:
37
+ file_contents = {}
38
+ for source in source_info.inlined_sources:
39
+ module = sys.modules[source.module]
40
+ file = inspect.getfile(module)
41
+ file_contents[file] = source.content
42
+ compile_config.traced_files.add(file)
43
+ expected_checksum = compute_code_hash_with_content(file_contents)
44
+ actual_checksum = compute_code_hash(set(file_contents.keys()))
45
+ if expected_checksum != actual_checksum:
46
+ raise RuntimeError("Source code has changed since the last compilation. Recompiling the model.")
47
+
48
+
49
+ class MagiCompilerBase:
50
+ compile_config: CompileConfig
51
+ """
52
+ A wrapper class for torch.compile, with a custom dispatch logic.
53
+ Subclasses should:
54
+ 1. Implement the forward method
55
+ 2. Implement the dispatch logic in the __call__ method
56
+ It can use `self.compiled_codes` to access the compiled bytecode,
57
+ and `with self.dispatch_to_compiled_code:` to dispatch to
58
+ the compiled code.
59
+ 3. Implement the `__init__` method to determine how to call
60
+ `torch.compile` over the forward method.
61
+ """
62
+
63
+ def __init__(self, compile_config: CompileConfig):
64
+ backend = init_backend(compile_config)
65
+ options = None
66
+ if isinstance(backend, str) and backend == "inductor":
67
+ options = compile_config.inductor_compile_config
68
+ if envs.MAGI_AOT_COMPILE:
69
+ options = options or {}
70
+ # Drop all the guards in the AOT compile mode as bytecode hook is not used anymore.
71
+ options["guard_filter_fn"] = lambda guards: [False for _ in guards]
72
+ assert hasattr(torch._dynamo.config, "enable_aot_compile"), "enable_aot_compile config not available"
73
+ torch._dynamo.config.enable_aot_compile = True
74
+
75
+ self.compiled_callable = torch.compile(self.forward, fullgraph=True, backend=backend, options=options)
76
+ self.original_code_object: CodeType = self.__class__.forward.__code__
77
+ self.compiled_code: CodeType | None = None
78
+ self.aot_compiled_fn: Callable | None = None
79
+
80
+ @property
81
+ def aot_compilation_path(self) -> str:
82
+ """
83
+ When using torch.compile in AOT mode, we store the cache artifacts
84
+ under cache_root_dir/torch_aot_compile/{hash}/rank_i_j. The {hash}
85
+ contains all of the factors except for the source files being
86
+ traced through, because we don't actually know which source files
87
+ to check at this point (before dynamo runs).
88
+ On loading we will actually look at the source files being traced
89
+ through. If any source file have changed (compared with the
90
+ serialized backend artifacts), then we need to generate a new AOT
91
+ compile artifact from scratch.
92
+ """
93
+ hash_key = compute_hash([self.forward, self.compile_config.model_idx, self.compile_config.hash, get_git_version()])
94
+ cache_dir = os.path.join(self.compile_config.cache_root_dir, "torch_aot_compile", hash_key)
95
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
96
+ cache_dir = os.path.join(cache_dir, f"rank_{rank}")
97
+ os.makedirs(cache_dir, exist_ok=True)
98
+ aot_compilation_path = os.path.join(cache_dir, "model")
99
+ return aot_compilation_path
100
+
101
+ def try_load_aot_compile_artifacts(self) -> Callable | None:
102
+ if self.aot_compiled_fn is not None:
103
+ return self.aot_compiled_fn
104
+ if not os.path.exists(self.aot_compilation_path):
105
+ return None
106
+ with open(self.aot_compilation_path, "rb") as f:
107
+ CompileMonitor().start(
108
+ self.compile_config.compile_mode == CompileMode.MAGI_COMPILE, self.compile_config.debug_dump_path()
109
+ )
110
+ loaded_fn = torch.compiler.load_compiled_function(f)
111
+ _verify_source_unchanged(loaded_fn.source_info(), self.compile_config)
112
+ return loaded_fn
113
+
114
+ def aot_compile(self, *args, **kwargs):
115
+ """
116
+ Run the model in AOT (Ahead-Of-Time) compile mode.
117
+
118
+ All compilation work is completed before execution, suitable for production environment.
119
+ This results in longer compilation time but superior runtime performance.
120
+ """
121
+ assert hasattr(self.compiled_callable, "aot_compile"), "aot_compile is not supported by the current configuration"
122
+ return self.compiled_callable.aot_compile((args, kwargs))
123
+
124
+ def jit_compile(self, *args, **kwargs):
125
+ """
126
+ Run the model in JIT (Just-In-Time) compile mode.
127
+
128
+ Compilation occurs at runtime, first run may be slower due to compilation overhead.
129
+ """
130
+ handle = torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
131
+ output = self.compiled_callable(*args, **kwargs)
132
+ handle.remove()
133
+ return output
134
+
135
+ @abstractmethod
136
+ def forward(self, *args, **kwargs):
137
+ ...
138
+
139
+ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
140
+ """Hook to save the compiled bytecode for direct execution."""
141
+ if old_code is not self.original_code_object:
142
+ return
143
+ # Step1: Check if the old bytecode is from the compiled code
144
+ # code borrowed from depyf enable_debugging.py
145
+ frame = sys._getframe()
146
+ while frame and frame.f_back:
147
+ frame = frame.f_back
148
+ code_name = frame.f_code.co_name
149
+ file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
150
+ if code_name == "_compile" and file_name == "convert_frame.py":
151
+ break
152
+ frame = frame.f_locals["frame"]
153
+ assert frame.f_code == old_code
154
+
155
+ if hasattr(frame.f_locals, "self") and frame.f_locals["self"] is not self:
156
+ return
157
+
158
+ # Step2: Save the compiled bytecode
159
+ self.compiled_code = new_code
160
+
161
+ # Step3: Save the decompiled code
162
+ path = self.compile_config.debug_dump_path()
163
+ decompiled_file = os.path.join(path, "decompiled_code.py")
164
+ if os.path.exists(decompiled_file):
165
+ return
166
+ try:
167
+ # usually the decompilation will succeed for most models, as we guarantee a full-graph compilation in Dynamo.
168
+ # but there's no 100% guarantee, since decompliation is not a reversible process.
169
+ from magi_compiler.magi_depyf import decompile as magi_decompile
170
+
171
+ src = magi_decompile(new_code)
172
+ with open(decompiled_file, "w") as f:
173
+ f.write(src)
174
+ magi_logger.info("Dynamo transformed code saved to %s", decompiled_file)
175
+ except Exception:
176
+ pass
177
+
178
+ @contextmanager
179
+ def dispatch_to_compiled_fwd(self, mode: Literal["jit", "aot"] = "jit"):
180
+ """
181
+ Context manager to dispatch to the compiled code.
182
+ Why does this work? Because Dynamo guarantees that the compiled
183
+ bytecode has exactly the same arguments, cell variables, and free
184
+ variables as the original code. Therefore we can directly switch
185
+ the code object in the function and call it.
186
+
187
+ See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
188
+ for more details.
189
+
190
+ NOTE: Why compile `forward` but invoke through `old_call`?
191
+
192
+ In torch.nn.Module, `__call__` wraps `forward` with critical runtime logic:
193
+ - Pre/post forward hooks
194
+ - FSDP parameter sharding/gathering and device placement
195
+
196
+ Our strategy: use this context manager to temporarily replace `self.forward`
197
+ with the compiled version, then invoke `old_call(self, *args, **kwargs)`.
198
+
199
+ This way:
200
+ 1. `old_call` executes hooks and FSDP mechanics normally
201
+ 2. When `old_call` internally calls `self.forward`, it hits our compiled code
202
+ 3. Compiled code runs within the proper FSDP/hook context
203
+
204
+ Calling `self.forward()` directly would bypass FSDP (seeing sharded/invalid
205
+ params) and skip hooks that other components may rely on.
206
+ """
207
+ if mode == "jit":
208
+ assert self.compiled_code is not None
209
+ self.__class__.forward.__code__ = self.compiled_code
210
+ yield
211
+ self.__class__.forward.__code__ = self.original_code_object
212
+ elif mode == "aot":
213
+ assert self.aot_compiled_fn is not None
214
+ old_forward = self.forward
215
+ self.forward = lambda *args, **kwargs: self.aot_compiled_fn(self, *args, **kwargs)
216
+ yield
217
+ self.forward = old_forward
218
+ else:
219
+ raise ValueError(f"Invalid mode: {mode}")
pkgs/MagiCompiler/magi_compiler/magi_depyf/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """magi_depyf — a modern bytecode decompiler and torch.compile inspector."""
16
+
17
+ from .decompile import DecompilationError, Decompiler, decompile, safe_decompile
18
+
19
+ __version__ = "0.1.0"
20
+
21
+ __all__ = ["Decompiler", "decompile", "safe_decompile", "DecompilationError", "__version__"]
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Decompilation: bytecode → Python source, plus recompile/fix/postprocess."""
16
+
17
+ from .decompiler import DecompilationError, Decompiler, decompile, safe_decompile
18
+
19
+ __all__ = ["Decompiler", "decompile", "safe_decompile", "DecompilationError"]
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Bytecode processing — pure Python, no torch dependency."""
16
+
17
+ from .decompile_context import DecompileContext
18
+ from .handler_registry import HandlerRegistry, registry
19
+ from .instruction import Instruction
20
+ from .source_emitter import SourceEmitter
21
+
22
+ __all__ = ["Instruction", "SourceEmitter", "HandlerRegistry", "DecompileContext", "registry"]
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/decompile_context.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """DecompileContext — read-only bag passed to every handler.
16
+
17
+ Handlers receive ``(emitter, inst, ctx)`` — they mutate *emitter*
18
+ and call *ctx* methods but never touch the ``Decompiler`` directly.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from types import CodeType
24
+ from typing import TYPE_CHECKING, Callable, Dict, Tuple
25
+
26
+ if TYPE_CHECKING:
27
+ from .instruction import Instruction
28
+
29
+
30
+ class DecompileContext:
31
+ """Read-only context providing handlers with instructions, code object,
32
+ and the ``decompile_range`` callback for recursive sub-block decompilation."""
33
+
34
+ def __init__(
35
+ self,
36
+ code: CodeType,
37
+ instructions: Tuple["Instruction", ...],
38
+ indentation: int,
39
+ decompile_range: Callable,
40
+ offset_to_index: Dict[int, int],
41
+ ) -> None:
42
+ self.code = code
43
+ self.instructions = instructions
44
+ self.indentation = indentation
45
+ self.decompile_range = decompile_range
46
+ self._offset_to_index = offset_to_index
47
+
48
+ def index_of(self, offset: int) -> int:
49
+ """Return the index of the instruction at *offset* (O(1) lookup)."""
50
+ try:
51
+ return self._offset_to_index[offset]
52
+ except KeyError:
53
+ raise ValueError(f"No instruction at offset {offset}") from None
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handler_registry.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HandlerRegistry — opcode-to-handler dispatch.
16
+
17
+ A *handler* is a plain function with signature::
18
+
19
+ (emitter: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]
20
+
21
+ Returning ``None`` advances to the next instruction.
22
+ Returning an ``int`` jumps to that instruction index.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from typing import TYPE_CHECKING, Callable, List, Optional
28
+
29
+ if TYPE_CHECKING:
30
+ pass
31
+
32
+ HandlerFn = Callable[..., Optional[int]]
33
+
34
+
35
+ class HandlerRegistry:
36
+ """Maps opcode names -> handler functions."""
37
+
38
+ def __init__(self) -> None:
39
+ self._handlers: dict[str, HandlerFn] = {}
40
+
41
+ def register(self, *opnames: str) -> Callable[[HandlerFn], HandlerFn]:
42
+ """Decorator that registers *fn* for one or more opcode names."""
43
+
44
+ def decorator(fn: HandlerFn) -> HandlerFn:
45
+ for name in opnames:
46
+ self._handlers[name] = fn
47
+ return fn
48
+
49
+ return decorator
50
+
51
+ def get(self, opname: str) -> Optional[HandlerFn]:
52
+ return self._handlers.get(opname)
53
+
54
+ def __contains__(self, opname: str) -> bool:
55
+ return opname in self._handlers
56
+
57
+ def supported_opnames(self) -> List[str]:
58
+ return sorted(self._handlers.keys())
59
+
60
+
61
+ # Singleton registry — handlers register against this at import time.
62
+ registry = HandlerRegistry()
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Import every handler module so they register against the global registry."""
16
+
17
+ from . import arithmetic # noqa: F401
18
+ from . import calls # noqa: F401
19
+ from . import containers # noqa: F401
20
+ from . import control_flow # noqa: F401
21
+ from . import load_store # noqa: F401
22
+ from . import stack_ops # noqa: F401
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/arithmetic.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Handlers for unary, binary, inplace, and comparison operations."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from ..decompile_context import DecompileContext
20
+ from ..handler_registry import registry
21
+ from ..instruction import Instruction
22
+ from ..source_emitter import SourceEmitter
23
+
24
+ _reg = registry.register
25
+
26
+ # ── Unary ─────────────────────────────────────────────────────────────────
27
+
28
+ _UNARY_SYMBOLS = {"UNARY_NEGATIVE": "-", "UNARY_POSITIVE": "+", "UNARY_INVERT": "~", "UNARY_NOT": "not"}
29
+
30
+
31
+ @_reg(*_UNARY_SYMBOLS)
32
+ def _unary(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
33
+ em.push(f"({_UNARY_SYMBOLS[inst.opname]} {em.pop()})")
34
+
35
+
36
+ @_reg("GET_LEN")
37
+ def _get_len(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
38
+ em.push(f"len({em.peek()})")
39
+
40
+
41
+ # ── Binary ────────────────────────────────────────────────────────────────
42
+
43
+ _BINARY_SYMBOLS = {
44
+ "BINARY_MULTIPLY": "*",
45
+ "BINARY_ADD": "+",
46
+ "BINARY_SUBTRACT": "-",
47
+ "BINARY_TRUE_DIVIDE": "/",
48
+ "BINARY_FLOOR_DIVIDE": "//",
49
+ "BINARY_MODULO": "%",
50
+ "BINARY_POWER": "**",
51
+ "BINARY_AND": "&",
52
+ "BINARY_OR": "|",
53
+ "BINARY_XOR": "^",
54
+ "BINARY_LSHIFT": "<<",
55
+ "BINARY_RSHIFT": ">>",
56
+ "BINARY_MATRIX_MULTIPLY": "@",
57
+ }
58
+
59
+
60
+ @_reg(*_BINARY_SYMBOLS)
61
+ def _binary(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
62
+ rhs = em.pop()
63
+ lhs = em.pop()
64
+ em.push(f"({lhs} {_BINARY_SYMBOLS[inst.opname]} {rhs})")
65
+
66
+
67
+ @_reg("BINARY_SUBSCR")
68
+ def _subscr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
69
+ rhs = em.pop()
70
+ lhs = em.pop()
71
+ em.push(f"{lhs}[{rhs}]")
72
+
73
+
74
+ @_reg("BINARY_SLICE")
75
+ def _slice(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
76
+ end = em.pop()
77
+ start = em.pop()
78
+ container = em.pop()
79
+ em.push(f"{container}[{start}:{end}]")
80
+
81
+
82
+ # ── Inplace ───────────────────────────────────────────────────────────────
83
+
84
+ _INPLACE_SYMBOLS = {
85
+ "INPLACE_MULTIPLY": "*",
86
+ "INPLACE_ADD": "+",
87
+ "INPLACE_SUBTRACT": "-",
88
+ "INPLACE_TRUE_DIVIDE": "/",
89
+ "INPLACE_FLOOR_DIVIDE": "//",
90
+ "INPLACE_MODULO": "%",
91
+ "INPLACE_POWER": "**",
92
+ "INPLACE_AND": "&",
93
+ "INPLACE_OR": "|",
94
+ "INPLACE_XOR": "^",
95
+ "INPLACE_LSHIFT": "<<",
96
+ "INPLACE_RSHIFT": ">>",
97
+ "INPLACE_MATRIX_MULTIPLY": "@",
98
+ }
99
+
100
+
101
+ @_reg(*_INPLACE_SYMBOLS)
102
+ def _inplace(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
103
+ rhs = em.pop()
104
+ lhs = em.pop()
105
+ em.emit(f"{lhs} {_INPLACE_SYMBOLS[inst.opname]}= {rhs}")
106
+ em.push(lhs)
107
+
108
+
109
+ @_reg("BINARY_OP")
110
+ def _binary_op(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
111
+ """Python 3.12+ unified BINARY_OP."""
112
+ rhs = em.pop()
113
+ lhs = em.pop()
114
+ if "=" in inst.argrepr:
115
+ em.emit(f"{lhs} {inst.argrepr} {rhs}")
116
+ em.push(lhs)
117
+ else:
118
+ em.push(f"({lhs} {inst.argrepr} {rhs})")
119
+
120
+
121
+ # ── Comparison ────────────────────────────────────────────────────────────
122
+
123
+
124
+ @_reg("COMPARE_OP")
125
+ def _compare(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
126
+ rhs = em.pop()
127
+ lhs = em.pop()
128
+ em.push(f"({lhs} {inst.argval} {rhs})")
129
+
130
+
131
+ @_reg("IS_OP")
132
+ def _is_op(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
133
+ rhs = em.pop()
134
+ lhs = em.pop()
135
+ op = "is" if inst.argval == 0 else "is not"
136
+ em.push(f"({lhs} {op} {rhs})")
137
+
138
+
139
+ @_reg("CONTAINS_OP")
140
+ def _contains(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
141
+ rhs = em.pop()
142
+ lhs = em.pop()
143
+ op = "in" if inst.argval == 0 else "not in"
144
+ em.push(f"({lhs} {op} {rhs})")
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/calls.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Handlers for function-call and function-creation opcodes."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import sys
20
+ from typing import Optional
21
+
22
+ from ..decompile_context import DecompileContext
23
+ from ..handler_registry import registry
24
+ from ..instruction import Instruction
25
+ from ..source_emitter import SourceEmitter
26
+
27
+ _reg = registry.register
28
+
29
+
30
+ @_reg("KW_NAMES")
31
+ def _kw_names(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
32
+ # Python 3.11+ instruction that passes keyword argument names to the subsequent CALL.
33
+ # inst.arg indexes into co_consts for the key-name tuple, e.g. ('y', 'z').
34
+ # Push repr so it becomes the string "('y', 'z')"; the CALL handler later eval()s it back to a tuple.
35
+ names = ctx.code.co_consts[inst.arg]
36
+ em.push(repr(names))
37
+
38
+
39
+ @_reg("CALL")
40
+ def _call(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
41
+ """Python 3.11+ unified CALL.
42
+
43
+ 3.12 stack layout: [NULL, callable, arg0, ..., argN-1] (KW_NAMES precedes)
44
+ 3.11 stack layout: [NULL, callable, arg0, ..., argN-1] (KW_NAMES → PRECALL → CALL)
45
+ """
46
+ # Check whether KW_NAMES precedes CALL (indicating keyword arguments exist).
47
+ # 3.12: KW_NAMES → CALL; 3.11: KW_NAMES → PRECALL → CALL
48
+ preceding = [x for x in ctx.instructions if x.offset < inst.offset]
49
+ has_kw = False
50
+ if preceding:
51
+ if preceding[-1].opname == "KW_NAMES" or (
52
+ len(preceding) > 1
53
+ and preceding[-2].opname == "KW_NAMES"
54
+ and preceding[-1].opname == "PRECALL" # 3.11 transitional opcode, removed in 3.12
55
+ ):
56
+ has_kw = True
57
+
58
+ kw_names: tuple = ()
59
+ if has_kw:
60
+ kw_names = eval(em.pop()) # retrieve the tuple stored by KW_NAMES from the stack
61
+ args = [em.pop() for _ in range(inst.argval)][::-1]
62
+ pos_args = args[: len(args) - len(kw_names)]
63
+ kw_args = args[len(args) - len(kw_names) :]
64
+ kwcalls = [f"{n}={v}" for n, v in zip(kw_names, kw_args)]
65
+ func = em.pop()
66
+ # 3.11+ PUSH_NULL / LOAD_GLOBAL(NULL+name) pushes a NULL sentinel before the call.
67
+ # After popping the callable, the top of stack may be NULL (represented as None); clear it.
68
+ if em.stack_size and em.peek() is None:
69
+ em.pop()
70
+ # GET_ITER produces "iter(x)"; if func happens to be "iter(x)" it is actually an argument
71
+ # (e.g. in the next(iter(x)) pattern), and the real callable is further down the stack.
72
+ if "iter(" in str(func):
73
+ pos_args = [func]
74
+ func = em.pop()
75
+ em.push(f"{func}({', '.join(pos_args + kwcalls)})")
76
+ # replace_tos_with_temp: the call result may be referenced multiple times (assignment, passing,
77
+ # method call), so store it in a temp to avoid repeated evaluation and side effects.
78
+ em.replace_tos_with_temp()
79
+
80
+
81
+ @_reg("CALL_FUNCTION", "CALL_METHOD")
82
+ def _call_legacy(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
83
+ """CALL_FUNCTION / CALL_METHOD (Python ≤3.10)."""
84
+ args = [em.pop() for _ in range(inst.argval)][::-1]
85
+ func = em.pop()
86
+ em.push(f"{func}({', '.join(args)})")
87
+ em.replace_tos_with_temp()
88
+
89
+
90
+ @_reg("CALL_FUNCTION_KW")
91
+ def _call_function_kw(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
92
+ kw_args = eval(em.pop())
93
+ kw_vals = [em.pop() for _ in range(len(kw_args))]
94
+ kw_vals.reverse()
95
+ kwcalls = [f"{n}={v}" for n, v in zip(kw_args, kw_vals)]
96
+ pos_args = [em.pop() for _ in range(inst.argval - len(kw_args))][::-1]
97
+ func = em.pop()
98
+ em.push(f"{func}({', '.join(pos_args + kwcalls)})")
99
+ em.replace_tos_with_temp()
100
+
101
+
102
+ @_reg("CALL_FUNCTION_EX")
103
+ def _call_function_ex(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
104
+ # 3.11+ stack: [NULL, func, args (, kwargs)]
105
+ # After popping func, clear the NULL sentinel before pushing the result
106
+ if inst.argval == 0:
107
+ a = em.pop()
108
+ f = em.pop()
109
+ if em.stack_size and em.peek() is None:
110
+ em.pop()
111
+ em.push(f"{f}(*{a})")
112
+ elif inst.argval == 1:
113
+ kw = em.pop()
114
+ a = em.pop()
115
+ f = em.pop()
116
+ if em.stack_size and em.peek() is None:
117
+ em.pop()
118
+ em.push(f"{f}(*{a}, **{kw})")
119
+ em.replace_tos_with_temp()
120
+
121
+
122
+ @_reg("CALL_INTRINSIC_1")
123
+ def _intrinsic_1(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
124
+ """Python 3.12 instruction replacing some internal C-level calls.
125
+ argrepr identifies the specific operation, e.g. INTRINSIC_PRINT, INTRINSIC_UNARY_POSITIVE.
126
+ Most are compiler-internal operations (import *, typealias) rarely triggered by user code."""
127
+ _SKIP = {
128
+ "INTRINSIC_1_INVALID",
129
+ "INTRINSIC_IMPORT_STAR",
130
+ "INTRINSIC_STOPITERATION_ERROR",
131
+ "INTRINSIC_ASYNC_GEN_WRAP",
132
+ "INTRINSIC_TYPEVAR",
133
+ "INTRINSIC_PARAMSPEC",
134
+ "INTRINSIC_TYPEVARTUPLE",
135
+ "INTRINSIC_SUBSCRIPT_GENERIC",
136
+ "INTRINSIC_TYPEALIAS",
137
+ }
138
+ if inst.argrepr in _SKIP:
139
+ return
140
+ if inst.argrepr == "INTRINSIC_PRINT":
141
+ em.emit(f"print({em.pop()})")
142
+ em.push("None")
143
+ elif inst.argrepr == "INTRINSIC_UNARY_POSITIVE":
144
+ em.set_at(0, f"+{em.peek()}")
145
+ elif inst.argrepr == "INTRINSIC_LIST_TO_TUPLE":
146
+ em.push(f"tuple({em.pop()})")
147
+
148
+
149
+ # ── MAKE_FUNCTION ─────────────────────────────────────────────────────────
150
+
151
+
152
+ @_reg("MAKE_FUNCTION")
153
+ def _make_function(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
154
+ """Handle bytecode for def inner(...) and lambda.
155
+
156
+ Bytecode: LOAD_CONST <code_object> → MAKE_FUNCTION → STORE_FAST name
157
+ The handler recursively decompiles the inner code object and emits the full def statement.
158
+ """
159
+ if sys.version_info < (3, 11):
160
+ # 3.10: qualified_name string is still on the stack
161
+ qual_name = em.pop()
162
+ try:
163
+ qual_name = eval(qual_name)
164
+ except Exception:
165
+ pass
166
+ func_name = qual_name.split(".")[-1]
167
+ if "<" in func_name: # <lambda>, <listcomp>, etc. — invalid identifiers
168
+ em.emit(f'"original function name {func_name} is illegal, use a temp name."')
169
+ func_name = em.make_temp()
170
+ else:
171
+ func_name = em.make_temp()
172
+
173
+ code = em.pop() # inner CodeType object pushed by LOAD_CONST
174
+ # argval bit flags indicate whether extra function components remain on the stack
175
+ if inst.argval & 0x08:
176
+ em.pop() # closure tuple (cell references for freevars)
177
+ if inst.argval & 0x04:
178
+ em.pop() # annotations dict
179
+ if inst.argval & 0x02:
180
+ em.pop() # keyword-only defaults tuple
181
+ if inst.argval & 0x01:
182
+ em.pop() # positional defaults tuple
183
+
184
+ # If the next instruction is STORE_FAST, use the target variable name as the function name
185
+ this_idx = ctx.index_of(inst.offset)
186
+ immediately_used = False
187
+ if ctx.instructions[this_idx + 1].opname == "STORE_FAST":
188
+ func_name = ctx.instructions[this_idx + 1].argval
189
+ immediately_used = True
190
+
191
+ # Recurse: create a new Decompiler instance for the inner code object
192
+ from ...decompiler import Decompiler
193
+
194
+ inner = Decompiler(code).decompile(overwrite_fn_name=func_name)
195
+ em.emit_raw(inner)
196
+
197
+ if immediately_used:
198
+ return this_idx + 2 # skip the MAKE_FUNCTION + STORE_FAST pair
199
+ em.push(func_name) # not immediately assigned — push onto stack for later use (e.g. as an argument)
200
+ return None
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/containers.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Handlers for BUILD_*, UNPACK_*, LIST_EXTEND/APPEND, SET_ADD, MAP_ADD,
16
+ FORMAT_VALUE, and BUILD_SLICE / BUILD_STRING."""
17
+
18
+ from __future__ import annotations
19
+
20
+ import sys
21
+
22
+ from ..decompile_context import DecompileContext
23
+ from ..handler_registry import registry
24
+ from ..instruction import Instruction
25
+ from ..source_emitter import SourceEmitter
26
+
27
+ _reg = registry.register
28
+
29
+
30
+ # ── BUILD tuple / list / set ──────────────────────────────────────────────
31
+
32
+
33
+ def _safe_str(val) -> str:
34
+ """Convert a stack value to string, handling None sentinels from PUSH_NULL."""
35
+ return "None" if val is None else str(val)
36
+
37
+
38
+ @_reg("BUILD_TUPLE", "BUILD_TUPLE_UNPACK", "BUILD_TUPLE_UNPACK_WITH_CALL")
39
+ def _build_tuple(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
40
+ args = [_safe_str(em.pop()) for _ in range(inst.argval)][::-1]
41
+ if "UNPACK" in inst.opname:
42
+ args = [f"*{a}" for a in args]
43
+ em.push(f"({args[0]},)" if inst.argval == 1 else f"({', '.join(args)})")
44
+
45
+
46
+ @_reg("BUILD_LIST", "BUILD_LIST_UNPACK")
47
+ def _build_list(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
48
+ args = [_safe_str(em.pop()) for _ in range(inst.argval)][::-1]
49
+ if "UNPACK" in inst.opname:
50
+ args = [f"*{a}" for a in args]
51
+ em.push(f"[{', '.join(args)}]")
52
+ em.replace_tos_with_temp()
53
+
54
+
55
+ @_reg("BUILD_SET", "BUILD_SET_UNPACK")
56
+ def _build_set(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
57
+ if inst.argval == 0:
58
+ em.push("set()")
59
+ else:
60
+ args = [em.pop() for _ in range(inst.argval)][::-1]
61
+ if "UNPACK" in inst.opname:
62
+ args = [f"*{a}" for a in args]
63
+ em.push(f"{{{', '.join(args)}}}")
64
+ em.replace_tos_with_temp()
65
+
66
+
67
+ # ── BUILD map ─────────────────────────────────────────────────────────────
68
+
69
+
70
+ @_reg("BUILD_MAP")
71
+ def _build_map(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
72
+ items = [em.pop() for _ in range(inst.argval * 2)][::-1]
73
+ keys, vals = items[::2], items[1::2]
74
+ em.push(f"{{{', '.join(f'{k}: {v}' for k, v in zip(keys, vals))}}}")
75
+ em.replace_tos_with_temp()
76
+
77
+
78
+ @_reg("BUILD_MAP_UNPACK", "BUILD_MAP_UNPACK_WITH_CALL")
79
+ def _build_map_unpack(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
80
+ if inst.argval == 0:
81
+ em.push("dict()")
82
+ else:
83
+ args = [em.pop() for _ in range(inst.argval)][::-1]
84
+ em.push(f"{{{', '.join(f'**{a}' for a in args)}}}")
85
+ em.replace_tos_with_temp()
86
+
87
+
88
+ @_reg("BUILD_CONST_KEY_MAP")
89
+ def _const_key_map(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
90
+ keys = eval(em.pop())
91
+ vals = [em.pop() for _ in range(inst.argval)][::-1]
92
+ em.push(f"{{{', '.join(f'{k!r}: {v}' for k, v in zip(keys, vals))}}}")
93
+ em.replace_tos_with_temp()
94
+
95
+
96
+ @_reg("BUILD_STRING")
97
+ def _build_string(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
98
+ args = [em.pop() for _ in range(inst.argval)][::-1]
99
+ em.push(" + ".join(args))
100
+
101
+
102
+ @_reg("BUILD_SLICE")
103
+ def _build_slice(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
104
+ tos = em.pop()
105
+ tos1 = em.pop()
106
+ if inst.argval == 2:
107
+ em.push(f"slice({tos1}, {tos})")
108
+ elif inst.argval == 3:
109
+ tos2 = em.pop()
110
+ em.push(f"slice({tos2}, {tos1}, {tos})")
111
+
112
+
113
+ @_reg("LIST_TO_TUPLE")
114
+ def _list_to_tuple(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
115
+ em.push(f"tuple({em.pop()})")
116
+
117
+
118
+ # ── Mutating container ops ────────────────────────────────────────────────
119
+
120
+
121
+ @_reg("LIST_EXTEND")
122
+ def _list_extend(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
123
+ values = em.pop()
124
+ temp = em.replace_tos_with_temp(depth=inst.argval)
125
+ em.emit(f"{temp}.extend({values})")
126
+
127
+
128
+ @_reg("LIST_APPEND")
129
+ def _list_append(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
130
+ argval = inst.argval if inst.argval != 1 else 2
131
+ container = em.stack[-argval]
132
+ value = em.pop()
133
+ em.emit(f"{container}.append({value})")
134
+
135
+
136
+ @_reg("SET_UPDATE", "DICT_UPDATE", "DICT_MERGE")
137
+ def _generic_update(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
138
+ assert inst.argval == 1, "Only tested for argval==1"
139
+ values = em.pop()
140
+ temp = em.replace_tos_with_temp()
141
+ em.emit(f"{temp}.update({values})")
142
+
143
+
144
+ @_reg("SET_ADD")
145
+ def _set_add(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
146
+ argval = inst.argval if inst.argval != 1 else 2
147
+ container = em.stack[-argval]
148
+ value = em.pop()
149
+ em.emit(f"{container}.add({value})")
150
+
151
+
152
+ @_reg("MAP_ADD")
153
+ def _map_add(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
154
+ container = em.stack[-inst.argval - 1]
155
+ if sys.version_info >= (3, 8):
156
+ value = em.pop()
157
+ key = em.pop()
158
+ else:
159
+ key = em.pop()
160
+ value = em.pop()
161
+ em.emit(f"{container}.__setitem__({key}, {value})")
162
+
163
+
164
+ # ── Unpack ────────────────────────────────────────────────────────────────
165
+
166
+
167
+ @_reg("UNPACK_SEQUENCE")
168
+ def _unpack_seq(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
169
+ varname = em.pop()
170
+ tmps = [em.make_temp() for _ in range(inst.argval)]
171
+ em.emit("".join(f"{t}, " for t in tmps) + f"= {varname}")
172
+ for t in reversed(tmps):
173
+ em.push(t)
174
+
175
+
176
+ @_reg("UNPACK_EX")
177
+ def _unpack_ex(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
178
+ varname = em.pop()
179
+ tmps = [em.make_temp() for _ in range(inst.argval)]
180
+ star = em.make_temp()
181
+ em.emit(f"{', '.join(tmps)}, *{star} = {varname}")
182
+ em.push(star)
183
+ for t in reversed(tmps):
184
+ em.push(t)
185
+
186
+
187
+ # ── Format ────────────────────────────────────────────────────────────────
188
+
189
+
190
+ @_reg("FORMAT_VALUE")
191
+ def _format_value(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
192
+ func, spec = inst.argval
193
+ if spec:
194
+ form_spec = em.pop()
195
+ value = em.pop()
196
+ em.push(f"format({value}, {form_spec})")
197
+ else:
198
+ value = em.pop()
199
+ fn = str if func is None else func
200
+ em.push(f"{fn.__name__}({value})")
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/control_flow.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Handlers for control-flow opcodes: jumps, if/else, for, return, yield, raise."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Optional
20
+
21
+ from ..decompile_context import DecompileContext
22
+ from ..handler_registry import registry
23
+ from ..instruction import Instruction
24
+ from ..source_emitter import LoopContext, SourceEmitter
25
+
26
+ _reg = registry.register
27
+
28
+
29
+ # ── Simple returns / yield / raise ────────────────────────────────────────
30
+
31
+
32
+ @_reg("RETURN_VALUE")
33
+ def _return_value(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
34
+ em.emit(f"return {em.peek()}")
35
+ em.pop()
36
+
37
+
38
+ @_reg("RETURN_CONST")
39
+ def _return_const(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
40
+ em.emit(f"return {repr(inst.argval)}")
41
+
42
+
43
+ @_reg("YIELD_VALUE")
44
+ def _yield_value(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
45
+ import sys
46
+
47
+ if sys.version_info >= (3, 12):
48
+ raise NotImplementedError("YIELD_VALUE is not supported in Python 3.12+")
49
+ em.emit(f"yield {em.peek()}")
50
+
51
+
52
+ @_reg("RETURN_GENERATOR")
53
+ def _return_generator(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
54
+ """Python 3.11+ generator function prologue. Each generator has its own stack frame;
55
+ RETURN_GENERATOR creates the generator object and returns it to the caller,
56
+ subsequent next(gen) resumes from RESUME. Push None as a placeholder during decompilation."""
57
+ em.push(None)
58
+
59
+
60
+ @_reg("GEN_START")
61
+ def _gen_start(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
62
+ """Python 3.11 marks generator start (replaced by RESUME in 3.12)."""
63
+ assert inst.argval == 0, "Only generator expression is supported"
64
+
65
+
66
+ @_reg("RAISE_VARARGS")
67
+ def _raise_varargs(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
68
+ if inst.argval == 0:
69
+ em.emit("raise")
70
+ elif inst.argval == 1:
71
+ em.emit(f"raise {em.pop()}")
72
+ elif inst.argval == 2:
73
+ tos = em.pop()
74
+ tos1 = em.pop()
75
+ em.emit(f"raise {tos1} from {tos}")
76
+
77
+
78
+ @_reg("BREAK_LOOP")
79
+ def _break_loop(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
80
+ em.emit("break")
81
+
82
+
83
+ # ── Unconditional jumps ───────────────────────────────────────────────────
84
+
85
+
86
+ @_reg("JUMP_ABSOLUTE")
87
+ @_reg("JUMP_FORWARD")
88
+ @_reg("JUMP_BACKWARD")
89
+ @_reg("JUMP_BACKWARD_NO_INTERRUPT")
90
+ def _abs_jump(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
91
+ """Unconditional jump. Returns len(instructions) to make decompile_range stop immediately."""
92
+ target = inst.jump_target_offset()
93
+ idx = ctx.index_of(target)
94
+ loop = em.loop
95
+ if loop is not None:
96
+ if idx >= loop.end_index:
97
+ em.emit("break")
98
+ return len(ctx.instructions)
99
+ if idx == loop.start_index:
100
+ em.emit("continue")
101
+ return len(ctx.instructions)
102
+ return idx
103
+
104
+
105
+ # ── Conditional jumps (if / else) ─────────────────────────────────────────
106
+
107
+
108
+ @_reg("POP_JUMP_IF_TRUE", "POP_JUMP_IF_FALSE")
109
+ @_reg("POP_JUMP_FORWARD_IF_TRUE", "POP_JUMP_FORWARD_IF_FALSE")
110
+ @_reg("POP_JUMP_BACKWARD_IF_TRUE", "POP_JUMP_BACKWARD_IF_FALSE")
111
+ @_reg("POP_JUMP_FORWARD_IF_NONE", "POP_JUMP_FORWARD_IF_NOT_NONE")
112
+ @_reg("POP_JUMP_BACKWARD_IF_NONE", "POP_JUMP_BACKWARD_IF_NOT_NONE")
113
+ @_reg("JUMP_IF_TRUE_OR_POP", "JUMP_IF_FALSE_OR_POP")
114
+ @_reg("POP_JUMP_IF_NOT_NONE", "POP_JUMP_IF_NONE")
115
+ def _jump_if(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
116
+ """Decompile if/else structure.
117
+
118
+ Standard if/else bytecode:
119
+ POP_JUMP_IF_FALSE else_start ← this_idx
120
+ (if-body)
121
+ JUMP_FORWARD after_else ← last instruction of if-body
122
+ >> else_start: ← jump_idx
123
+ (else-body)
124
+ >> after_else: ← merge point (end)
125
+ """
126
+
127
+ jump_offset = inst.jump_target_offset()
128
+ jump_idx = ctx.index_of(jump_offset)
129
+ this_idx = ctx.index_of(inst.offset)
130
+
131
+ # ── Step 1: condition expression and branch stack state ──
132
+ cond = em.peek()
133
+ fall_stack = list(em.stack)
134
+ jump_stack = list(em.stack)
135
+
136
+ if "IF_NOT_NONE" in inst.opname:
137
+ cond = f"({cond} is None)"
138
+ elif "IF_NONE" in inst.opname:
139
+ cond = f"({cond} is not None)"
140
+ elif "IF_TRUE" in inst.opname:
141
+ cond = f"(not {cond})"
142
+ else:
143
+ cond = f"{cond}"
144
+
145
+ if "POP_JUMP" in inst.opname:
146
+ jump_stack.pop()
147
+ fall_stack.pop()
148
+ elif "OR_POP" in inst.opname:
149
+ fall_stack.pop()
150
+
151
+ # ── Step 2: merge point candidate upper bounds ──
152
+ merge_upper_bounds = [len(ctx.instructions)]
153
+ if em.loop is not None:
154
+ merge_upper_bounds.append(em.loop.end_index)
155
+
156
+ # ── Step 3: find "skip else" JUMPs in the if-body ──
157
+ def _is_forward_past_else(i: Instruction) -> bool:
158
+ return i.is_jump and i.jump_target_offset() >= jump_offset
159
+
160
+ forward_targets = [i.jump_target_offset() for i in ctx.instructions[this_idx:jump_idx] if _is_forward_past_else(i)]
161
+
162
+ # ── Step 4: compute merge point by case ──
163
+ if not forward_targets:
164
+ if jump_idx <= this_idx:
165
+ # Case C: backward jump (inside loop), emit if cond: continue
166
+ rev_cond = em.peek()
167
+ if "IF_NOT_NONE" in inst.opname:
168
+ rev_cond = f"({rev_cond} is not None)"
169
+ elif "IF_NONE" in inst.opname:
170
+ rev_cond = f"({rev_cond} is None)"
171
+ elif "IF_TRUE" in inst.opname:
172
+ rev_cond = f"{rev_cond}"
173
+ elif "IF_FALSE" in inst.opname:
174
+ rev_cond = f"(not {rev_cond})"
175
+ em.emit(f"if {rev_cond}:")
176
+ em.emit(em.indent("continue\n").rstrip("\n"))
177
+ return None
178
+ # Case B: both branches terminate with RETURN/RAISE
179
+ end = jump_idx
180
+ else:
181
+ # Case A: standard if/else, infer merge point from forward_targets
182
+ max_jump = max(forward_targets)
183
+ max_idx = ctx.index_of(max_jump)
184
+ all_targets = [i.jump_target_offset() for i in ctx.instructions[this_idx:max_idx] if _is_forward_past_else(i)]
185
+ max_idx = ctx.index_of(max(all_targets))
186
+
187
+ last = ctx.instructions[max_idx - 1]
188
+ if not ("RAISE" in last.opname or "RETURN" in last.opname or "STORE" in last.opname):
189
+ old = max_idx
190
+ while max_idx < len(ctx.instructions):
191
+ op = ctx.instructions[max_idx].opname
192
+ if "STORE" in op or "RETURN" in op:
193
+ max_idx += 1
194
+ break
195
+ if ("JUMP" in op and max_idx > old) or "FOR_ITER" in op:
196
+ break
197
+ max_idx += 1
198
+
199
+ merge_upper_bounds.append(max_idx)
200
+ end = min(merge_upper_bounds)
201
+
202
+ # ── Step 5: else-body end position (PR#91 fix) ──
203
+ else_end = end
204
+ if end == jump_idx and jump_idx < len(ctx.instructions):
205
+ last_if = ctx.instructions[jump_idx - 1]
206
+ if "RETURN" in last_if.opname or "RAISE" in last_if.opname:
207
+ else_end = len(ctx.instructions)
208
+ if em.loop is not None:
209
+ else_end = min(else_end, em.loop.end_index)
210
+
211
+ # ── Step 6: decompile both branches ──
212
+ with em.fork(stack=fall_stack) as if_em:
213
+ ctx.decompile_range(this_idx + 1, jump_idx, if_em)
214
+ if_body = em.indent(if_em.get_source())
215
+ if_end_stack = list(if_em.stack)
216
+ em.emit_raw(f"if {cond}:\n{if_body}")
217
+
218
+ with em.fork(stack=jump_stack) as else_em:
219
+ ctx.decompile_range(jump_idx, else_end, else_em)
220
+ else_body = else_em.get_source()
221
+ if else_body:
222
+ em.emit_raw(f"else:\n{em.indent(else_body)}")
223
+
224
+ em.stack[:] = if_end_stack
225
+ return else_end
226
+
227
+
228
+ # ── FOR_ITER ──────────────────────────────────────────────────────────────
229
+
230
+
231
+ @_reg("FOR_ITER")
232
+ def _for_iter(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> Optional[int]:
233
+ """Decompile for loop.
234
+
235
+ Bytecode layout (3.12):
236
+ FOR_ITER target ← get next value; jump to target when exhausted (END_FOR)
237
+ (loop body)
238
+ JUMP_BACKWARD for_iter ← normal back-jump (not continue)
239
+ >> target: END_FOR
240
+
241
+ Loop body range excludes the trailing JUMP_BACKWARD to avoid emitting a spurious continue.
242
+ """
243
+ start_idx = ctx.index_of(inst.offset)
244
+ end_idx = ctx.index_of(inst.jump_target_offset())
245
+
246
+ temp = em.make_temp()
247
+ iterator = em.pop()
248
+ em.push(temp)
249
+
250
+ # Determine the actual end position of the loop body:
251
+ # if the instruction at end_idx is a back-jump to FOR_ITER, extend end_idx so
252
+ # the LoopContext boundary is correct (break needs to jump past end_idx)
253
+ if end_idx < len(ctx.instructions):
254
+ at_end = ctx.instructions[end_idx]
255
+ if at_end.is_jump and at_end.jump_target_offset() == inst.offset:
256
+ end_idx += 1
257
+
258
+ # Exclude the trailing JUMP_BACKWARD: it is the normal loop back-jump mechanism, not continue.
259
+ # Only JUMP_BACKWARDs in the middle of the loop body are continue (handled by _abs_jump).
260
+ body_end = end_idx
261
+ if body_end > start_idx + 1:
262
+ back_jump = ctx.instructions[body_end - 1]
263
+ if back_jump.is_jump and back_jump.jump_target_offset() == inst.offset:
264
+ body_end -= 1
265
+
266
+ loop = LoopContext(start_index=start_idx, end_index=end_idx)
267
+ with em.fork(stack=list(em.stack), loop=loop) as body_em:
268
+ ctx.decompile_range(start_idx + 1, body_end, body_em)
269
+
270
+ body_src = em.indent(body_em.get_source())
271
+ em.emit_raw(f"for {temp} in {iterator}:\n{body_src}")
272
+ em.stack[:] = body_em.stack
273
+ return end_idx
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/load_store.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Handlers for LOAD_*, STORE_*, DELETE_*, IMPORT_*, PUSH_NULL, GET_ITER."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from types import CodeType
20
+
21
+ from ..decompile_context import DecompileContext
22
+ from ..handler_registry import registry
23
+ from ..instruction import Instruction
24
+ from ..source_emitter import SourceEmitter
25
+
26
+ _reg = registry.register
27
+
28
+
29
+ # ── NOP / unsupported sentinels ──────────────────────────────────────────
30
+
31
+
32
+ @_reg("NOP", "RESUME", "EXTENDED_ARG", "SETUP_LOOP", "POP_BLOCK")
33
+ @_reg("PRECALL", "BEGIN_FINALLY", "END_FINALLY", "MAKE_CELL")
34
+ @_reg("RERAISE", "END_FOR", "COPY_FREE_VARS")
35
+ def _nop(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
36
+ pass
37
+
38
+
39
+ @_reg("GET_YIELD_FROM_ITER")
40
+ @_reg("POP_EXCEPT", "WITH_EXCEPT_START", "JUMP_IF_NOT_EXC_MATCH")
41
+ @_reg("CHECK_EG_MATCH", "PUSH_EXC_INFO", "PREP_RERAISE_STAR")
42
+ @_reg("WITH_CLEANUP_FINISH", "CALL_FINALLY", "POP_FINALLY")
43
+ @_reg("WITH_CLEANUP_START", "SETUP_EXCEPT", "CHECK_EXC_MATCH")
44
+ @_reg("CLEANUP_THROW")
45
+ @_reg("GET_AWAITABLE", "GET_AITER", "GET_ANEXT", "END_ASYNC_FOR")
46
+ @_reg("BEFORE_ASYNC_WITH", "SETUP_ASYNC_WITH", "SEND", "ASYNC_GEN_WRAP")
47
+ @_reg("CACHE")
48
+ @_reg("PRINT_EXPR", "COPY_DICT_WITHOUT_KEYS")
49
+ @_reg("IMPORT_STAR")
50
+ @_reg("YIELD_FROM", "SETUP_ANNOTATIONS", "LOAD_BUILD_CLASS")
51
+ @_reg("MATCH_MAPPING", "MATCH_SEQUENCE", "MATCH_KEYS", "MATCH_CLASS")
52
+ @_reg("CALL_INTRINSIC_2")
53
+ @_reg("SETUP_FINALLY", "SETUP_WITH", "BEFORE_WITH")
54
+ def _unsupported(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
55
+ from ...decompiler import DecompilationError
56
+
57
+ raise DecompilationError(f"Unsupported opcode: {inst.opname}", instruction=inst)
58
+
59
+
60
+ # ── LOAD instructions ────────────────────────────────────────────────────
61
+
62
+
63
+ @_reg("LOAD_CONST")
64
+ def _load_const(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
65
+ """Load a constant. Branches: can_repr → direct repr / type → importlib /
66
+ torch prefix → import torch / CodeType → push as-is for MAKE_FUNCTION."""
67
+ can_repr = False
68
+ try:
69
+ can_repr = eval(repr(inst.argval)) == inst.argval
70
+ except BaseException:
71
+ pass
72
+ if can_repr:
73
+ em.push(repr(inst.argval))
74
+ elif isinstance(inst.argval, type):
75
+ module = inst.argval.__module__
76
+ name = inst.argval.__name__
77
+ em.emit("import importlib")
78
+ tmp = em.make_temp()
79
+ em.emit(f'{tmp} = importlib.import_module("{module}").{name}')
80
+ em.push(tmp)
81
+ elif inst.argrepr.startswith("torch."):
82
+ em.emit("import torch")
83
+ tmp = em.make_temp()
84
+ em.emit(f"{tmp} = {inst.argval}")
85
+ em.push(tmp)
86
+ elif isinstance(inst.argval, CodeType):
87
+ em.push(inst.argval)
88
+ else:
89
+ from ...decompiler import DecompilationError
90
+
91
+ raise DecompilationError(
92
+ f"LOAD_CONST: cannot represent co_consts[{inst.arg}] = {repr(inst.argval)!r} "
93
+ f"(type {type(inst.argval).__name__}) as source code",
94
+ instruction=inst,
95
+ )
96
+
97
+
98
+ @_reg("LOAD_FAST", "LOAD_FAST_CHECK")
99
+ @_reg("LOAD_GLOBAL", "LOAD_DEREF", "LOAD_NAME")
100
+ @_reg("LOAD_CLASSDEREF", "LOAD_CLOSURE")
101
+ def _generic_load(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
102
+ """Generic load. 3.11+ LOAD_GLOBAL argrepr "NULL + name" pushes a NULL sentinel first.
103
+ Python <3.12 comprehension parameter name ".0" is replaced with "comp_arg_0"."""
104
+ if "NULL + " in inst.argrepr:
105
+ em.push(None)
106
+ if inst.argrepr.startswith("."):
107
+ em.push(inst.argval.replace(".", "comp_arg_"))
108
+ else:
109
+ em.push(inst.argval)
110
+
111
+
112
+ # Python 3.12 comprehension variable protection: LOAD_FAST_AND_CLEAR saves old value + STORE_FAST restores.
113
+ # During decompilation, temp variables used for loops don't need save/restore; push a sentinel so STORE_FAST skips.
114
+ _CLEAR_SENTINEL = object()
115
+
116
+
117
+ @_reg("LOAD_FAST_AND_CLEAR")
118
+ def _load_fast_and_clear(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
119
+ em.push(_CLEAR_SENTINEL)
120
+
121
+
122
+ @_reg("LOAD_LOCALS")
123
+ def _load_locals(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
124
+ """3.12 class body: locals() returns a new dict snapshot, cached in a temp to avoid repeated calls."""
125
+ em.push("locals()")
126
+ em.replace_tos_with_temp()
127
+
128
+
129
+ @_reg("LOAD_FROM_DICT_OR_GLOBALS", "LOAD_FROM_DICT_OR_DEREF")
130
+ def _load_from_dict(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
131
+ """3.12 class body: look up in locals dict first, fall back to globals if not found."""
132
+ tos = em.pop()
133
+ em.push(f"{tos}[{inst.argval}] if '{inst.argval}' in {tos} else {inst.argval}")
134
+ em.replace_tos_with_temp()
135
+
136
+
137
+ @_reg("LOAD_ATTR")
138
+ def _load_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
139
+ """Attribute access. isidentifier() checks if the attr name is valid; if not, use getattr()."""
140
+ lhs = str(em.pop())
141
+ rhs = inst.argval
142
+ em.push(f"{lhs}.{rhs}" if rhs.isidentifier() else f"getattr({lhs}, {rhs!r})")
143
+
144
+
145
+ @_reg("LOAD_SUPER_ATTR")
146
+ def _load_super_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
147
+ self_obj = em.pop()
148
+ cls_obj = em.pop()
149
+ super_obj = em.pop()
150
+ em.push(f"{super_obj}({cls_obj}, {self_obj}).{inst.argval}")
151
+ em.replace_tos_with_temp()
152
+
153
+
154
+ @_reg("LOAD_METHOD")
155
+ def _load_method(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
156
+ em.push(f"{em.pop()}.{inst.argval}")
157
+
158
+
159
+ @_reg("LOAD_ASSERTION_ERROR")
160
+ def _load_assertion_error(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
161
+ em.push("AssertionError")
162
+
163
+
164
+ @_reg("PUSH_NULL")
165
+ def _push_null(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
166
+ """3.11+ pushes a NULL sentinel before function calls; the CALL handler will clear it."""
167
+ em.push(None)
168
+
169
+
170
+ @_reg("GET_ITER")
171
+ def _get_iter(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
172
+ em.push(f"iter({em.pop()})")
173
+
174
+
175
+ # ── STORE instructions ───────────────────────────────────────────────────
176
+
177
+
178
+ @_reg("STORE_FAST", "STORE_GLOBAL", "STORE_DEREF", "STORE_NAME")
179
+ def _generic_store(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
180
+ """Generic store. Skips _CLEAR_SENTINEL and self-assignment, protects variable names on the stack that are about to be overwritten."""
181
+ left = inst.argval
182
+ right = em.pop()
183
+ if right is _CLEAR_SENTINEL:
184
+ return
185
+ if left != right:
186
+ if isinstance(left, str) and left in em.stack:
187
+ tmp = em.make_temp()
188
+ em.emit(f"{tmp} = {left}")
189
+ em.stack[:] = [tmp if x == left else x for x in em.stack]
190
+ em.emit(f"{left} = {right}")
191
+
192
+
193
+ @_reg("STORE_SUBSCR")
194
+ def _store_subscr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
195
+ index = em.pop()
196
+ obj = em.pop()
197
+ value = em.pop()
198
+ em.emit(f"{obj}[{index}] = {value}")
199
+
200
+
201
+ @_reg("STORE_SLICE")
202
+ def _store_slice(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
203
+ end = em.pop()
204
+ start = em.pop()
205
+ container = em.pop()
206
+ value = em.pop()
207
+ em.emit(f"{container}[{start}:{end}] = {value}")
208
+
209
+
210
+ @_reg("STORE_ATTR")
211
+ def _store_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
212
+ obj = em.pop()
213
+ value = em.pop()
214
+ em.emit(f"{obj}.{inst.argval} = {value}")
215
+
216
+
217
+ # ── DELETE instructions ──────────────────────────────────────────────────
218
+
219
+
220
+ @_reg("DELETE_SUBSCR")
221
+ def _delete_subscr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
222
+ index = em.pop()
223
+ obj = em.pop()
224
+ if f"{obj}[{index}]" not in em.stack:
225
+ em.emit(f"del {obj}[{index}]")
226
+
227
+
228
+ @_reg("DELETE_NAME", "DELETE_GLOBAL", "DELETE_DEREF")
229
+ def _generic_delete(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
230
+ em.emit(f"del {inst.argval}")
231
+
232
+
233
+ @_reg("DELETE_FAST")
234
+ def _delete_fast(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
235
+ """Dynamo cleans up temp variables; no explicit del needed after decompilation."""
236
+ pass
237
+
238
+
239
+ @_reg("DELETE_ATTR")
240
+ def _delete_attr(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
241
+ em.emit(f"del {em.pop()}.{inst.argval}")
242
+
243
+
244
+ # ── IMPORT instructions ──────────────────────────────────────────────────
245
+
246
+
247
+ @_reg("IMPORT_NAME")
248
+ def _import_name(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
249
+ """import os.path → binds 'os' (top-level module), accesses submodules via os.path.sep."""
250
+ name = inst.argval.split(".")[0]
251
+ fromlist = em.pop()
252
+ level = em.pop()
253
+ em.emit(f"{name} = __import__({inst.argval!r}, fromlist={fromlist}, level={level})")
254
+ em.push(name)
255
+
256
+
257
+ @_reg("IMPORT_FROM")
258
+ def _import_from(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
259
+ name = inst.argval
260
+ module = em.peek()
261
+ em.emit(f"{name} = {module}.{name}")
262
+ em.push(name)
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/handlers/stack_ops.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Handlers for stack-manipulation opcodes: ROT, SWAP, COPY, POP, DUP.
16
+
17
+ See bytecode_explained.py §16 for details.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from ..decompile_context import DecompileContext
23
+ from ..handler_registry import registry
24
+ from ..instruction import Instruction
25
+ from ..source_emitter import SourceEmitter
26
+
27
+ _reg = registry.register
28
+
29
+
30
+ # ── ROT_N family (Python ≤3.10, replaced by SWAP/COPY in 3.11+) ───────────
31
+
32
+
33
+ @_reg("ROT_N")
34
+ @_reg("ROT_TWO")
35
+ @_reg("ROT_THREE")
36
+ @_reg("ROT_FOUR")
37
+ def _rot_n(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
38
+ """Top-n stack rotation: [a, b, c] → [c, a, b] (n=3).
39
+
40
+ ROT_TWO: a, b → b, a (swap, used for a, b = b, a)
41
+ ROT_THREE: a, b, c → c, a, b (3-element rotation)
42
+ ROT_FOUR: a, b, c, d → d, a, b, c (4-element rotation)
43
+ ROT_N: generic n-element rotation (argval = n)
44
+ """
45
+ n = inst.argval if inst.opname == "ROT_N" else {"ROT_TWO": 2, "ROT_THREE": 3, "ROT_FOUR": 4}[inst.opname]
46
+ vals = em.stack[-n:]
47
+ em.stack[-n:] = [vals[-1]] + vals[:-1]
48
+
49
+
50
+ @_reg("SWAP")
51
+ def _swap(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
52
+ """Python 3.11+: swap stack[-1] and stack[-n]."""
53
+ n = inst.argval
54
+ em.stack[-1], em.stack[-n] = em.stack[-n], em.stack[-1]
55
+
56
+
57
+ @_reg("COPY")
58
+ def _copy(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
59
+ """Python 3.11+: copy stack[-n] to top of stack (COPY 1 = DUP_TOP)."""
60
+ n = inst.argval
61
+ if n == 0:
62
+ return
63
+ em.push(em.stack[-1 - (n - 1)])
64
+
65
+
66
+ @_reg("POP_TOP")
67
+ def _pop_top(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
68
+ if em.stack_size > 0:
69
+ em.pop()
70
+
71
+
72
+ @_reg("DUP_TOP")
73
+ def _dup_top(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
74
+ """Python ≤3.10: duplicate top of stack. Replaced by COPY 1 in 3.11+."""
75
+ em.push(em.peek())
76
+
77
+
78
+ @_reg("DUP_TOP_TWO")
79
+ def _dup_top_two(em: SourceEmitter, inst: Instruction, ctx: DecompileContext) -> None:
80
+ """Python ≤3.10: duplicate top two stack items. Replaced by two COPYs in 3.11+."""
81
+ tos = em.peek(0)
82
+ tos1 = em.peek(1)
83
+ em.push(tos1)
84
+ em.push(tos)
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/instruction.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Enhanced Instruction dataclass with rich querying properties."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import dataclasses
20
+ import dis
21
+ import sys
22
+ from typing import Any, Optional
23
+
24
+ _ALL_JUMP_OPCODES = frozenset(dis.hasjabs) | frozenset(dis.hasjrel)
25
+ _PY311 = sys.version_info >= (3, 11)
26
+
27
+ _LOAD_OPCODES = frozenset(n for n in dis.opname if n.startswith("LOAD_") or n in ("PUSH_NULL", "GET_ITER"))
28
+ _STORE_OPCODES = frozenset(n for n in dis.opname if n.startswith("STORE_"))
29
+ _DELETE_OPCODES = frozenset(n for n in dis.opname if n.startswith("DELETE_"))
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class Instruction:
34
+ """Mutable mirror of ``dis.Instruction`` with convenience queries.
35
+
36
+ Unlike the stdlib version this is mutable so cleanup passes can
37
+ modify instructions in-place (e.g. NOP-ing unreachable bytecode).
38
+ """
39
+
40
+ opcode: int
41
+ opname: str
42
+ # arg: raw integer argument (the number in the bytecode), may be an index into co_consts/co_varnames
43
+ # argval: Python object resolved by the dis module (value of co_consts[arg], or a variable name string)
44
+ # argrepr: human-readable string of argval (e.g. "NULL + print", "to 20")
45
+ # See bytecode_explained.py §1 for details
46
+ arg: Optional[int]
47
+ argval: Any
48
+ argrepr: str
49
+ offset: Optional[int] = None
50
+ starts_line: Optional[int] = None
51
+ is_jump_target: bool = False
52
+
53
+ # -- identity / hashing (by object id, not value) ----------------------
54
+
55
+ def __hash__(self) -> int:
56
+ return id(self)
57
+
58
+ def __eq__(self, other: object) -> bool:
59
+ return self is other
60
+
61
+ def __repr__(self) -> str:
62
+ return f"Instruction({self.opname}, offset={self.offset}, argval={self.argrepr!r})"
63
+
64
+ # -- category queries ---------------------------------------------------
65
+
66
+ @property
67
+ def is_load(self) -> bool:
68
+ return self.opname in _LOAD_OPCODES
69
+
70
+ @property
71
+ def is_store(self) -> bool:
72
+ return self.opname in _STORE_OPCODES
73
+
74
+ @property
75
+ def is_delete(self) -> bool:
76
+ return self.opname in _DELETE_OPCODES
77
+
78
+ @property
79
+ def is_jump(self) -> bool:
80
+ return self.opcode in _ALL_JUMP_OPCODES
81
+
82
+ @property
83
+ def is_conditional_jump(self) -> bool:
84
+ return self.is_jump and ("IF" in self.opname or "FOR_ITER" in self.opname)
85
+
86
+ @property
87
+ def is_unconditional_jump(self) -> bool:
88
+ return self.is_jump and not self.is_conditional_jump
89
+
90
+ @property
91
+ def is_return(self) -> bool:
92
+ return self.opname in ("RETURN_VALUE", "RETURN_CONST")
93
+
94
+ @property
95
+ def is_nop(self) -> bool:
96
+ return self.opname == "NOP"
97
+
98
+ # -- jump target --------------------------------------------------------
99
+
100
+ def jump_target_offset(self) -> Optional[int]:
101
+ """Return the absolute bytecode offset this instruction jumps to,
102
+ or ``None`` if it is not a jump instruction."""
103
+ if not self.is_jump:
104
+ return None
105
+ if "to " in self.argrepr:
106
+ return int(self.argrepr.replace("to ", "").strip())
107
+ if self.opcode in dis.hasjabs:
108
+ return self.argval
109
+ if self.opcode in dis.hasjrel:
110
+ return self.argval if _PY311 else self.offset + self.argval
111
+ return None
112
+
113
+ # -- mutation helpers (for cleanup passes) ------------------------------
114
+
115
+ def nop_(self) -> None:
116
+ """In-place convert this instruction to a NOP."""
117
+ self.opname = "NOP"
118
+ self.opcode = dis.opmap["NOP"]
119
+ self.arg = 0
120
+ self.argval = 0
121
+ self.argrepr = ""
122
+ self.is_jump_target = False
123
+
124
+ # -- factory ------------------------------------------------------------
125
+
126
+ @staticmethod
127
+ def from_dis(i: dis.Instruction) -> "Instruction":
128
+ """Create from a stdlib ``dis.Instruction``."""
129
+ return Instruction(i.opcode, i.opname, i.arg, i.argval, i.argrepr, i.offset, i.starts_line, i.is_jump_target)
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/bytecode/source_emitter.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """SourceEmitter: manages the evaluation stack and source-code emission.
16
+
17
+ This replaces the bare ``DecompilerState`` (just ``source_code: str``
18
+ and ``stack: list``) with a proper class that owns *all* mutable state
19
+ touched during decompilation, including the temp-variable counter
20
+ (instance-level, not class-level — thread-safe by design).
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import contextlib
26
+ import dataclasses
27
+ from typing import Any, Iterator, List, Optional
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class LoopContext:
32
+ """Current loop boundaries, used for break/continue determination.
33
+
34
+ Range semantics similar to range(start, end):
35
+ start_index: index of FOR_ITER itself (inclusive — part of the loop)
36
+ end_index: index of the first instruction outside the loop (exclusive — not part of the loop)
37
+
38
+ break/continue determination (used in _abs_jump):
39
+ jump target >= end_index → break (jump out of the loop)
40
+ jump target == start_index → continue (jump back to loop head)
41
+ """
42
+
43
+ start_index: int # index of FOR_ITER, inclusive (part of the loop)
44
+ end_index: int # first instruction outside the loop, exclusive (not part of the loop)
45
+
46
+
47
+ class SourceEmitter:
48
+ """Stateful accumulator for the decompiler's output.
49
+
50
+ Improvements over depyf's ``DecompilerState``:
51
+ * ``_temp_counter`` is **instance-level** (no thread-safety issues).
52
+ * Stack operations (``push / pop / peek``) are proper methods.
53
+ * ``emit()`` appends with a trailing newline automatically.
54
+ * ``fork()`` context-manager creates a child emitter for sub-blocks
55
+ (if-else branches, loop bodies, etc.) and returns it so the caller
56
+ can inspect the generated source and final stack.
57
+ """
58
+
59
+ def __init__(self, indent_size: int = 4, temp_prefix: str = "__temp_", *, _parent_counter: Optional[list] = None) -> None:
60
+ self._lines: List[str] = []
61
+ self._stack: List[Any] = []
62
+ self._indent_size = indent_size
63
+ self._temp_prefix = temp_prefix
64
+ # Share counter across forks so names are globally unique within
65
+ # one Decompiler invocation, but still instance-scoped.
66
+ self._counter: list = _parent_counter if _parent_counter is not None else [0]
67
+ self.loop: Optional[LoopContext] = None
68
+
69
+ # -- source emission ----------------------------------------------------
70
+
71
+ def emit(self, line: str) -> None:
72
+ """Append *line* (with auto newline) to accumulated source."""
73
+ self._lines.append(line + "\n")
74
+
75
+ def emit_raw(self, text: str) -> None:
76
+ """Append pre-formatted *text* verbatim (e.g. nested function defs)."""
77
+ self._lines.append(text)
78
+
79
+ def get_source(self) -> str:
80
+ return "".join(self._lines)
81
+
82
+ # -- stack operations ---------------------------------------------------
83
+
84
+ def push(self, value: Any) -> None:
85
+ self._stack.append(value)
86
+
87
+ def pop(self) -> Any:
88
+ return self._stack.pop()
89
+
90
+ def peek(self, depth: int = 0) -> Any:
91
+ """Return item at ``stack[-(depth+1)]`` without popping."""
92
+ return self._stack[-(depth + 1)]
93
+
94
+ def set_at(self, depth: int, value: Any) -> None:
95
+ """Set ``stack[-(depth+1)]`` to *value*."""
96
+ self._stack[-(depth + 1)] = value
97
+
98
+ @property
99
+ def stack(self) -> List[Any]:
100
+ """Direct access (for complex multi-item operations)."""
101
+ return self._stack
102
+
103
+ @property
104
+ def stack_size(self) -> int:
105
+ return len(self._stack)
106
+
107
+ # -- temp variables (instance-scoped counter) ---------------------------
108
+
109
+ def make_temp(self) -> str:
110
+ """Return a unique temporary variable name."""
111
+ self._counter[0] += 1
112
+ return f"{self._temp_prefix}{self._counter[0]}"
113
+
114
+ def replace_tos_with_temp(self, depth: int = 1) -> str:
115
+ """Replace ``stack[-depth]`` with a fresh temp, emitting the
116
+ assignment ``__temp_N = <old_value>``. Returns the temp name."""
117
+ old = self._stack[-depth]
118
+ name = self.make_temp()
119
+ self.emit(f"{name} = {old}")
120
+ self._stack[-depth] = name
121
+ return name
122
+
123
+ # -- sub-block forking --------------------------------------------------
124
+
125
+ @contextlib.contextmanager
126
+ def fork(self, stack: Optional[List[Any]] = None, loop: Optional[LoopContext] = None) -> Iterator["SourceEmitter"]:
127
+ """Create a child emitter for a sub-block (if-branch, loop body …).
128
+
129
+ The child shares the temp counter but has its own ``_lines`` and
130
+ ``_stack``. If *loop* is ``None`` the parent's loop context is
131
+ inherited (matching depyf's ``new_state`` semantics).
132
+
133
+ Usage::
134
+
135
+ with emitter.fork(stack=my_stack) as child:
136
+ decompile_range(start, end, child)
137
+ child_source = child.get_source()
138
+ child_final_stack = child.stack
139
+ """
140
+ child = SourceEmitter(indent_size=self._indent_size, temp_prefix=self._temp_prefix, _parent_counter=self._counter)
141
+ child._stack = list(stack) if stack is not None else list(self._stack)
142
+ if loop is not None:
143
+ child.loop = loop
144
+ elif self.loop is not None:
145
+ child.loop = self.loop
146
+ yield child
147
+
148
+ # -- indentation helpers ------------------------------------------------
149
+
150
+ def indent(self, text: str) -> str:
151
+ """Add one level of indentation to every line in *text*."""
152
+ prefix = " " * self._indent_size
153
+ return "".join(prefix + line + "\n" for line in text.splitlines())
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/decompiler.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Decompiler — the orchestrator that ties everything together.
16
+
17
+ This module is the only place that coordinates ``SourceEmitter``,
18
+ ``HandlerRegistry``, and ``DecompileContext``.
19
+ Individual handler functions never import from here (except for
20
+ ``DecompilationError`` and recursive ``Decompiler`` usage in
21
+ ``MAKE_FUNCTION``).
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import dis
27
+ import inspect
28
+ import os
29
+ from types import CodeType
30
+ from typing import Callable, List, Optional, Union
31
+
32
+ # Force handler registration by importing the package.
33
+ import magi_compiler.magi_depyf.decompile.bytecode.handlers # noqa: F401
34
+
35
+ from .bytecode.decompile_context import DecompileContext
36
+ from .bytecode.handler_registry import registry
37
+ from .bytecode.instruction import Instruction
38
+ from .bytecode.source_emitter import SourceEmitter
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Errors
42
+ # ---------------------------------------------------------------------------
43
+
44
+
45
+ class DecompilationError(Exception):
46
+ """Raised when decompilation fails.
47
+
48
+ Carries optional ``instruction`` context so callers can produce
49
+ actionable error messages.
50
+ """
51
+
52
+ def __init__(self, message: str = "", *, instruction: Optional[Instruction] = None):
53
+ self.message = message
54
+ self.instruction = instruction
55
+ super().__init__(message)
56
+
57
+ def __str__(self) -> str:
58
+ loc = ""
59
+ if self.instruction is not None:
60
+ loc = f" at {self.instruction}"
61
+ return f"DecompilationError: {self.message}{loc}"
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Signature builder (lives here — it's a Decompiler concern, not a util)
66
+ # ---------------------------------------------------------------------------
67
+
68
+
69
+ class SignatureBuilder:
70
+ """Build the ``def fn(args):`` header from a ``CodeType``."""
71
+
72
+ @staticmethod
73
+ def build(code: CodeType, overwrite_name: Optional[str] = None) -> str:
74
+ n = code.co_argcount + code.co_kwonlyargcount
75
+ names = [x.replace(".", "comp_arg_") if x.startswith(".") else x for x in code.co_varnames[:n]]
76
+ if code.co_flags & inspect.CO_VARARGS:
77
+ names.append("*" + code.co_varnames[n])
78
+ n += 1
79
+ if code.co_flags & inspect.CO_VARKEYWORDS:
80
+ names.append("**" + code.co_varnames[n])
81
+ n += 1
82
+ fn_name = overwrite_name or code.co_name
83
+ return f"def {fn_name}({', '.join(names)}):\n"
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Decompiler
88
+ # ---------------------------------------------------------------------------
89
+
90
+
91
+ class Decompiler:
92
+ """Decompile a ``CodeType`` into Python source code.
93
+
94
+ Design differences from depyf's ``Decompiler``:
95
+
96
+ * Handlers live in separate modules and receive
97
+ ``(emitter, inst, ctx)`` — they never reference this class.
98
+ * All mutable state is on ``SourceEmitter`` (instance-scoped counter).
99
+ * ``decompile_range`` is delegated *through* ``DecompileContext``
100
+ so handlers can recurse without importing this class (except
101
+ ``MAKE_FUNCTION`` which needs a fresh ``Decompiler`` instance).
102
+ """
103
+
104
+ _TERMINATORS = frozenset({"RETURN_VALUE", "RETURN_CONST", "RAISE_VARARGS"})
105
+
106
+ def __init__(self, code: Union[CodeType, Callable]) -> None:
107
+ if callable(code) and not isinstance(code, CodeType):
108
+ code = _get_code_owner(code).__code__
109
+ self.code: CodeType = code
110
+ self.instructions = [Instruction.from_dis(i) for i in dis.get_instructions(code)]
111
+ self._cleanup()
112
+
113
+ # -- bytecode cleanup ---------------------------------------------------
114
+
115
+ def _cleanup(self) -> None:
116
+ """Propagate line numbers and NOP dead code after unconditional exits."""
117
+ cur: Optional[int] = None
118
+ for inst in self.instructions:
119
+ if inst.starts_line is not None:
120
+ cur = inst.starts_line
121
+ inst.starts_line = cur
122
+
123
+ in_dead = False
124
+ for inst in self.instructions:
125
+ if in_dead:
126
+ if inst.is_jump_target:
127
+ in_dead = False
128
+ else:
129
+ inst.nop_()
130
+ elif inst.opname in self._TERMINATORS:
131
+ in_dead = True
132
+
133
+ # -- core loop ----------------------------------------------------------
134
+
135
+ def decompile_range(self, start: int, end: int, emitter: SourceEmitter) -> None:
136
+ """Execute instruction handlers from *start* to *end* (exclusive)."""
137
+ idx = start
138
+ try:
139
+ while idx < end:
140
+ inst = self.instructions[idx]
141
+ handler = registry.get(inst.opname)
142
+ if handler is None:
143
+ raise DecompilationError(f"No handler for opcode {inst.opname}", instruction=inst)
144
+ ctx = self._make_context(emitter)
145
+ result = handler(emitter, inst, ctx)
146
+ idx = result if result is not None else idx + 1
147
+ except DecompilationError:
148
+ raise
149
+ except Exception as e:
150
+ raise DecompilationError(f"Failed at {inst!r} in {self.code.co_name}", instruction=inst) from e
151
+
152
+ def _make_context(self, emitter: SourceEmitter) -> DecompileContext:
153
+ return DecompileContext(
154
+ code=self.code,
155
+ instructions=tuple(self.instructions),
156
+ indentation=emitter._indent_size,
157
+ decompile_range=lambda start, end, em: self.decompile_range(start, end, em),
158
+ offset_to_index={inst.offset: idx for idx, inst in enumerate(self.instructions)},
159
+ )
160
+
161
+ # -- public API ---------------------------------------------------------
162
+
163
+ def decompile(self, indentation: int = 4, temp_prefix: str = "__temp_", overwrite_fn_name: Optional[str] = None) -> str:
164
+ """Return decompiled Python source code."""
165
+ try:
166
+ emitter = SourceEmitter(indent_size=indentation, temp_prefix=temp_prefix)
167
+ self.decompile_range(0, len(self.instructions), emitter)
168
+ body = emitter.get_source()
169
+
170
+ if os.environ.get("DEPYF_REMOVE_TEMP", "1") == "1":
171
+ from .postprocess import run_all as _postprocess
172
+
173
+ body = _postprocess(body, temp_prefix, indentation)
174
+
175
+ header = SignatureBuilder.build(self.code, overwrite_fn_name)
176
+
177
+ global_names = {i.argval for i in dis.get_instructions(self.code) if i.opname == "STORE_GLOBAL"}
178
+ preamble = ""
179
+ if global_names:
180
+ preamble += "global " + ", ".join(global_names) + "\n"
181
+ if self.code.co_freevars:
182
+ preamble += "nonlocal " + ", ".join(self.code.co_freevars) + "\n"
183
+
184
+ body = preamble + body
185
+ return header + emitter.indent(body)
186
+ except DecompilationError:
187
+ raise
188
+ except Exception as e:
189
+ raise DecompilationError(f"Failed to decompile {self.code.co_name}") from e
190
+
191
+ @staticmethod
192
+ def supported_opnames() -> List[str]:
193
+ return registry.supported_opnames()
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # Module-level convenience
198
+ # ---------------------------------------------------------------------------
199
+
200
+
201
+ def decompile(code: Union[CodeType, Callable]) -> str:
202
+ """One-liner: decompile a code object or callable to source."""
203
+ return Decompiler(code).decompile()
204
+
205
+
206
+ def safe_decompile(code: CodeType) -> str:
207
+ """Decompile *code* without raising; fall back to depyf then placeholder."""
208
+ try:
209
+ return Decompiler(code).decompile()
210
+ except Exception:
211
+ try:
212
+ from depyf import decompile as _depyf_decompile
213
+
214
+ return _depyf_decompile(code)
215
+ except Exception:
216
+ return f"# Failed to decompile {code.co_name}\n"
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # Private helpers
221
+ # ---------------------------------------------------------------------------
222
+
223
+
224
+ def _get_code_owner(fn):
225
+ """Walk through wrappers to find the object that owns ``__code__``."""
226
+ if hasattr(fn, "__func__"):
227
+ return fn.__func__
228
+ if hasattr(fn, "__wrapped__"):
229
+ return _get_code_owner(fn.__wrapped__)
230
+ return fn
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Source-level post-processing pipeline for decompiled code.
16
+
17
+ Each pass is a function ``(source, ...) -> source`` that performs one
18
+ semantics-preserving transformation. ``run_all`` applies them in order.
19
+ All passes are best-effort: on any exception they return the input unchanged.
20
+ """
21
+
22
+ from .branch_dedup import dedup_branch_tails
23
+ from .for_temps import eliminate_for_temps
24
+ from .inline_temps import eliminate_inline_temps
25
+
26
+
27
+ def run_all(source: str, temp_prefix: str = "__temp_", indent: int = 4) -> str:
28
+ """Apply all post-processing passes in sequence."""
29
+ source = eliminate_for_temps(source, temp_prefix, indent)
30
+ source = eliminate_inline_temps(source, temp_prefix, indent)
31
+ source = dedup_branch_tails(source, indent)
32
+ return source
33
+
34
+
35
+ __all__ = ["run_all", "eliminate_for_temps", "eliminate_inline_temps", "dedup_branch_tails"]
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/branch_dedup.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Pass 3: if/else branch tail deduplication.
16
+
17
+ Move identical trailing statements from if/else branches to after the block.
18
+
19
+ Example::
20
+
21
+ if cond: if cond:
22
+ x = 1 x = 1
23
+ return x → else:
24
+ else: x = 2
25
+ x = 2 return x
26
+ return x
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import ast
32
+ from typing import List, Tuple
33
+
34
+ import astor
35
+
36
+
37
+ def dedup_branch_tails(source: str, indent: int = 4) -> str:
38
+ """Move identical trailing statements from if/else branches to after the block."""
39
+ try:
40
+ tree = ast.parse(source)
41
+ changed = False
42
+ for node in ast.walk(tree):
43
+ if hasattr(node, "body") and isinstance(node.body, list):
44
+ new_body, c = _dedup_stmts(node.body)
45
+ if c:
46
+ node.body = new_body
47
+ changed = True
48
+ if not changed:
49
+ return source
50
+ ast.fix_missing_locations(tree)
51
+ return astor.to_source(tree, indent_with=" " * indent)
52
+ except Exception:
53
+ return source
54
+
55
+
56
+ def _dedup_stmts(stmts: List[ast.stmt]) -> Tuple[List[ast.stmt], bool]:
57
+ """Process a statement list, extracting common if/else tails."""
58
+ result: List[ast.stmt] = []
59
+ changed = False
60
+
61
+ for stmt in stmts:
62
+ for attr in ("body", "orelse", "handlers", "finalbody"):
63
+ sub = getattr(stmt, attr, None)
64
+ if isinstance(sub, list) and sub:
65
+ new_sub, c = _dedup_stmts(sub)
66
+ if c:
67
+ setattr(stmt, attr, new_sub)
68
+ changed = True
69
+
70
+ if isinstance(stmt, ast.If) and stmt.orelse:
71
+ n = _common_tail_length(stmt.body, stmt.orelse)
72
+ if n > 0:
73
+ common = stmt.body[-n:]
74
+ stmt.body = stmt.body[:-n] or [ast.Pass()]
75
+ stmt.orelse = stmt.orelse[:-n] or []
76
+ result.append(stmt)
77
+ result.extend(common)
78
+ changed = True
79
+ continue
80
+
81
+ result.append(stmt)
82
+
83
+ return result, changed
84
+
85
+
86
+ def _common_tail_length(body: List[ast.stmt], orelse: List[ast.stmt]) -> int:
87
+ """Count identical trailing statements (by AST dump equality)."""
88
+ count = 0
89
+ i, j = len(body) - 1, len(orelse) - 1
90
+ while i >= 0 and j >= 0:
91
+ if ast.dump(body[i]) == ast.dump(orelse[j]):
92
+ count += 1
93
+ i -= 1
94
+ j -= 1
95
+ else:
96
+ break
97
+ if count >= len(body) or count >= len(orelse):
98
+ count = min(len(body), len(orelse)) - 1
99
+ return max(count, 0)
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/for_temps.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Pass 1: for-loop temp elimination.
16
+
17
+ ``for __temp in iter: var = __temp; ...`` → ``for var in iter: ...``
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import ast
23
+
24
+ import astor
25
+
26
+
27
+ def eliminate_for_temps(source: str, temp_prefix: str = "__temp_", indent: int = 4) -> str:
28
+ """Only applies when the first body statement is a plain assignment
29
+ from the temp to a real variable."""
30
+ try:
31
+ tree = ast.parse(source)
32
+ tree = _ForTempEliminator(temp_prefix).visit(tree)
33
+ ast.fix_missing_locations(tree)
34
+ return astor.to_source(tree, indent_with=" " * indent)
35
+ except Exception:
36
+ return source
37
+
38
+
39
+ class _ForTempEliminator(ast.NodeTransformer):
40
+ def __init__(self, prefix: str):
41
+ self._prefix = prefix
42
+
43
+ def visit_For(self, node: ast.For) -> ast.For:
44
+ self.generic_visit(node)
45
+ if not (
46
+ isinstance(node.target, ast.Name)
47
+ and node.target.id.startswith(self._prefix)
48
+ and node.body
49
+ and isinstance(node.body[0], ast.Assign)
50
+ and len(node.body[0].targets) == 1
51
+ and isinstance(node.body[0].value, ast.Name)
52
+ and node.body[0].value.id == node.target.id
53
+ ):
54
+ return node
55
+ node.target = node.body[0].targets[0]
56
+ node.body = node.body[1:] or [ast.Pass()]
57
+ return node
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/postprocess/inline_temps.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Pass 2: single-use temp inlining.
16
+
17
+ ``__temp = expr; use(__temp)`` → ``use(expr)`` for single-use temps.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import ast
23
+ from collections import defaultdict
24
+ from typing import List, Optional
25
+
26
+ import astor
27
+
28
+
29
+ def eliminate_inline_temps(source: str, temp_prefix: str = "__temp_", indent: int = 4) -> str:
30
+ """Inline single-use temporaries into their use site."""
31
+ try:
32
+ tree = ast.parse(source)
33
+ _set_parents(tree)
34
+
35
+ occurrences: dict[str, list] = defaultdict(list)
36
+ for node in ast.walk(tree):
37
+ if isinstance(node, ast.Name) and node.id.startswith(temp_prefix):
38
+ occurrences[node.id].append(node)
39
+
40
+ _INDENT_NODES = (
41
+ ast.FunctionDef,
42
+ ast.AsyncFunctionDef,
43
+ ast.For,
44
+ ast.AsyncFor,
45
+ ast.While,
46
+ ast.If,
47
+ ast.Try,
48
+ ast.With,
49
+ ast.AsyncWith,
50
+ ast.ClassDef,
51
+ )
52
+
53
+ for name in occurrences:
54
+ occ = occurrences[name]
55
+ if len(occ) == 2:
56
+ n1, n2 = occ
57
+ _, p1, p2 = _lowest_common_parent(n1, n2)
58
+ ap = p1 if isinstance(getattr(n1, "parent", None), ast.Assign) else p2
59
+ can = not isinstance(ap, _INDENT_NODES)
60
+ if can:
61
+ can = _safe_to_inline(tree, n1, n2)
62
+ occ.append(can)
63
+ tree = _RemoveAssign(name, occurrences).visit(tree)
64
+ tree = _InlineTemp(name, occurrences).visit(tree)
65
+
66
+ return astor.to_source(tree, indent_with=" " * indent)
67
+ except Exception:
68
+ return source
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # AST helpers
73
+ # ---------------------------------------------------------------------------
74
+
75
+
76
+ def _set_parents(node: ast.AST, parent: Optional[ast.AST] = None) -> None:
77
+ for child in ast.iter_child_nodes(node):
78
+ child.parent = parent # type: ignore[attr-defined]
79
+ _set_parents(child, child)
80
+
81
+
82
+ def _get_parents(node: ast.AST) -> List[ast.AST]:
83
+ out = []
84
+ while node:
85
+ out.append(node)
86
+ node = getattr(node, "parent", None)
87
+ return out
88
+
89
+
90
+ def _lowest_common_parent(n1: ast.AST, n2: ast.AST):
91
+ p1 = _get_parents(n1)
92
+ p2 = _get_parents(n2)
93
+ p1.reverse()
94
+ p2.reverse()
95
+ last = c1 = c2 = None
96
+ for a, b in zip(p1, p2):
97
+ if a is b:
98
+ last = a
99
+ else:
100
+ c1, c2 = a, b
101
+ break
102
+ return last, c1, c2
103
+
104
+
105
+ def _safe_to_inline(tree: ast.AST, def_node: ast.AST, use_node: ast.AST) -> bool:
106
+ """Verify the RHS variable is not reassigned between definition and use."""
107
+ assign_parent = getattr(def_node, "parent", None)
108
+ if not isinstance(assign_parent, ast.Assign):
109
+ return True
110
+ rhs = assign_parent.value
111
+ if not isinstance(rhs, ast.Name):
112
+ return True
113
+
114
+ rhs_name = rhs.id
115
+ stmts: List[ast.stmt] = []
116
+ for node in ast.walk(tree):
117
+ if hasattr(node, "body") and isinstance(node.body, list):
118
+ stmts = node.body
119
+ break
120
+ try:
121
+ def_idx = next(i for i, s in enumerate(stmts) if s is assign_parent)
122
+ use_stmt = getattr(use_node, "parent", None)
123
+ while use_stmt and use_stmt not in stmts:
124
+ use_stmt = getattr(use_stmt, "parent", None)
125
+ use_idx = next(i for i, s in enumerate(stmts) if s is use_stmt)
126
+ except StopIteration:
127
+ return True
128
+
129
+ for stmt in stmts[def_idx + 1 : use_idx]:
130
+ if isinstance(stmt, ast.Assign):
131
+ for t in stmt.targets:
132
+ if isinstance(t, ast.Name) and t.id == rhs_name:
133
+ return False
134
+ return True
135
+
136
+
137
+ class _RemoveAssign(ast.NodeTransformer):
138
+ def __init__(self, name: str, occ: dict):
139
+ self._name = name
140
+ self._occ = occ
141
+
142
+ def visit_Assign(self, node: ast.Assign):
143
+ if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
144
+ n = node.targets[0].id
145
+ if n == self._name:
146
+ o = self._occ[n]
147
+ if len(o) == 1:
148
+ return ast.Expr(value=node.value)
149
+ if len(o) == 3 and isinstance(o[-1], bool):
150
+ o.append(node.value)
151
+ if o[-2]:
152
+ return None
153
+ return node
154
+
155
+
156
+ class _InlineTemp(ast.NodeTransformer):
157
+ def __init__(self, name: str, occ: dict):
158
+ self._name = name
159
+ self._occ = occ
160
+
161
+ def visit_Name(self, node: ast.Name):
162
+ o = self._occ.get(node.id, [])
163
+ if node.id == self._name and len(o) == 4 and isinstance(o[-2], bool) and o[-2]:
164
+ return o[-1]
165
+ return node
pkgs/MagiCompiler/magi_compiler/magi_depyf/decompile/recompiler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """CodeRecompiler: round-trip decompile -> compile -> extract target CodeType.
16
+
17
+ Pipeline: CodeType -> decompile -> compile -> find target.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from types import CodeType
23
+ from typing import List
24
+
25
+ from .decompiler import Decompiler
26
+
27
+
28
+ class CodeRecompiler:
29
+ """Decompile *code*, recompile, and produce a compatible ``CodeType``."""
30
+
31
+ @staticmethod
32
+ def recompile(
33
+ code_to_decompile: CodeType, reference_code: CodeType, indentation: int = 4, temp_prefix: str = "__temp_"
34
+ ) -> CodeType:
35
+ """Full round-trip: decompile -> compile -> find target."""
36
+ fn_name = reference_code.co_name
37
+
38
+ src = Decompiler(code_to_decompile).decompile(
39
+ indentation=indentation, temp_prefix=temp_prefix, overwrite_fn_name=fn_name
40
+ )
41
+
42
+ compiled = compile(src, "noname", "exec")
43
+ all_codes = CodeRecompiler.collect_code_objects(compiled)
44
+ return [c for c in all_codes if c.co_name == fn_name][0]
45
+
46
+ @staticmethod
47
+ def collect_code_objects(code: CodeType) -> List[CodeType]:
48
+ """Recursively collect all ``CodeType`` objects from *code*."""
49
+ result = [code]
50
+ for c in code.co_consts:
51
+ if isinstance(c, CodeType):
52
+ result.extend(CodeRecompiler.collect_code_objects(c))
53
+ return result
pkgs/MagiCompiler/magi_compiler/magi_depyf/demo_toy_example.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Demo: magi_depyf.dump_src with the depyf tutorial toy_example.
16
+
17
+ Run: PYTHONPATH=. python demo_toy_example.py
18
+ """
19
+ import torch
20
+ from magi_compiler.magi_depyf.inspect import dump_src
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ torch.set_default_device(device)
24
+
25
+
26
+ @torch.compile
27
+ def toy_example(a, b):
28
+ x = a / (torch.abs(a) + 1)
29
+ if b.sum() < 0:
30
+ b = b * -1
31
+ return x * b
32
+
33
+
34
+ def main():
35
+ for _ in range(100):
36
+ toy_example(torch.randn(10), torch.randn(10))
37
+
38
+
39
+ if __name__ == "__main__":
40
+ import os
41
+ import shutil
42
+
43
+ out = "./magi_dump_src_dir"
44
+ if os.path.exists(out):
45
+ shutil.rmtree(out)
46
+ with dump_src(out):
47
+ main()
48
+
49
+ print("\n=== Generated files ===")
50
+ for root, dirs, files in os.walk(out):
51
+ level = root.replace(out, "").count(os.sep)
52
+ print(f"{' ' * level}{os.path.basename(root)}/")
53
+ for f in files:
54
+ print(f"{' ' * (level + 1)}{f}")
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Inspection layer: capture torch.compile events, introspect artifacts, write structured output."""
16
+
17
+ from typing import Optional
18
+
19
+ from .dump_src import dump_src
20
+ from .introspect import Introspector
21
+ from .model import CompiledFnInfo, EntryInfo, FunctionInfo, GuardInfo, GuardNode, SubgraphInfo
22
+ from .result import CaptureResult
23
+ from .session import CaptureSession
24
+ from .writer import FunctionWriter, write_function
25
+
26
+
27
+ def debug_compiled(fn, output_dir: Optional[str] = None) -> FunctionInfo:
28
+ """Introspect a compiled function and optionally write debug output.
29
+
30
+ Args:
31
+ fn: The original (uncompiled) function.
32
+ output_dir: If provided, write organized files to this directory.
33
+
34
+ Returns:
35
+ FunctionInfo with full compilation state.
36
+ """
37
+ info = Introspector.build_function_info(fn)
38
+ if output_dir is not None:
39
+ write_function(info, output_dir)
40
+ return info
41
+
42
+
43
+ __all__ = [
44
+ "CompiledFnInfo",
45
+ "EntryInfo",
46
+ "FunctionInfo",
47
+ "GuardInfo",
48
+ "GuardNode",
49
+ "SubgraphInfo",
50
+ "Introspector",
51
+ "FunctionWriter",
52
+ "write_function",
53
+ "dump_src",
54
+ "debug_compiled",
55
+ "CaptureSession",
56
+ "CaptureResult",
57
+ ]
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/dump_src.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """dump_src: context manager that captures torch.compile artifacts and
16
+ writes structured source output to disk.
17
+
18
+ Usage::
19
+
20
+ from magi_compiler.magi_depyf.inspect import dump_src
21
+
22
+ @torch.compile
23
+ def my_fn(x):
24
+ return x.sum()
25
+
26
+ with dump_src("./output_dir"):
27
+ my_fn(torch.randn(10))
28
+
29
+ Internally uses ``CaptureSession`` to intercept compilation events,
30
+ then runs ``Introspector`` post-hoc for full introspection.
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import contextlib
36
+ from pathlib import Path
37
+ from typing import Set
38
+
39
+ from magi_compiler.utils import magi_logger
40
+
41
+ from .introspect import Introspector
42
+ from .session import CaptureSession
43
+ from .writer import write_function
44
+
45
+
46
+ @contextlib.contextmanager
47
+ def dump_src(dump_src_dir: str):
48
+ """Context manager that captures torch.compile artifacts and writes output.
49
+
50
+ Uses CaptureSession for hook management and post-hoc introspection
51
+ of CacheEntries after execution completes.
52
+ """
53
+ dump_dir = Path(dump_src_dir)
54
+ dump_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ with CaptureSession() as session:
57
+ yield
58
+
59
+ seen: Set[str] = set()
60
+ overview_paths: list[Path] = []
61
+ for r in session.results:
62
+ name = r.original_code.co_name
63
+ if name in seen:
64
+ continue
65
+ if name.startswith("torch_dynamo_resume_in_"):
66
+ continue
67
+ seen.add(name)
68
+
69
+ try:
70
+ info = Introspector.build_function_info(r.original_code, fn_globals=r.fn_globals)
71
+ root = write_function(info, dump_dir)
72
+ overview_paths.append(root / "overview.md")
73
+ except Exception as e:
74
+ magi_logger.warning("[magi_depyf] failed to process '%s': %s", name, e)
75
+
76
+ for p in overview_paths:
77
+ if p.exists():
78
+ magi_logger.info("[magi_depyf] %s", p)
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/introspect.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Runtime introspection of torch.compile artifacts.
16
+
17
+ Walk actual runtime state (CacheEntry chain, guard trees, __compiled_fn
18
+ objects) to build the structured model. All torch imports are lazy so
19
+ this module can be imported without torch.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import io
25
+ from pathlib import Path
26
+ from typing import Any, Dict, Optional
27
+
28
+ from ..decompile import safe_decompile
29
+ from .model import CompiledFnInfo, EntryInfo, FunctionInfo, GuardInfo, GuardNode, SubgraphInfo
30
+
31
+
32
+ class Introspector:
33
+ """Namespace for runtime introspection helpers (all static methods)."""
34
+
35
+ @staticmethod
36
+ def get_cache_entries(fn) -> list:
37
+ """Return CacheEntry list for *fn* (function or code object)."""
38
+ from torch._dynamo.eval_frame import _debug_get_cache_entry_list
39
+
40
+ code = fn.__code__ if hasattr(fn, "__code__") else fn
41
+ return _debug_get_cache_entry_list(code)
42
+
43
+ @staticmethod
44
+ def build_guard_tree(node, max_depth: int = 32, _depth: int = 0) -> GuardNode:
45
+ """Recursively build a GuardNode from a GuardManager C++ object."""
46
+ type_name = type(node).__name__
47
+ leaf_guards = []
48
+ for lg in node.get_leaf_guards():
49
+ for part in lg.verbose_code_parts():
50
+ leaf_guards.append(part.strip()[:120])
51
+ children = []
52
+ if _depth < max_depth:
53
+ for child in node.get_child_managers():
54
+ children.append(Introspector.build_guard_tree(child, max_depth, _depth + 1))
55
+ return GuardNode(type_name=type_name, leaf_guards=leaf_guards, children=children)
56
+
57
+ @staticmethod
58
+ def extract_guard_info(entry) -> Optional[GuardInfo]:
59
+ """Extract structured guard info from a CacheEntry (post-hoc introspection).
60
+
61
+ Operates on the persisted CacheEntry and builds a full GuardNode tree.
62
+ """
63
+ try:
64
+ gm = entry.guard_manager
65
+ tree = Introspector.build_guard_tree(gm.root)
66
+ closure_vars: Dict[str, str] = {}
67
+ if hasattr(gm, "closure_vars") and gm.closure_vars:
68
+ for k, v in list(gm.closure_vars.items())[:10]:
69
+ closure_vars[k] = repr(v)[:100]
70
+ return GuardInfo(tree=tree, closure_vars=closure_vars or None)
71
+ except Exception:
72
+ return None
73
+
74
+ @staticmethod
75
+ def extract_compiled_fn_info(name: str, fn_globals: dict) -> Optional[CompiledFnInfo]:
76
+ """Inspect a __compiled_fn_xxx from fn.__globals__.
77
+
78
+ Handles three backend types:
79
+ eager: wrapper -> closure[0]=GraphModule.forward (bound method)
80
+ inductor: wrapper -> closure[0]=aot_forward -> ... -> CompiledFxGraph
81
+ magi_compile: MagiSerializableFunction -> split_gm -> PiecewiseBackend(s)
82
+ """
83
+ obj = fn_globals.get(name)
84
+ if obj is None:
85
+ return None
86
+
87
+ magi_info = Introspector._try_extract_magi_info(name, obj)
88
+ if magi_info is not None:
89
+ return magi_info
90
+
91
+ info = CompiledFnInfo(name=name, backend="eager")
92
+
93
+ gm = Introspector._find_graph_module(obj)
94
+ if gm is not None:
95
+ Introspector._fill_graph_module_info(info, gm)
96
+
97
+ cfx = Introspector._find_compiled_fx_graph(obj)
98
+ if cfx is not None:
99
+ info.backend = "inductor"
100
+ Introspector._fill_compiled_fx_graph_info(info, cfx)
101
+
102
+ return info
103
+
104
+ @staticmethod
105
+ def _fill_graph_module_info(info: CompiledFnInfo, gm) -> None:
106
+ try:
107
+ info.readable_code = gm.print_readable(print_output=False)
108
+ except Exception:
109
+ pass
110
+ try:
111
+ info.graph_module_code = str(gm.code) if hasattr(gm, "code") else None
112
+ except Exception:
113
+ pass
114
+ try:
115
+ buf = io.StringIO()
116
+ gm.graph.print_tabular(file=buf)
117
+ info.fx_graph_tabular = buf.getvalue()
118
+ except Exception:
119
+ pass
120
+
121
+ @staticmethod
122
+ def _fill_compiled_fx_graph_info(info: CompiledFnInfo, cfx) -> None:
123
+ try:
124
+ info.source_code = cfx.source_code
125
+ except Exception:
126
+ pass
127
+ try:
128
+ info.inductor_post_grad_graph = cfx.inductor_post_grad_graph_str
129
+ except Exception:
130
+ pass
131
+ try:
132
+ info.cache_key = cfx.cache_key
133
+ except Exception:
134
+ pass
135
+ try:
136
+ info.runnable_graph_str = cfx.runnable_graph_str
137
+ except Exception:
138
+ pass
139
+
140
+ # -- Magi backend introspection ----------------------------------------
141
+
142
+ @staticmethod
143
+ def _try_extract_magi_info(name: str, obj) -> Optional[CompiledFnInfo]:
144
+ """Detect MagiSerializableFunction and walk its hierarchy.
145
+
146
+ MagiSerializableFunction hierarchy:
147
+ .graph_module → fx.GraphModule (full graph before splitting)
148
+ .optimized_call → split_gm (fx.GraphModule with PiecewiseBackend submodules)
149
+ .submod_N → PiecewiseBackend
150
+ .graph → fx.GraphModule (the subgraph)
151
+ .compiled_graph_for_general_shape → inductor compiled output
152
+
153
+ Dynamo wraps the backend result in a DisableContext closure, so the
154
+ MagiSerializableFunction may live one level deep in the closure chain.
155
+ """
156
+ msf = obj if (hasattr(obj, "graph_module") and hasattr(obj, "optimized_call")) else None
157
+ if msf is None and callable(obj) and getattr(obj, "__closure__", None):
158
+ for cell in obj.__closure__:
159
+ try:
160
+ val = cell.cell_contents
161
+ except ValueError:
162
+ continue
163
+ if hasattr(val, "graph_module") and hasattr(val, "optimized_call"):
164
+ msf = val
165
+ break
166
+ if msf is None:
167
+ return None
168
+ obj = msf
169
+
170
+ import torch.fx
171
+
172
+ info = CompiledFnInfo(name=name, backend="magi_compile")
173
+
174
+ full_gm = getattr(obj, "graph_module", None)
175
+ if isinstance(full_gm, torch.fx.GraphModule):
176
+ Introspector._fill_graph_module_info(info, full_gm)
177
+
178
+ split_gm = getattr(obj, "optimized_call", None)
179
+
180
+ # In FULL cudagraph mode, optimized_call is a wrapper function whose
181
+ # __dict__ carries the GraphModule's attributes (via __dict__.update).
182
+ # Unwrap to find the actual GraphModule for print_readable / named_children.
183
+ actual_gm = split_gm if isinstance(split_gm, torch.fx.GraphModule) else None
184
+ if actual_gm is None and split_gm is not None:
185
+ actual_gm = Introspector._find_graph_module_deep(split_gm)
186
+
187
+ info.cudagraph_mode = Introspector._detect_cudagraph_mode(split_gm, actual_gm)
188
+
189
+ if actual_gm is not None:
190
+ try:
191
+ info.split_graph_readable = actual_gm.print_readable(print_output=False)
192
+ except Exception:
193
+ pass
194
+
195
+ # PiecewiseCompileInterpreter replaces submodules via __dict__,
196
+ # so named_children() still sees the original GraphModules while
197
+ # __dict__ contains the PiecewiseBackend (or cudagraph wrapper).
198
+ # In FULL cudagraph mode, those __dict__ entries are copied onto
199
+ # the wrapper function, so we look up runtime objects from
200
+ # split_gm (the wrapper) rather than actual_gm.
201
+ runtime_source = split_gm if split_gm is not None else actual_gm
202
+ for sub_name, original_gm in actual_gm.named_children():
203
+ runtime_obj = runtime_source.__dict__.get(sub_name, original_gm)
204
+ sg_info = Introspector._extract_subgraph_info(sub_name, runtime_obj, original_gm)
205
+ if sg_info is not None:
206
+ info.subgraph_infos.append(sg_info)
207
+
208
+ info.subgraph_infos.sort(key=lambda s: s.name)
209
+
210
+ return info
211
+
212
+ @staticmethod
213
+ def _extract_subgraph_info(sub_name: str, runtime_obj, original_gm=None) -> Optional[SubgraphInfo]:
214
+ """Extract info from one submodule of the split graph.
215
+
216
+ Args:
217
+ sub_name: The submodule name (e.g. "submod_0").
218
+ runtime_obj: The actual runtime object — PiecewiseBackend,
219
+ cudagraph wrapper, or the original GraphModule.
220
+ original_gm: The original GraphModule before replacement (from _modules).
221
+ """
222
+ import torch.fx
223
+
224
+ piecewise = Introspector._unwrap_piecewise_backend(runtime_obj)
225
+
226
+ if piecewise is not None:
227
+ sg = SubgraphInfo(name=sub_name, is_splitting_graph=False)
228
+ inner_gm = getattr(piecewise, "graph", None)
229
+ if isinstance(inner_gm, torch.fx.GraphModule):
230
+ Introspector._fill_subgraph_gm_info(sg, inner_gm)
231
+
232
+ compiled = getattr(piecewise, "compiled_graph_for_general_shape", None)
233
+ if compiled is not None:
234
+ sg.inductor_code = Introspector._try_extract_inductor_source(compiled)
235
+
236
+ if sg.inductor_code is None:
237
+ sg.inductor_code = Introspector._read_artifact_source_from_piecewise(piecewise)
238
+
239
+ return sg
240
+
241
+ gm = original_gm if isinstance(original_gm, torch.fx.GraphModule) else None
242
+ if gm is None and isinstance(runtime_obj, torch.fx.GraphModule):
243
+ gm = runtime_obj
244
+
245
+ if gm is not None:
246
+ sg = SubgraphInfo(name=sub_name, is_splitting_graph=True)
247
+ Introspector._fill_subgraph_gm_info(sg, gm)
248
+ return sg
249
+
250
+ return None
251
+
252
+ @staticmethod
253
+ def _unwrap_piecewise_backend(obj):
254
+ """Find a PiecewiseBackend from obj, unwrapping closures/wrappers if needed."""
255
+ if hasattr(obj, "graph") and hasattr(obj, "compiled_graph_for_general_shape"):
256
+ return obj
257
+
258
+ if callable(obj) and hasattr(obj, "__closure__") and obj.__closure__:
259
+ for cell in obj.__closure__:
260
+ try:
261
+ val = cell.cell_contents
262
+ except ValueError:
263
+ continue
264
+ if hasattr(val, "graph") and hasattr(val, "compiled_graph_for_general_shape"):
265
+ return val
266
+ return None
267
+
268
+ @staticmethod
269
+ def _fill_subgraph_gm_info(sg: SubgraphInfo, gm) -> None:
270
+ try:
271
+ sg.readable_code = gm.print_readable(print_output=False)
272
+ except Exception:
273
+ pass
274
+ try:
275
+ sg.graph_module_code = str(gm.code) if hasattr(gm, "code") else None
276
+ except Exception:
277
+ pass
278
+ try:
279
+ buf = io.StringIO()
280
+ gm.graph.print_tabular(file=buf)
281
+ sg.fx_graph_tabular = buf.getvalue()
282
+ except Exception:
283
+ pass
284
+
285
+ @staticmethod
286
+ def _try_extract_inductor_source(compiled) -> Optional[str]:
287
+ """Try to extract inductor kernel source from a compiled graph object.
288
+
289
+ Handles CompiledFxGraph, CompiledArtifact, and closure-wrapped variants.
290
+ """
291
+ for attr in ("source_code", "_source_code"):
292
+ val = getattr(compiled, attr, None)
293
+ if isinstance(val, str) and val:
294
+ return val
295
+
296
+ cfx = Introspector._find_compiled_fx_graph(compiled)
297
+ if cfx is not None:
298
+ try:
299
+ return cfx.source_code
300
+ except Exception:
301
+ pass
302
+
303
+ if hasattr(compiled, "print_readable"):
304
+ try:
305
+ return compiled.print_readable(print_output=False)
306
+ except Exception:
307
+ pass
308
+
309
+ return None
310
+
311
+ @staticmethod
312
+ def _read_artifact_source_from_piecewise(piecewise) -> Optional[str]:
313
+ """Read Inductor-generated source from the saved artifact directory.
314
+
315
+ PiecewiseBackend stores a compiler_manager whose cache maps
316
+ CacheEntry(runtime_shape, graph_index, backend_name) → CacheHandle(key, path).
317
+ The artifact at CacheHandle.path is an unpacked directory containing
318
+ ``py/*.py`` — the full Inductor output code.
319
+ """
320
+ try:
321
+ compiler_manager = getattr(piecewise, "compiler_manager", None)
322
+ if compiler_manager is None:
323
+ return None
324
+ cache = getattr(compiler_manager, "cache", None)
325
+ if not cache:
326
+ return None
327
+ index = getattr(piecewise, "piecewise_compile_index", None)
328
+ if index is None:
329
+ return None
330
+
331
+ for cache_entry, cache_handle in cache.items():
332
+ if cache_entry.graph_index == index and cache_entry.runtime_shape is None:
333
+ artifact_path = getattr(cache_handle, "path", None)
334
+ if artifact_path:
335
+ return Introspector._read_py_from_artifact(artifact_path)
336
+ return None
337
+ except Exception:
338
+ return None
339
+
340
+ @staticmethod
341
+ def _read_py_from_artifact(artifact_path: str) -> Optional[str]:
342
+ """Read the Inductor-generated Python wrapper from an artifact directory.
343
+
344
+ The unpacked artifact layout varies across PyTorch versions; the
345
+ wrapper ``.py`` file has been observed under ``yb/`` and ``py/``.
346
+ We try known directories first, then fall back to scanning all
347
+ immediate subdirectories.
348
+ """
349
+ root = Path(artifact_path)
350
+ if not root.is_dir():
351
+ return None
352
+
353
+ for candidate_dir in ("yb", "py"):
354
+ d = root / candidate_dir
355
+ if d.is_dir():
356
+ py_files = sorted(d.glob("*.py"))
357
+ if py_files:
358
+ try:
359
+ return py_files[0].read_text(encoding="utf-8")
360
+ except Exception:
361
+ pass
362
+
363
+ for d in sorted(root.iterdir()):
364
+ if d.is_dir():
365
+ py_files = sorted(d.glob("*.py"))
366
+ if py_files:
367
+ try:
368
+ return py_files[0].read_text(encoding="utf-8")
369
+ except Exception:
370
+ pass
371
+ return None
372
+
373
+ @staticmethod
374
+ def _detect_cudagraph_mode(split_gm, actual_gm) -> str:
375
+ """Detect cudagraph wrapping mode from the split graph structure.
376
+
377
+ - FULL: split_gm itself is a cudagraph wrapper (not a GraphModule),
378
+ with __qualname__ containing 'Athena_CUDAGraph_full'.
379
+ - PIECEWISE: split_gm is a GraphModule, but its __dict__ submodules are
380
+ cudagraph wrappers with __qualname__ 'Athena_CUDAGraph_piecewise'.
381
+ - NONE: no cudagraph wrapping detected.
382
+ """
383
+ _CG_PREFIX = "Athena_CUDAGraph_"
384
+
385
+ qualname = getattr(split_gm, "__qualname__", "") or ""
386
+ if qualname.startswith(f"{_CG_PREFIX}full"):
387
+ return "FULL"
388
+
389
+ if actual_gm is not None:
390
+ for key, val in actual_gm.__dict__.items():
391
+ if not key.startswith("submod_"):
392
+ continue
393
+ sub_qualname = getattr(val, "__qualname__", "") or ""
394
+ if sub_qualname.startswith(f"{_CG_PREFIX}piecewise"):
395
+ return "PIECEWISE"
396
+
397
+ return "NONE"
398
+
399
+ @staticmethod
400
+ def _find_graph_module_deep(obj, _depth: int = 0, _max_depth: int = 4) -> Optional[Any]:
401
+ """Recursively walk closure chain to find a ``torch.fx.GraphModule``.
402
+
403
+ This is needed for FULL cudagraph mode where the split GraphModule is
404
+ wrapped by ``gen_wrap_func_for_cudagraph`` (+ ``@instrument_nvtx``),
405
+ placing the GraphModule 2-3 levels deep in the closure chain.
406
+ """
407
+ import torch.fx
408
+
409
+ if isinstance(obj, torch.fx.GraphModule):
410
+ return obj
411
+ if _depth >= _max_depth:
412
+ return None
413
+ if not callable(obj) or not getattr(obj, "__closure__", None):
414
+ return None
415
+ for cell in obj.__closure__:
416
+ try:
417
+ val = cell.cell_contents
418
+ except ValueError:
419
+ continue
420
+ if isinstance(val, torch.fx.GraphModule):
421
+ return val
422
+ if callable(val):
423
+ found = Introspector._find_graph_module_deep(val, _depth + 1, _max_depth)
424
+ if found is not None:
425
+ return found
426
+ return None
427
+
428
+ @staticmethod
429
+ def _find_graph_module(obj) -> Optional[Any]:
430
+ """Walk closure chain to find a torch.fx.GraphModule."""
431
+ import torch.fx
432
+
433
+ if isinstance(obj, torch.fx.GraphModule):
434
+ return obj
435
+ if hasattr(obj, "__self__") and isinstance(obj.__self__, torch.fx.GraphModule):
436
+ return obj.__self__
437
+ if not callable(obj) or not hasattr(obj, "__closure__") or not obj.__closure__:
438
+ return None
439
+ for cell in obj.__closure__:
440
+ try:
441
+ val = cell.cell_contents
442
+ except ValueError:
443
+ continue
444
+ if isinstance(val, torch.fx.GraphModule):
445
+ return val
446
+ if hasattr(val, "__self__") and isinstance(val.__self__, torch.fx.GraphModule):
447
+ return val.__self__
448
+ return None
449
+
450
+ @staticmethod
451
+ def _find_compiled_fx_graph(obj, _depth: int = 0) -> Optional[Any]:
452
+ """Walk closure chain (up to 4 levels) to find a CompiledFxGraph."""
453
+ if _depth > 4:
454
+ return None
455
+ try:
456
+ from torch._inductor.codecache import CompiledFxGraph
457
+ except ImportError:
458
+ return None
459
+ if isinstance(obj, CompiledFxGraph):
460
+ return obj
461
+ if not callable(obj) or not hasattr(obj, "__closure__") or not obj.__closure__:
462
+ return None
463
+ for cell in obj.__closure__:
464
+ try:
465
+ val = cell.cell_contents
466
+ except ValueError:
467
+ continue
468
+ if isinstance(val, CompiledFxGraph):
469
+ return val
470
+ if callable(val):
471
+ found = Introspector._find_compiled_fx_graph(val, _depth + 1)
472
+ if found is not None:
473
+ return found
474
+ return None
475
+
476
+ @staticmethod
477
+ def build_entry_info(entry, index: int, fn_globals: dict) -> EntryInfo:
478
+ """Build an EntryInfo from a CacheEntry."""
479
+ tc = entry.code
480
+ decompiled = safe_decompile(tc)
481
+
482
+ compiled_names = [n for n in tc.co_names if n.startswith("__compiled")]
483
+ compiled_fns = []
484
+ for cn in compiled_names:
485
+ cf = Introspector.extract_compiled_fn_info(cn, fn_globals)
486
+ if cf:
487
+ compiled_fns.append(cf)
488
+
489
+ resume_names = [n for n in tc.co_names if n.startswith("__resume")]
490
+ resume_fns = []
491
+ for rn in resume_names:
492
+ rfn = fn_globals.get(rn)
493
+ if rfn is not None and hasattr(rfn, "__code__"):
494
+ resume_info = Introspector.build_function_info(rfn, fn_globals=fn_globals)
495
+ resume_info.name = rn
496
+ resume_fns.append(resume_info)
497
+
498
+ guard = Introspector.extract_guard_info(entry)
499
+
500
+ return EntryInfo(
501
+ index=index,
502
+ dynamo_code=tc,
503
+ decompiled_source=decompiled,
504
+ guard=guard,
505
+ compiled_fns=compiled_fns,
506
+ resume_fns=resume_fns,
507
+ )
508
+
509
+ @staticmethod
510
+ def build_function_info(fn, fn_globals: Optional[dict] = None) -> FunctionInfo:
511
+ """Build full FunctionInfo by walking CacheEntry chain."""
512
+ if fn_globals is None:
513
+ fn_globals = fn.__globals__ if hasattr(fn, "__globals__") else {}
514
+
515
+ code = fn.__code__ if hasattr(fn, "__code__") else fn
516
+ name = code.co_name
517
+ original_source = safe_decompile(code)
518
+
519
+ entries_raw = Introspector.get_cache_entries(fn)
520
+ entries = []
521
+ for i, raw_entry in enumerate(entries_raw):
522
+ entries.append(Introspector.build_entry_info(raw_entry, i, fn_globals))
523
+
524
+ return FunctionInfo(name=name, original_code=code, original_source=original_source, entries=entries)
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/model.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Data model for structured compilation output.
16
+
17
+ These dataclasses represent the full compilation state that
18
+ torch.compile produces, organized to reflect the actual runtime
19
+ structure: CacheEntry linked list, fn/resume recursion,
20
+ compiled_fn → backend mapping, and guard trees.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import dataclasses
26
+ import dis
27
+ import inspect
28
+ import io
29
+ from types import CodeType
30
+ from typing import Dict, List, Optional
31
+
32
+
33
+ def format_code_info(code: CodeType) -> str:
34
+ """Format key attributes of a CodeType for debugging."""
35
+ lines: List[str] = []
36
+ lines.append(f"co_name: {code.co_name}")
37
+ if hasattr(code, "co_qualname"):
38
+ lines.append(f"co_qualname: {code.co_qualname}")
39
+ lines.append(f"co_filename: {code.co_filename}")
40
+ lines.append(f"co_firstlineno: {code.co_firstlineno}")
41
+ lines.append(f"co_argcount: {code.co_argcount}")
42
+ lines.append(f"co_kwonlyargcount:{code.co_kwonlyargcount}")
43
+ lines.append(f"co_varnames: {code.co_varnames}")
44
+ lines.append(f"co_freevars: {code.co_freevars}")
45
+ lines.append(f"co_cellvars: {code.co_cellvars}")
46
+ lines.append(f"co_names: {code.co_names}")
47
+ flags = code.co_flags
48
+ flag_strs = [name for name, val in _CODE_FLAGS.items() if flags & val]
49
+ lines.append(f"co_flags: 0x{flags:04x} ({' | '.join(flag_strs) if flag_strs else 'none'})")
50
+ lines.append(f"co_stacksize: {code.co_stacksize}")
51
+ lines.append("")
52
+ lines.append("co_consts:")
53
+ for i, c in enumerate(code.co_consts):
54
+ lines.append(f" [{i:3d}] {type(c).__name__:12s} {_safe_repr(c)}")
55
+ lines.append("")
56
+ lines.append("dis:")
57
+ buf = io.StringIO()
58
+ dis.dis(code, file=buf)
59
+ lines.append(buf.getvalue())
60
+ return "\n".join(lines)
61
+
62
+
63
+ _CODE_FLAGS = {
64
+ "CO_OPTIMIZED": inspect.CO_OPTIMIZED,
65
+ "CO_NEWLOCALS": inspect.CO_NEWLOCALS,
66
+ "CO_VARARGS": inspect.CO_VARARGS,
67
+ "CO_VARKEYWORDS": inspect.CO_VARKEYWORDS,
68
+ "CO_NESTED": inspect.CO_NESTED,
69
+ "CO_GENERATOR": inspect.CO_GENERATOR,
70
+ "CO_COROUTINE": inspect.CO_COROUTINE,
71
+ "CO_ASYNC_GENERATOR": inspect.CO_ASYNC_GENERATOR,
72
+ }
73
+
74
+
75
+ def _safe_repr(obj, max_len: int = 120) -> str:
76
+ try:
77
+ r = repr(obj)
78
+ except Exception:
79
+ r = f"<repr failed: {type(obj).__name__}>"
80
+ if len(r) > max_len:
81
+ r = r[: max_len - 3] + "..."
82
+ return r
83
+
84
+
85
+ @dataclasses.dataclass
86
+ class GuardNode:
87
+ """One node in the guard tree (mirrors RootGuardManager / GuardManager)."""
88
+
89
+ type_name: str
90
+ leaf_guards: List[str]
91
+ children: List["GuardNode"] = dataclasses.field(default_factory=list)
92
+
93
+ def format(self, depth: int = 0, max_depth: int = 32) -> str:
94
+ prefix = " " * depth
95
+ lines = [f"{prefix}[{self.type_name}] " f"({len(self.leaf_guards)} leaf guards, {len(self.children)} children)"]
96
+ for g in self.leaf_guards:
97
+ lines.append(f"{prefix} LEAF: {g}")
98
+ if depth < max_depth:
99
+ for i, child in enumerate(self.children):
100
+ lines.append(f"{prefix} child[{i}]:")
101
+ lines.append(child.format(depth + 2, max_depth))
102
+ elif self.children:
103
+ lines.append(f"{prefix} ... ({len(self.children)} children omitted)")
104
+ return "\n".join(lines)
105
+
106
+
107
+ @dataclasses.dataclass
108
+ class SubgraphInfo:
109
+ """One piecewise subgraph in the magi split pipeline."""
110
+
111
+ name: str
112
+ is_splitting_graph: bool = False
113
+ readable_code: Optional[str] = None
114
+ graph_module_code: Optional[str] = None
115
+ fx_graph_tabular: Optional[str] = None
116
+ inductor_code: Optional[str] = None
117
+
118
+ def format(self) -> str:
119
+ if self.inductor_code:
120
+ return self.inductor_code
121
+ if self.readable_code:
122
+ return self.readable_code
123
+ if self.graph_module_code:
124
+ return self.graph_module_code
125
+ tag = "splitting_op" if self.is_splitting_graph else "compiled"
126
+ return f"# {self.name} ({tag})\n"
127
+
128
+
129
+ @dataclasses.dataclass
130
+ class CompiledFnInfo:
131
+ """What __compiled_fn_xxx actually points to in the backend."""
132
+
133
+ name: str
134
+ backend: str # "eager", "inductor", or "magi_compile"
135
+ cudagraph_mode: Optional[str] = None # "NONE", "PIECEWISE", "FULL" (magi_compile only)
136
+ readable_code: Optional[str] = None
137
+ graph_module_code: Optional[str] = None
138
+ fx_graph_tabular: Optional[str] = None
139
+ source_code: Optional[str] = None
140
+ inductor_post_grad_graph: Optional[str] = None
141
+ runnable_graph_str: Optional[str] = None
142
+ cache_key: Optional[str] = None
143
+ split_graph_readable: Optional[str] = None
144
+ subgraph_infos: List["SubgraphInfo"] = dataclasses.field(default_factory=list)
145
+
146
+ def format(self) -> str:
147
+ """Full content for writing to file (compiled output)."""
148
+ if self.source_code:
149
+ return self.source_code
150
+ if self.readable_code:
151
+ return self.readable_code
152
+ if self.graph_module_code:
153
+ return self.graph_module_code
154
+ return f"# {self.name} (backend={self.backend})\n"
155
+
156
+ def format_summary(self) -> str:
157
+ """Short summary for overview / full_code."""
158
+ header = f"{self.name} (backend={self.backend}"
159
+ if self.cudagraph_mode:
160
+ header += f", cudagraph={self.cudagraph_mode}"
161
+ header += ")"
162
+ lines = [header]
163
+ if self.cache_key:
164
+ lines.append(f" cache_key: {self.cache_key}")
165
+ if self.graph_module_code:
166
+ lines.append(" GraphModule.code:")
167
+ for l in self.graph_module_code.strip().splitlines():
168
+ lines.append(f" {l}")
169
+ if self.subgraph_infos:
170
+ lines.append(f" piecewise subgraphs: {len(self.subgraph_infos)}")
171
+ for sg in self.subgraph_infos:
172
+ tag = "splitting_op" if sg.is_splitting_graph else "compiled"
173
+ lines.append(f" {sg.name} ({tag})")
174
+ return "\n".join(lines)
175
+
176
+
177
+ @dataclasses.dataclass
178
+ class GuardInfo:
179
+ """Guard information for a CacheEntry."""
180
+
181
+ tree: Optional[GuardNode] = None
182
+ closure_vars: Optional[Dict[str, str]] = None
183
+
184
+ def format(self) -> str:
185
+ lines = []
186
+ if self.tree:
187
+ lines.append(self.tree.format())
188
+ if self.closure_vars:
189
+ lines.append(" closure_vars:")
190
+ for k, v in list(self.closure_vars.items())[:8]:
191
+ lines.append(f" {k} = {v}")
192
+ return "\n".join(lines)
193
+
194
+
195
+ @dataclasses.dataclass
196
+ class EntryInfo:
197
+ """One CacheEntry in the linked list."""
198
+
199
+ index: int
200
+ dynamo_code: Optional[CodeType] = None
201
+ decompiled_source: str = ""
202
+ guard: Optional[GuardInfo] = None
203
+ compiled_fns: List[CompiledFnInfo] = dataclasses.field(default_factory=list)
204
+ resume_fns: List["FunctionInfo"] = dataclasses.field(default_factory=list)
205
+
206
+ def format(self, indent: int = 0) -> str:
207
+ pfx = " " * indent
208
+ lines = [f"{pfx}entry[{self.index}]:"]
209
+ if self.decompiled_source:
210
+ lines.append(f"{pfx} dynamo_code (decompiled):")
211
+ for l in self.decompiled_source.splitlines():
212
+ lines.append(f"{pfx} {l}")
213
+ if self.compiled_fns:
214
+ lines.append(f"{pfx} compiled functions:")
215
+ for cf in self.compiled_fns:
216
+ lines.append(cf.format_summary())
217
+ if self.guard:
218
+ lines.append(f"{pfx} guards:")
219
+ lines.append(self.guard.format())
220
+ if self.resume_fns:
221
+ lines.append(f"{pfx} resume functions:")
222
+ for rf in self.resume_fns:
223
+ lines.append(rf.format(indent + 2))
224
+ return "\n".join(lines)
225
+
226
+
227
+ @dataclasses.dataclass
228
+ class FunctionInfo:
229
+ """A compiled function and its CacheEntry chain."""
230
+
231
+ name: str
232
+ original_code: Optional[CodeType] = None
233
+ original_source: str = ""
234
+ entries: List[EntryInfo] = dataclasses.field(default_factory=list)
235
+
236
+ def format(self, indent: int = 0) -> str:
237
+ pfx = " " * indent
238
+ lines = [f"{pfx}{self.name}: {len(self.entries)} cache entries"]
239
+ for entry in self.entries:
240
+ lines.append(entry.format(indent + 1))
241
+ return "\n".join(lines)
pkgs/MagiCompiler/magi_compiler/magi_depyf/inspect/result.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """CaptureResult — structured data model for one compilation event."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import dataclasses
20
+ import time
21
+ from types import CodeType
22
+ from typing import List, Optional
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class CaptureResult:
27
+ """Data captured from a single ``torch.compile`` bytecode event.
28
+
29
+ - original_code: the user's original function code
30
+ - dynamo_code: the code after Dynamo transformation (with __compiled_fn / __resume calls)
31
+ - decompiled_source: dynamo_code decompiled back to Python source
32
+ - fn_globals: the function's global namespace (for post-hoc introspection)
33
+ """
34
+
35
+ function_name: str
36
+ original_code: CodeType
37
+ dynamo_code: CodeType
38
+ decompiled_source: str
39
+ fn_globals: Optional[dict] = None
40
+ guards: List[str] = dataclasses.field(default_factory=list)
41
+ graph_source: Optional[str] = None
42
+ timestamp: float = dataclasses.field(default_factory=time.time)
43
+
44
+ def summary(self) -> str:
45
+ n_guards = len(self.guards)
46
+ return (
47
+ f"[{self.function_name}] "
48
+ f"original={self.original_code.co_name}, "
49
+ f"guards={n_guards}, "
50
+ f"graph={'yes' if self.graph_source else 'no'}"
51
+ )