JasonYinnnn commited on
Commit
afea36f
·
1 Parent(s): 3d533e5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +408 -0
  2. .gitmodules +3 -0
  3. README.md +3 -3
  4. app.py +857 -4
  5. requirements.txt +33 -0
  6. scripts/grounding_sam.py +371 -0
  7. scripts/grounding_sam2.py +353 -0
  8. threeDFixer/__init__.py +11 -0
  9. threeDFixer/datasets/__init__.py +107 -0
  10. threeDFixer/datasets/utils.py +631 -0
  11. threeDFixer/models/__init__.py +123 -0
  12. threeDFixer/models/scene_sparse_structure_flow.py +334 -0
  13. threeDFixer/models/scene_structured_latent_flow.py +415 -0
  14. threeDFixer/models/sparse_elastic_mixin.py +29 -0
  15. threeDFixer/models/sparse_structure_flow.py +219 -0
  16. threeDFixer/models/sparse_structure_vae.py +325 -0
  17. threeDFixer/models/structured_latent_flow.py +295 -0
  18. threeDFixer/models/structured_latent_vae/__init__.py +9 -0
  19. threeDFixer/models/structured_latent_vae/base.py +122 -0
  20. threeDFixer/models/structured_latent_vae/decoder_gs.py +150 -0
  21. threeDFixer/models/structured_latent_vae/decoder_mesh.py +189 -0
  22. threeDFixer/models/structured_latent_vae/decoder_rf.py +118 -0
  23. threeDFixer/models/structured_latent_vae/encoder.py +93 -0
  24. threeDFixer/modules/attention/__init__.py +41 -0
  25. threeDFixer/modules/attention/full_attn.py +145 -0
  26. threeDFixer/modules/attention/modules.py +151 -0
  27. threeDFixer/modules/norm.py +30 -0
  28. threeDFixer/modules/sparse/__init__.py +102 -0
  29. threeDFixer/modules/sparse/attention/__init__.py +9 -0
  30. threeDFixer/modules/sparse/attention/full_attn.py +220 -0
  31. threeDFixer/modules/sparse/attention/modules.py +144 -0
  32. threeDFixer/modules/sparse/attention/serialized_attn.py +198 -0
  33. threeDFixer/modules/sparse/attention/windowed_attn.py +140 -0
  34. threeDFixer/modules/sparse/basic.py +464 -0
  35. threeDFixer/modules/sparse/conv/__init__.py +26 -0
  36. threeDFixer/modules/sparse/conv/conv_spconv.py +85 -0
  37. threeDFixer/modules/sparse/conv/conv_torchsparse.py +43 -0
  38. threeDFixer/modules/sparse/linear.py +20 -0
  39. threeDFixer/modules/sparse/nonlinearity.py +40 -0
  40. threeDFixer/modules/sparse/norm.py +63 -0
  41. threeDFixer/modules/sparse/spatial.py +115 -0
  42. threeDFixer/modules/sparse/transformer/__init__.py +7 -0
  43. threeDFixer/modules/sparse/transformer/blocks.py +156 -0
  44. threeDFixer/modules/sparse/transformer/modulated.py +304 -0
  45. threeDFixer/modules/spatial.py +53 -0
  46. threeDFixer/modules/transformer/__init__.py +2 -0
  47. threeDFixer/modules/transformer/blocks.py +187 -0
  48. threeDFixer/modules/transformer/modulated.py +289 -0
  49. threeDFixer/modules/utils.py +59 -0
  50. threeDFixer/moge/__init__.py +5 -0
.gitignore ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Ignore Visual Studio temporary files, build results, and
2
+ ## files generated by popular Visual Studio add-ons.
3
+ ##
4
+ ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
5
+
6
+ # User-specific files
7
+ *.rsuser
8
+ *.suo
9
+ *.user
10
+ *.userosscache
11
+ *.sln.docstates
12
+
13
+ # User-specific files (MonoDevelop/Xamarin Studio)
14
+ *.userprefs
15
+
16
+ # Mono auto generated files
17
+ mono_crash.*
18
+
19
+ # Build results
20
+ [Dd]ebug/
21
+ [Dd]ebugPublic/
22
+ [Rr]elease/
23
+ [Rr]eleases/
24
+ x64/
25
+ x86/
26
+ [Ww][Ii][Nn]32/
27
+ [Aa][Rr][Mm]/
28
+ [Aa][Rr][Mm]64/
29
+ bld/
30
+ [Bb]in/
31
+ [Oo]bj/
32
+ [Ll]og/
33
+ [Ll]ogs/
34
+
35
+ # Visual Studio 2015/2017 cache/options directory
36
+ .vs/
37
+ # Uncomment if you have tasks that create the project's static files in wwwroot
38
+ #wwwroot/
39
+
40
+ # Visual Studio 2017 auto generated files
41
+ Generated\ Files/
42
+
43
+ # MSTest test Results
44
+ [Tt]est[Rr]esult*/
45
+ [Bb]uild[Ll]og.*
46
+
47
+ # NUnit
48
+ *.VisualState.xml
49
+ TestResult.xml
50
+ nunit-*.xml
51
+
52
+ # Build Results of an ATL Project
53
+ [Dd]ebugPS/
54
+ [Rr]eleasePS/
55
+ dlldata.c
56
+
57
+ # Benchmark Results
58
+ BenchmarkDotNet.Artifacts/
59
+
60
+ # .NET Core
61
+ project.lock.json
62
+ project.fragment.lock.json
63
+ artifacts/
64
+
65
+ # ASP.NET Scaffolding
66
+ ScaffoldingReadMe.txt
67
+
68
+ # StyleCop
69
+ StyleCopReport.xml
70
+
71
+ # Files built by Visual Studio
72
+ *_i.c
73
+ *_p.c
74
+ *_h.h
75
+ *.ilk
76
+ *.meta
77
+ *.obj
78
+ *.iobj
79
+ *.pch
80
+ *.pdb
81
+ *.ipdb
82
+ *.pgc
83
+ *.pgd
84
+ *.rsp
85
+ *.sbr
86
+ *.tlb
87
+ *.tli
88
+ *.tlh
89
+ *.tmp
90
+ *.tmp_proj
91
+ *_wpftmp.csproj
92
+ *.log
93
+ *.tlog
94
+ *.vspscc
95
+ *.vssscc
96
+ .builds
97
+ *.pidb
98
+ *.svclog
99
+ *.scc
100
+
101
+ # Chutzpah Test files
102
+ _Chutzpah*
103
+
104
+ # Visual C++ cache files
105
+ ipch/
106
+ *.aps
107
+ *.ncb
108
+ *.opendb
109
+ *.opensdf
110
+ *.sdf
111
+ *.cachefile
112
+ *.VC.db
113
+ *.VC.VC.opendb
114
+
115
+ # Visual Studio profiler
116
+ *.psess
117
+ *.vsp
118
+ *.vspx
119
+ *.sap
120
+
121
+ # Visual Studio Trace Files
122
+ *.e2e
123
+
124
+ # TFS 2012 Local Workspace
125
+ $tf/
126
+
127
+ # Guidance Automation Toolkit
128
+ *.gpState
129
+
130
+ # ReSharper is a .NET coding add-in
131
+ _ReSharper*/
132
+ *.[Rr]e[Ss]harper
133
+ *.DotSettings.user
134
+
135
+ # TeamCity is a build add-in
136
+ _TeamCity*
137
+
138
+ # DotCover is a Code Coverage Tool
139
+ *.dotCover
140
+
141
+ # AxoCover is a Code Coverage Tool
142
+ .axoCover/*
143
+ !.axoCover/settings.json
144
+
145
+ # Coverlet is a free, cross platform Code Coverage Tool
146
+ coverage*.json
147
+ coverage*.xml
148
+ coverage*.info
149
+
150
+ # Visual Studio code coverage results
151
+ *.coverage
152
+ *.coveragexml
153
+
154
+ # NCrunch
155
+ _NCrunch_*
156
+ .*crunch*.local.xml
157
+ nCrunchTemp_*
158
+
159
+ # MightyMoose
160
+ *.mm.*
161
+ AutoTest.Net/
162
+
163
+ # Web workbench (sass)
164
+ .sass-cache/
165
+
166
+ # Installshield output folder
167
+ [Ee]xpress/
168
+
169
+ # DocProject is a documentation generator add-in
170
+ DocProject/buildhelp/
171
+ DocProject/Help/*.HxT
172
+ DocProject/Help/*.HxC
173
+ DocProject/Help/*.hhc
174
+ DocProject/Help/*.hhk
175
+ DocProject/Help/*.hhp
176
+ DocProject/Help/Html2
177
+ DocProject/Help/html
178
+
179
+ # Click-Once directory
180
+ publish/
181
+
182
+ # Publish Web Output
183
+ *.[Pp]ublish.xml
184
+ *.azurePubxml
185
+ # Note: Comment the next line if you want to checkin your web deploy settings,
186
+ # but database connection strings (with potential passwords) will be unencrypted
187
+ *.pubxml
188
+ *.publishproj
189
+
190
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
191
+ # checkin your Azure Web App publish settings, but sensitive information contained
192
+ # in these scripts will be unencrypted
193
+ PublishScripts/
194
+
195
+ # NuGet Packages
196
+ *.nupkg
197
+ # NuGet Symbol Packages
198
+ *.snupkg
199
+ # The packages folder can be ignored because of Package Restore
200
+ **/[Pp]ackages/*
201
+ # except build/, which is used as an MSBuild target.
202
+ !**/[Pp]ackages/build/
203
+ # Uncomment if necessary however generally it will be regenerated when needed
204
+ #!**/[Pp]ackages/repositories.config
205
+ # NuGet v3's project.json files produces more ignorable files
206
+ *.nuget.props
207
+ *.nuget.targets
208
+
209
+ # Microsoft Azure Build Output
210
+ csx/
211
+ *.build.csdef
212
+
213
+ # Microsoft Azure Emulator
214
+ ecf/
215
+ rcf/
216
+
217
+ # Windows Store app package directories and files
218
+ AppPackages/
219
+ BundleArtifacts/
220
+ Package.StoreAssociation.xml
221
+ _pkginfo.txt
222
+ *.appx
223
+ *.appxbundle
224
+ *.appxupload
225
+
226
+ # Visual Studio cache files
227
+ # files ending in .cache can be ignored
228
+ *.[Cc]ache
229
+ # but keep track of directories ending in .cache
230
+ !?*.[Cc]ache/
231
+
232
+ # Others
233
+ ClientBin/
234
+ ~$*
235
+ *~
236
+ *.dbmdl
237
+ *.dbproj.schemaview
238
+ *.jfm
239
+ *.pfx
240
+ *.publishsettings
241
+ orleans.codegen.cs
242
+
243
+ # Including strong name files can present a security risk
244
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
245
+ #*.snk
246
+
247
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
248
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
249
+ #bower_components/
250
+
251
+ # RIA/Silverlight projects
252
+ Generated_Code/
253
+
254
+ # Backup & report files from converting an old project file
255
+ # to a newer Visual Studio version. Backup files are not needed,
256
+ # because we have git ;-)
257
+ _UpgradeReport_Files/
258
+ Backup*/
259
+ UpgradeLog*.XML
260
+ UpgradeLog*.htm
261
+ ServiceFabricBackup/
262
+ *.rptproj.bak
263
+
264
+ # SQL Server files
265
+ *.mdf
266
+ *.ldf
267
+ *.ndf
268
+
269
+ # Business Intelligence projects
270
+ *.rdl.data
271
+ *.bim.layout
272
+ *.bim_*.settings
273
+ *.rptproj.rsuser
274
+ *- [Bb]ackup.rdl
275
+ *- [Bb]ackup ([0-9]).rdl
276
+ *- [Bb]ackup ([0-9][0-9]).rdl
277
+
278
+ # Microsoft Fakes
279
+ FakesAssemblies/
280
+
281
+ # GhostDoc plugin setting file
282
+ *.GhostDoc.xml
283
+
284
+ # Node.js Tools for Visual Studio
285
+ .ntvs_analysis.dat
286
+ node_modules/
287
+
288
+ # Visual Studio 6 build log
289
+ *.plg
290
+
291
+ # Visual Studio 6 workspace options file
292
+ *.opt
293
+
294
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
295
+ *.vbw
296
+
297
+ # Visual Studio 6 auto-generated project file (contains which files were open etc.)
298
+ *.vbp
299
+
300
+ # Visual Studio 6 workspace and project file (working project files containing files to include in project)
301
+ *.dsw
302
+ *.dsp
303
+
304
+ # Visual Studio 6 technical files
305
+ *.ncb
306
+ *.aps
307
+
308
+ # Visual Studio LightSwitch build output
309
+ **/*.HTMLClient/GeneratedArtifacts
310
+ **/*.DesktopClient/GeneratedArtifacts
311
+ **/*.DesktopClient/ModelManifest.xml
312
+ **/*.Server/GeneratedArtifacts
313
+ **/*.Server/ModelManifest.xml
314
+ _Pvt_Extensions
315
+
316
+ # Paket dependency manager
317
+ .paket/paket.exe
318
+ paket-files/
319
+
320
+ # FAKE - F# Make
321
+ .fake/
322
+
323
+ # CodeRush personal settings
324
+ .cr/personal
325
+
326
+ # Python Tools for Visual Studio (PTVS)
327
+ __pycache__/
328
+ *.pyc
329
+
330
+ # Cake - Uncomment if you are using it
331
+ # tools/**
332
+ # !tools/packages.config
333
+
334
+ # Tabs Studio
335
+ *.tss
336
+
337
+ # Telerik's JustMock configuration file
338
+ *.jmconfig
339
+
340
+ # BizTalk build output
341
+ *.btp.cs
342
+ *.btm.cs
343
+ *.odx.cs
344
+ *.xsd.cs
345
+
346
+ # OpenCover UI analysis results
347
+ OpenCover/
348
+
349
+ # Azure Stream Analytics local run output
350
+ ASALocalRun/
351
+
352
+ # MSBuild Binary and Structured Log
353
+ *.binlog
354
+
355
+ # NVidia Nsight GPU debugger configuration file
356
+ *.nvuser
357
+
358
+ # MFractors (Xamarin productivity tool) working folder
359
+ .mfractor/
360
+
361
+ # Local History for Visual Studio
362
+ .localhistory/
363
+
364
+ # Visual Studio History (VSHistory) files
365
+ .vshistory/
366
+
367
+ # BeatPulse healthcheck temp database
368
+ healthchecksdb
369
+
370
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
371
+ MigrationBackup/
372
+
373
+ # Ionide (cross platform F# VS Code tools) working folder
374
+ .ionide/
375
+
376
+ # Fody - auto-generated XML schema
377
+ FodyWeavers.xsd
378
+
379
+ # VS Code files for those working on multiple tools
380
+ .vscode/*
381
+ !.vscode/settings.json
382
+ !.vscode/tasks.json
383
+ !.vscode/launch.json
384
+ !.vscode/extensions.json
385
+ *.code-workspace
386
+
387
+ # Local History for Visual Studio Code
388
+ .history/
389
+
390
+ # Windows Installer files from build outputs
391
+ *.cab
392
+ *.msi
393
+ *.msix
394
+ *.msm
395
+ *.msp
396
+
397
+ # JetBrains Rider
398
+ *.sln.iml
399
+
400
+ threeDFixer_weights
401
+ threeDFixer_weights/**
402
+
403
+ tmp
404
+ tmp/**
405
+
406
+ gradio_temp
407
+ gradio_temp/**
408
+
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "threeDFixer/representations/mesh/flexicubes"]
2
+ path = threeDFixer/representations/mesh/flexicubes
3
+ url = https://github.com/MaxtirError/FlexiCubes.git
README.md CHANGED
@@ -4,12 +4,12 @@ emoji: 🦀
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.10.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  short_description: Create 3D Scene from a single image via In-Place Completion.
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
+ python_version: '3.10'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  short_description: Create 3D Scene from a single image via In-Place Completion.
13
  ---
14
 
15
+ This is the interactive demo of [3D-Fixer](https://zx-yin.github.io/3dfixer/).
app.py CHANGED
@@ -1,7 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # SPDX-FileCopyrightText: 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # See the LICENSE file in the project root for full license information.
4
+
5
+ import os
6
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), "gradio_temp")
7
+ os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
8
+ import uuid
9
+ from typing import Any, List, Optional, Union
10
+
11
+ import cv2
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+ import trimesh
16
+ import random
17
+ import imageio
18
+ from einops import repeat
19
+
20
+ from gradio_image_prompter import ImagePrompter
21
  import gradio as gr
22
 
23
+ from threeDFixer.pipelines import ThreeDFixerPipeline
24
+ from threeDFixer.datasets.utils import (
25
+ edge_mask_morph_gradient,
26
+ process_scene_image,
27
+ process_instance_image,
28
+ transform_vertices,
29
+ normalize_vertices,
30
+ project2ply
31
+ )
32
+ from threeDFixer.utils import render_utils, postprocessing_utils
33
+ from scripts.grounding_sam2 import plot_segmentation, segment
34
+ from sam2.build_sam import build_sam2
35
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
36
+ import copy
37
+
38
+ import shutil
39
+ import time
40
+ from concurrent.futures import ThreadPoolExecutor
41
+
42
+ MARKDOWN = """
43
+ ## Image to 3D Scene with [3D-Fixer](https://zx-yin.github.io/3dfixer/)
44
+ 1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then click "Run Segmentation" to generate the segmentation result.
45
+ 2. If you find the generated 3D scene satisfactory, download it by clicking the "Download scene GLB" button, and you can also download each islolated 3D instance.
46
+ 3. In this implementation, we generate each instances one by one, and update the scene results at the "Generated GLB" area, besides, we display isolated instances below.
47
+ 4. it may take a while for the first time inference due to the usage of ```torch.compile```.
48
+ """
49
+ MAX_SEED = np.iinfo(np.int32).max
50
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
51
+ EXAMPLE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/example_data")
52
+ DTYPE = torch.float16
53
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
54
+ VALID_RATIO_THRESHOLD = 0.005
55
+ CROP_SIZE = 518
56
+ work_space = None
57
+ dpt_pack = None
58
+ generated_object_map = {}
59
+
60
+ # Prepare models
61
+ ## Grounding SAM
62
+ sam2_checkpoint = "./checkpoints/sam2-hiera-large/sam2_hiera_large.pt"
63
+ sam2_model_cfg = "configs/sam2/sam2_hiera_l.yaml"
64
+ sam2_predictor = SAM2ImagePredictor(
65
+ build_sam2(sam2_model_cfg, sam2_checkpoint),
66
+ )
67
+
68
+ ############## 3D-Fixer model
69
+ model_dir = 'HorizonRobotics/3D-Fixer'
70
+ pipeline = ThreeDFixerPipeline.from_pretrained(
71
+ model_dir, compile=True
72
+ )
73
+ pipeline.cuda()
74
+ ############## 3D-Fixer model
75
+
76
+ rot = np.array([
77
+ [-1.0, 0.0, 0.0, 0.0],
78
+ [ 0.0, 0.0, 1.0, 0.0],
79
+ [ 0.0, 1.0, 0.0, 0.0],
80
+ [ 0.0, 0.0, 0.0, 1.0],
81
+ ], dtype=np.float32)
82
+
83
+ c2w = torch.tensor([
84
+ [1.0, 0.0, 0.0, 0.0],
85
+ [0.0, 0.0, -1.0, 0.0],
86
+ [0.0, 1.0, 0.0, 0.0],
87
+ [0.0, 0.0, 0.0, 1.0],
88
+ ], dtype=torch.float32, device=DEVICE)
89
+
90
+ save_projected_colored_pcd = lambda pts, pts_color, fpath: trimesh.PointCloud(pts.reshape(-1, 3), pts_color.reshape(-1, 3)).export(fpath)
91
+
92
+ EXAMPLES = [
93
+ [
94
+ {
95
+ "image": "assets/example_data/scene1/rgb.png",
96
+ },
97
+ "assets/example_data/scene1/seg.png",
98
+ 1024,
99
+ False,
100
+ 25, 5.5, 0.8, 1.0, 5.0
101
+ # num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale
102
+ ],
103
+ [
104
+ {
105
+ "image": "assets/example_data/scene2/rgb.png",
106
+ },
107
+ "assets/example_data/scene2/seg.png",
108
+ 1,
109
+ False,
110
+ 25, 5.0, 0.8, 1.0, 5.0
111
+ ],
112
+ [
113
+ {
114
+ "image": "assets/example_data/scene3/rgb.png",
115
+ },
116
+ "assets/example_data/scene3/seg.png",
117
+ 1,
118
+ False,
119
+ 25, 5.0, 0.8, 1.0, 5.0
120
+ ],
121
+ [
122
+ {
123
+ "image": "assets/example_data/scene4/rgb.png",
124
+ },
125
+ "assets/example_data/scene4/seg.png",
126
+ 42,
127
+ False,
128
+ 25, 5.0, 0.8, 1.0, 5.0
129
+ ],
130
+ [
131
+ {
132
+ "image": "assets/example_data/scene5/rgb.png",
133
+ },
134
+ "assets/example_data/scene5/seg.png",
135
+ 1,
136
+ False,
137
+ 25, 5.0, 0.8, 1.0, 5.0
138
+ ],
139
+ [
140
+ {
141
+ "image": "assets/example_data/scene6/rgb.png",
142
+ },
143
+ "assets/example_data/scene6/seg.png",
144
+ 1,
145
+ False,
146
+ 25, 5.0, 0.8, 1.0, 5.0
147
+ ]
148
+ ]
149
+
150
+ def cleanup_tmp(tmp_root: str = "./tmp", expire_seconds: int = 3600) -> None:
151
+ """
152
+ 删除 tmp_root 下超过 expire_seconds 未更新的旧子目录。
153
+
154
+ Args:
155
+ tmp_root: 临时目录根路径。
156
+ expire_seconds: 过期时间,默认 3600 秒(1 小时)。
157
+ """
158
+ tmp_root = os.path.abspath(tmp_root)
159
+
160
+ if not os.path.isdir(tmp_root):
161
+ return
162
+
163
+ now = time.time()
164
+
165
+ for name in os.listdir(tmp_root):
166
+ path = os.path.join(tmp_root, name)
167
+
168
+ # 只清理子目录,不动散落文件
169
+ if not os.path.isdir(path):
170
+ continue
171
+
172
+ try:
173
+ mtime = os.path.getmtime(path)
174
+ age = now - mtime
175
+
176
+ if age > expire_seconds:
177
+ shutil.rmtree(path, ignore_errors=False)
178
+ print(f"[cleanup_tmp] removed old directory: {path}")
179
+ except Exception as e:
180
+ print(f"[cleanup_tmp] failed to remove {path}: {e}")
181
+
182
+ @torch.no_grad()
183
+ def run_segmentation(
184
+ image_prompts: Any,
185
+ polygon_refinement: bool = True,
186
+ ) -> Image.Image:
187
+ rgb_image = image_prompts["image"].convert("RGB")
188
+
189
+ global work_space
190
+
191
+ # pre-process the layers and get the xyxy boxes of each layer
192
+ if len(image_prompts["points"]) == 0:
193
+ gr.Error("No points provided for segmentation. Please add points to the image.")
194
+ return None
195
+
196
+ boxes = [
197
+ [
198
+ [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
199
+ for box in image_prompts["points"]
200
+ ]
201
+ ]
202
+
203
+ detections = segment(
204
+ sam2_predictor,
205
+ rgb_image,
206
+ boxes=[boxes],
207
+ polygon_refinement=polygon_refinement,
208
+ )
209
+ seg_map_pil = plot_segmentation(rgb_image, detections)
210
+
211
+ torch.cuda.empty_cache()
212
+
213
+ cleanup_tmp(TMP_DIR, expire_seconds=3600)
214
+
215
+ work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
216
+ os.makedirs(work_space, exist_ok=True)
217
+ seg_map_pil.save(os.path.join(work_space, 'mask.png'))
218
+
219
+ return seg_map_pil
220
+
221
+ @torch.no_grad()
222
+ def run_depth_estimation(
223
+ image_prompts: Any,
224
+ seg_image: Union[str, Image.Image],
225
+ ) -> Image.Image:
226
+ rgb_image = image_prompts["image"].convert("RGB")
227
+
228
+ rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
229
+
230
+ global dpt_pack
231
+ global work_space
232
+ if work_space is None:
233
+ work_space = os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}")
234
+ os.makedirs(work_space, exist_ok=True)
235
+ global generated_object_map
236
+
237
+ generated_object_map = {}
238
+
239
+ origin_W, origin_H = rgb_image.size
240
+ if max(origin_H, origin_W) > 1024:
241
+ factor = max(origin_H, origin_W) / 1024
242
+ H = int(origin_H // factor)
243
+ W = int(origin_W // factor)
244
+ rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
245
+ W, H = rgb_image.size
246
+
247
+ input_image = np.array(rgb_image).astype(np.float32)
248
+ input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=DEVICE).permute(2, 0, 1)
249
+
250
+ output = pipeline.models['scene_cond_model'].infer(input_image)
251
+ depth = output['depth']
252
+ intrinsics = output['intrinsics']
253
+
254
+ invalid_mask = torch.logical_or(torch.isnan(depth), torch.isinf(depth))
255
+ depth_mask = ~invalid_mask
256
+
257
+ depth = torch.where(invalid_mask, 0.0, depth)
258
+ K = torch.from_numpy(
259
+ np.array([
260
+ [intrinsics[0, 0].item() * W, 0, 0.5*W],
261
+ [0, intrinsics[1, 1].item() * H, 0.5*H],
262
+ [0, 0, 1]
263
+ ])
264
+ ).to(dtype=torch.float32, device=DEVICE)
265
+
266
+ dpt_pack = {
267
+ 'c2w': c2w,
268
+ 'K': K,
269
+ 'depth_mask': depth_mask,
270
+ 'depth': depth
271
+ }
272
+
273
+ instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
274
+ seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
275
+ seg_image = np.array(seg_image)
276
+
277
+ mask_pack = []
278
+ for instance_label in instance_labels:
279
+ if (instance_label == np.array([0, 0, 0])).all():
280
+ continue
281
+ else:
282
+ instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W)
283
+ mask_pack.append(instance_mask)
284
+ fg_mask = torch.from_numpy(np.stack(mask_pack).any(axis=0)).to(DEVICE)
285
+
286
+ scene_est_depth_pts, scene_est_depth_pts_colors = \
287
+ project2ply(depth_mask, depth, input_image, K, c2w)
288
+ save_ply_path = os.path.join(work_space, "scene_pcd.glb")
289
+
290
+ fg_depth_pts, _ = \
291
+ project2ply(fg_mask, depth, input_image, K, c2w)
292
+ _, trans, scale = normalize_vertices(fg_depth_pts)
293
+
294
+ if trans.shape[0] == 1:
295
+ trans = trans[0]
296
+
297
+ dpt_pack.update(
298
+ {
299
+ "trans": trans,
300
+ "scale": scale,
301
+ }
302
+ )
303
+
304
+ trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\
305
+ apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\
306
+ apply_transform(rot).export(save_ply_path)
307
+
308
+ torch.cuda.empty_cache()
309
+
310
+ return save_ply_path
311
+
312
+
313
+ def save_image(img, save_path):
314
+ img = (img.permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)
315
+ imageio.v3.imwrite(save_path, img)
316
+
317
+ def set_random_seed(seed):
318
+ np.random.seed(seed)
319
+ random.seed(seed)
320
+ torch.manual_seed(seed)
321
+ if torch.cuda.is_available():
322
+ torch.cuda.manual_seed_all(seed)
323
+
324
+ def export_single_glb_from_outputs(
325
+ outputs,
326
+ fine_scale,
327
+ fine_trans,
328
+ coarse_scale,
329
+ coarse_trans,
330
+ trans,
331
+ scale,
332
+ rot,
333
+ work_space,
334
+ instance_name,
335
+ run_id
336
+ ):
337
+
338
+ with torch.enable_grad():
339
+ glb = postprocessing_utils.to_glb(
340
+ outputs["gaussian"][0],
341
+ outputs["mesh"][0],
342
+ simplify=0.95,
343
+ texture_size=1024,
344
+ transform_fn=lambda x: transform_vertices(
345
+ x,
346
+ ops=["scale", "translation", "scale", "translation"],
347
+ params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]],
348
+ ),
349
+ verbose=False
350
+ )
351
+
352
+ instance_glb_path = os.path.abspath(
353
+ os.path.join(work_space, f"{run_id}_{instance_name}.glb")
354
+ )
355
+
356
+ glb.apply_translation(-trans) \
357
+ .apply_scale(1.0 / (scale + 1e-6)) \
358
+ .apply_transform(rot) \
359
+ .export(instance_glb_path)
360
+
361
+ return instance_glb_path, glb
362
+
363
+
364
+ def export_scene_glb(trimeshes, work_space, scene_name):
365
+ scene_path = os.path.abspath(os.path.join(work_space, scene_name))
366
+ trimesh.Scene(trimeshes).export(scene_path)
367
+
368
+ return scene_path
369
+
370
+ @torch.no_grad()
371
+ def run_generation(
372
+ rgb_image: Any,
373
+ seg_image: Union[str, Image.Image],
374
+ seed: int,
375
+ randomize_seed: bool = False,
376
+ num_inference_steps: int = 50,
377
+ guidance_scale: float = 5.0,
378
+ cfg_interval_start: float = 0.5,
379
+ cfg_interval_end: float = 1.0,
380
+ t_rescale: float = 3.0,
381
+ ):
382
+ global dpt_pack
383
+ global work_space
384
+ global generated_object_map
385
+ generated_object_map = {}
386
+ run_id = str(uuid.uuid4())
387
+
388
+ if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
389
+ rgb_image = rgb_image["image"]
390
+
391
+ instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
392
+ if randomize_seed:
393
+ seed = random.randint(0, MAX_SEED)
394
+ set_random_seed(seed)
395
+
396
+ H, W = dpt_pack['depth_mask'].shape
397
+ rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
398
+ seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
399
+
400
+ depth_mask = dpt_pack['depth_mask'].detach().cpu().numpy() > 0
401
+ seg_image = np.array(seg_image)
402
+
403
+ mask_pack = []
404
+ for instance_label in instance_labels:
405
+ if (instance_label == np.array([0, 0, 0])).all():
406
+ continue
407
+ instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W)
408
+ mask_pack.append(instance_mask)
409
+
410
+ erode_kernel_size = 7
411
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_kernel_size, erode_kernel_size))
412
+ results = []
413
+ trimeshes = []
414
+
415
+ trans = dpt_pack['trans']
416
+ scale = dpt_pack['scale']
417
+
418
+ current_scene_path = None
419
+ pending_exports = []
420
+
421
+ def build_stream_html(status_text: str):
422
+ cards_html = "".join([
423
+ f"""
424
+ <div style="
425
+ width: 220px;
426
+ border: 1px solid #ddd;
427
+ border-radius: 10px;
428
+ padding: 8px;
429
+ background: white;
430
+ box-sizing: border-box;
431
+ ">
432
+ <div style="font-weight: 600; margin-bottom: 6px;">
433
+ {item["name"]}
434
+ </div>
435
+
436
+ <video
437
+ autoplay
438
+ muted
439
+ loop
440
+ playsinline
441
+ preload="metadata"
442
+ poster="/file={item['poster_path']}?v={run_id}"
443
+ style="
444
+ width: 100%;
445
+ border-radius: 8px;
446
+ display: block;
447
+ background: #f5f5f5;
448
+ "
449
+ >
450
+ <source src="/file={item['mp4_path']}?v={run_id}" type="video/mp4">
451
+ </video>
452
+
453
+ <div style="
454
+ margin-top: 6px;
455
+ font-size: 12px;
456
+ color: #666;
457
+ ">
458
+ Status: {item.get("status_text", "Unknown")}
459
+ </div>
460
+
461
+ <div style="
462
+ margin-top: 4px;
463
+ font-size: 13px;
464
+ color: #444;
465
+ word-break: break-all;
466
+ ">
467
+ {os.path.basename(item["glb_path"]) if item["glb_path"] is not None else "GLB not ready yet"}
468
+ </div>
469
+ </div>
470
+ """
471
+ for item in results
472
+ ])
473
+
474
+ return f"""
475
+ <div style="padding: 8px 0;">
476
+ <div style="font-weight: 700; margin-bottom: 8px;">Status: {status_text}</div>
477
+ <div style="font-weight: 700; margin-bottom: 12px;">Generated objects: {len(results)}</div>
478
+ <div style="display: flex; flex-wrap: wrap; gap: 12px; align-items: flex-start;">
479
+ {cards_html}
480
+ </div>
481
+ </div>
482
+ """
483
+
484
+ def build_selector_and_download_updates(default_latest: bool = True):
485
+ object_choices = [item["name"] for item in results if item["glb_path"] is not None]
486
+
487
+ if len(object_choices) == 0:
488
+ return (
489
+ gr.update(choices=[], value=None),
490
+ gr.update(value=None, interactive=False),
491
+ )
492
+
493
+ selected_value = object_choices[-1] if default_latest else object_choices[0]
494
+ selected_path = generated_object_map[selected_value]
495
+
496
+ return (
497
+ gr.update(choices=object_choices, value=selected_value),
498
+ gr.update(value=selected_path, interactive=True),
499
+ )
500
+
501
+ def flush_finished_exports(status_text: str):
502
+ nonlocal current_scene_path, trimeshes, pending_exports
503
+
504
+ any_update = False
505
+ finished_items = []
506
+
507
+ for item in pending_exports:
508
+ if item["future"].done():
509
+ finished_items.append(item)
510
+
511
+ for item in finished_items:
512
+ pending_exports.remove(item)
513
+
514
+ result_index = item["result_index"]
515
+ object_label = item["object_label"]
516
+ future = item["future"]
517
+
518
+ try:
519
+ instance_glb_path, glb = future.result()
520
+ except Exception as e:
521
+ print(f"[export_glb][error] instance={item['instance_name']}: {e}")
522
+ results[result_index]["status_text"] = "GLB export failed"
523
+ any_update = True
524
+ continue
525
+
526
+ results[result_index]["glb_path"] = instance_glb_path
527
+ results[result_index]["status_text"] = "GLB ready"
528
+ generated_object_map[object_label] = instance_glb_path
529
+
530
+ trimeshes.append(glb)
531
+ current_scene_path = export_scene_glb(
532
+ trimeshes=trimeshes,
533
+ work_space=work_space,
534
+ scene_name=f"{run_id}_scene_step_{len(trimeshes)}.glb",
535
+ )
536
+ any_update = True
537
+
538
+ if any_update:
539
+ selector_update, single_download_update = build_selector_and_download_updates(default_latest=True)
540
+ return (
541
+ current_scene_path,
542
+ build_stream_html(status_text),
543
+ gr.update(value=current_scene_path, interactive=(current_scene_path is not None)),
544
+ selector_update,
545
+ single_download_update,
546
+ )
547
+
548
+ return None
549
+
550
+ yield (
551
+ None,
552
+ build_stream_html("Generating..."),
553
+ gr.update(value=None, interactive=False),
554
+ gr.update(choices=[], value=None),
555
+ gr.update(value=None, interactive=False),
556
+ )
557
+
558
+ with ThreadPoolExecutor(max_workers=1) as executor:
559
+ for instance_name, object_mask in enumerate(mask_pack):
560
+ try:
561
+ flushed = flush_finished_exports("Generating...")
562
+ if flushed is not None:
563
+ yield flushed
564
+
565
+ est_depth = dpt_pack['depth'].to('cpu')
566
+ c2w = dpt_pack['c2w'].to('cpu')
567
+ K = dpt_pack['K'].to('cpu')
568
+
569
+ intrinsics = dpt_pack['K'].float().to(DEVICE)
570
+ extrinsics = copy.deepcopy(dpt_pack['c2w']).float().to(DEVICE)
571
+ extrinsics[:3, 1:3] *= -1
572
+
573
+ object_mask = object_mask > 0
574
+ instance_mask = np.logical_and(object_mask, depth_mask).astype(np.uint8)
575
+ valid_ratio = np.sum((instance_mask > 0).astype(np.float32)) / (H * W)
576
+ print(f'valid ratio of {instance_name}: {valid_ratio:.4f}')
577
+ if valid_ratio < VALID_RATIO_THRESHOLD:
578
+ continue
579
+
580
+ edge_mask = edge_mask_morph_gradient(instance_mask, kernel, 3)
581
+ fg_mask = (instance_mask > edge_mask).astype(np.uint8)
582
+ color_mask = fg_mask.astype(np.float32) + edge_mask.astype(np.float32) * 0.5
583
+
584
+ image = rgb_image
585
+ scene_image, scene_image_masked = process_scene_image(image, instance_mask, CROP_SIZE)
586
+ instance_image, instance_mask, instance_rays_o, instance_rays_d, instance_rays_c, \
587
+ instance_rays_t = process_instance_image(image, instance_mask, color_mask, est_depth, K, c2w, CROP_SIZE)
588
+
589
+ save_image(scene_image, os.path.join(work_space, f'input_scene_image_{instance_name}.png'))
590
+ save_image(scene_image_masked, os.path.join(work_space, f'input_scene_image_masked_{instance_name}.png'))
591
+ save_image(instance_image, os.path.join(work_space, f'input_instance_image_{instance_name}.png'))
592
+ save_image(
593
+ torch.cat([instance_image, instance_mask]),
594
+ os.path.join(work_space, f'input_instance_image_masked_{instance_name}.png')
595
+ )
596
+
597
+ pcd_points = (
598
+ instance_rays_o.to(DEVICE) +
599
+ instance_rays_d.to(DEVICE) * instance_rays_t[..., None].to(DEVICE)
600
+ ).detach().cpu().numpy()
601
+ pcd_colors = instance_rays_c
602
+
603
+ save_projected_colored_pcd(
604
+ pcd_points,
605
+ repeat(pcd_colors, 'n -> n c', c=3),
606
+ f"{work_space}/instance_est_depth_{instance_name}.ply"
607
+ )
608
+
609
+ outputs, coarse_trans, coarse_scale, fine_trans, fine_scale = pipeline.run(
610
+ torch.cat([instance_image, instance_mask]).to(DEVICE),
611
+ scene_image_masked=scene_image_masked.to(DEVICE),
612
+ seed=seed,
613
+ extrinsics=extrinsics.to(DEVICE),
614
+ intrinsics=intrinsics.to(DEVICE),
615
+ points=pcd_points,
616
+ points_mask=pcd_colors,
617
+ sparse_structure_sampler_params={
618
+ "steps": num_inference_steps,
619
+ "cfg_strength": guidance_scale,
620
+ "cfg_interval": [cfg_interval_start, cfg_interval_end],
621
+ "rescale_t": t_rescale
622
+ },
623
+ slat_sampler_params={
624
+ "steps": num_inference_steps,
625
+ "cfg_strength": guidance_scale,
626
+ "cfg_interval": [cfg_interval_start, cfg_interval_end],
627
+ "rescale_t": t_rescale
628
+ }
629
+ )
630
+
631
+ mp4_path = os.path.abspath(
632
+ os.path.join(work_space, f"{run_id}_instance_gs_fine_{instance_name}.mp4")
633
+ )
634
+ poster_path = os.path.abspath(
635
+ os.path.join(work_space, f"{run_id}_instance_gs_fine_{instance_name}.png")
636
+ )
637
+
638
+ video = render_utils.render_video(
639
+ outputs["gaussian"][0],
640
+ bg_color=(1.0, 1.0, 1.0)
641
+ )["color"]
642
+ imageio.mimsave(mp4_path, video, fps=30)
643
+ imageio.imwrite(poster_path, video[0])
644
+
645
+ object_label = f"Object {len(results) + 1}"
646
+ result_index = len(results)
647
+
648
+ results.append({
649
+ "name": object_label,
650
+ "mp4_path": mp4_path,
651
+ "poster_path": poster_path,
652
+ "glb_path": None,
653
+ "instance_index": instance_name,
654
+ "status_text": "Exporting GLB...",
655
+ })
656
+
657
+ # 第一次更新:视频先出来,3D 场景保持当前不变
658
+ yield (
659
+ current_scene_path,
660
+ build_stream_html("Generating..."),
661
+ gr.update(value=current_scene_path, interactive=(current_scene_path is not None)),
662
+ gr.update(choices=[], value=None),
663
+ gr.update(value=None, interactive=False),
664
+ )
665
+
666
+ future = executor.submit(
667
+ export_single_glb_from_outputs,
668
+ outputs=outputs,
669
+ fine_scale=fine_scale,
670
+ fine_trans=fine_trans,
671
+ coarse_scale=coarse_scale,
672
+ coarse_trans=coarse_trans,
673
+ trans=trans,
674
+ scale=scale,
675
+ rot=rot,
676
+ work_space=work_space,
677
+ instance_name=instance_name,
678
+ run_id=run_id,
679
+ )
680
+
681
+ pending_exports.append({
682
+ "future": future,
683
+ "result_index": result_index,
684
+ "instance_name": instance_name,
685
+ "object_label": object_label,
686
+ })
687
+
688
+ flushed = flush_finished_exports("Generating...")
689
+ if flushed is not None:
690
+ yield flushed
691
+
692
+ except Exception as e:
693
+ print(e)
694
+
695
+ while len(pending_exports) > 0:
696
+ flushed = flush_finished_exports("Generating...")
697
+ if flushed is not None:
698
+ yield flushed
699
+ else:
700
+ time.sleep(0.2)
701
+
702
+ ready_items = [item for item in results if item["glb_path"] is not None]
703
+ if len(ready_items) > 0:
704
+ final_scene_path = export_scene_glb(
705
+ trimeshes=trimeshes,
706
+ work_space=work_space,
707
+ scene_name=f"{run_id}_scene_final.glb",
708
+ )
709
+
710
+ selector_update, single_download_update = build_selector_and_download_updates(default_latest=True)
711
+
712
+ yield (
713
+ final_scene_path,
714
+ build_stream_html("Finished"),
715
+ gr.update(value=final_scene_path, interactive=True),
716
+ selector_update,
717
+ single_download_update,
718
+ )
719
+ else:
720
+ yield (
721
+ None,
722
+ "<div style='padding: 8px 0;'><b>Status:</b> No valid object generated.</div>",
723
+ gr.update(value=None, interactive=False),
724
+ gr.update(choices=[], value=None),
725
+ gr.update(value=None, interactive=False),
726
+ )
727
+
728
+ def update_single_download(selected_name):
729
+ global generated_object_map
730
+
731
+ if selected_name is None or selected_name not in generated_object_map:
732
+ return gr.update(value=None, interactive=False)
733
+
734
+ return gr.update(value=generated_object_map[selected_name], interactive=True)
735
+
736
+ # Demo
737
+ with gr.Blocks() as demo:
738
+ gr.Markdown(MARKDOWN)
739
+
740
+ with gr.Column():
741
+ with gr.Row():
742
+ image_prompts = ImagePrompter(label="Input Image", type="pil")
743
+ seg_image = gr.Image(
744
+ label="Segmentation Result", type="pil", format="png"
745
+ )
746
+ with gr.Column():
747
+ with gr.Accordion("Segmentation Settings", open=True):
748
+ polygon_refinement = gr.Checkbox(label="Polygon Refinement", value=False)
749
+ seg_button = gr.Button("Run Segmentation (step 1)")
750
+ dpt_button = gr.Button("Run Depth estimation (step 2)", variant="primary")
751
+ with gr.Row():
752
+ dpt_model_output = gr.Model3D(label="Estimated depth map", interactive=False)
753
+ model_output = gr.Model3D(label="Generated GLB", interactive=False)
754
+ with gr.Column():
755
+ with gr.Accordion("Generation Settings", open=True):
756
+ seed = gr.Slider(
757
+ label="Seed",
758
+ minimum=0,
759
+ maximum=MAX_SEED,
760
+ step=1,
761
+ value=42,
762
+ )
763
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
764
+ num_inference_steps = gr.Slider(
765
+ label="Number of inference steps",
766
+ minimum=1,
767
+ maximum=50,
768
+ step=1,
769
+ value=25,
770
+ )
771
+ with gr.Row():
772
+ cfg_interval_start = gr.Slider(
773
+ label="CFG interval start",
774
+ minimum=0.0,
775
+ maximum=1.0,
776
+ step=0.01,
777
+ value=0.8,
778
+ )
779
+ cfg_interval_end = gr.Slider(
780
+ label="CFG interval end",
781
+ minimum=0.0,
782
+ maximum=1.0,
783
+ step=0.01,
784
+ value=1.0,
785
+ )
786
+ t_rescale = gr.Slider(
787
+ label="t rescale factor",
788
+ minimum=1.0,
789
+ maximum=5.0,
790
+ step=0.1,
791
+ value=5.0,
792
+ )
793
+ guidance_scale = gr.Slider(
794
+ label="CFG scale",
795
+ minimum=0.0,
796
+ maximum=10.0,
797
+ step=0.1,
798
+ value=5.0,
799
+ )
800
+ gen_button = gr.Button("Run Generation (step 3)", variant="primary", interactive=False)
801
+ download_glb = gr.DownloadButton(label="Download scene GLB", interactive=False)
802
+ with gr.Row():
803
+ object_selector = gr.Dropdown(label="Choose instance: ")
804
+ download_single_glb = gr.DownloadButton(label="Download single GLB", interactive=False)
805
+
806
+ stream_output = gr.HTML(label="Generated Objects Stream")
807
+ with gr.Row():
808
+ gr.Examples(
809
+ examples=EXAMPLES,
810
+ fn=run_generation,
811
+ inputs=[image_prompts, seg_image, seed, randomize_seed, num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale],
812
+ outputs=[model_output, download_glb, seed],
813
+ cache_examples=False,
814
+ )
815
+
816
+ seg_button.click(
817
+ run_segmentation,
818
+ inputs=[
819
+ image_prompts,
820
+ polygon_refinement,
821
+ ],
822
+ outputs=[seg_image],
823
+ ).then(lambda: gr.Button(interactive=True), outputs=[dpt_button])
824
+
825
+ dpt_button.click(
826
+ run_depth_estimation,
827
+ inputs=[
828
+ image_prompts,
829
+ seg_image
830
+ ],
831
+ outputs=[dpt_model_output],
832
+ ).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
833
+
834
+ gen_button.click(
835
+ run_generation,
836
+ inputs=[
837
+ image_prompts,
838
+ seg_image,
839
+ seed,
840
+ randomize_seed,
841
+ num_inference_steps,
842
+ guidance_scale,
843
+ cfg_interval_start,
844
+ cfg_interval_end,
845
+ t_rescale
846
+ ],
847
+ outputs=[model_output,
848
+ stream_output,
849
+ download_glb,
850
+ object_selector,
851
+ download_single_glb],
852
+ )
853
+
854
+ object_selector.change(
855
+ update_single_download,
856
+ inputs=[object_selector],
857
+ outputs=[download_single_glb],
858
+ )
859
 
860
+ demo.launch(allowed_paths=[TMP_DIR, EXAMPLE_DIR])
 
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ packaging
2
+ wheel
3
+ pybind11
4
+ ninja
5
+ Cython
6
+ torch==2.4.0+cu118
7
+ torchvision==0.19.0+cu118
8
+ pillow
9
+ imageio
10
+ imageio-ffmpeg
11
+ tqdm
12
+ easydict
13
+ opencv-python-headless
14
+ scipy
15
+ rembg
16
+ onnxruntime
17
+ trimesh
18
+ open3d
19
+ xatlas
20
+ pyvista
21
+ pymeshfix
22
+ igraph
23
+ transformers
24
+ icecream
25
+ plyfile
26
+ pycocotools
27
+ shapely
28
+ git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
29
+ flash-attn
30
+ kaolin==0.17.0
31
+ spconv-cu118
32
+ gradio==4.44.1
33
+ gradio_image_prompter
scripts/grounding_sam.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/VAST-AI-Research/MIDI-3D
2
+ # Original license: Apache-2.0 license
3
+ # Copyright (c) the MIDI-3D authors
4
+
5
+ import argparse
6
+ import os
7
+ import random
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import requests
14
+ import torch
15
+ from PIL import Image
16
+ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
17
+
18
+
19
+ def create_palette():
20
+ # Define a palette with 24 colors for labels 0-23 (example colors)
21
+ palette = [
22
+ 0,
23
+ 0,
24
+ 0, # Label 0 (black)
25
+ 255,
26
+ 0,
27
+ 0, # Label 1 (red)
28
+ 0,
29
+ 255,
30
+ 0, # Label 2 (green)
31
+ 0,
32
+ 0,
33
+ 255, # Label 3 (blue)
34
+ 255,
35
+ 255,
36
+ 0, # Label 4 (yellow)
37
+ 255,
38
+ 0,
39
+ 255, # Label 5 (magenta)
40
+ 0,
41
+ 255,
42
+ 255, # Label 6 (cyan)
43
+ 128,
44
+ 0,
45
+ 0, # Label 7 (dark red)
46
+ 0,
47
+ 128,
48
+ 0, # Label 8 (dark green)
49
+ 0,
50
+ 0,
51
+ 128, # Label 9 (dark blue)
52
+ 128,
53
+ 128,
54
+ 0, # Label 10
55
+ 128,
56
+ 0,
57
+ 128, # Label 11
58
+ 0,
59
+ 128,
60
+ 128, # Label 12
61
+ 64,
62
+ 0,
63
+ 0, # Label 13
64
+ 0,
65
+ 64,
66
+ 0, # Label 14
67
+ 0,
68
+ 0,
69
+ 64, # Label 15
70
+ 64,
71
+ 64,
72
+ 0, # Label 16
73
+ 64,
74
+ 0,
75
+ 64, # Label 17
76
+ 0,
77
+ 64,
78
+ 64, # Label 18
79
+ 192,
80
+ 192,
81
+ 192, # Label 19 (light gray)
82
+ 128,
83
+ 128,
84
+ 128, # Label 20 (gray)
85
+ 255,
86
+ 165,
87
+ 0, # Label 21 (orange)
88
+ 75,
89
+ 0,
90
+ 130, # Label 22 (indigo)
91
+ 238,
92
+ 130,
93
+ 238, # Label 23 (violet)
94
+ ]
95
+ # Extend the palette to have 768 values (256 * 3)
96
+ palette.extend([0] * (768 - len(palette)))
97
+ return palette
98
+
99
+
100
+ PALETTE = create_palette()
101
+
102
+
103
+ # Result Utils
104
+ @dataclass
105
+ class BoundingBox:
106
+ xmin: int
107
+ ymin: int
108
+ xmax: int
109
+ ymax: int
110
+
111
+ @property
112
+ def xyxy(self) -> List[float]:
113
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
114
+
115
+
116
+ @dataclass
117
+ class DetectionResult:
118
+ score: Optional[float] = None
119
+ label: Optional[str] = None
120
+ box: Optional[BoundingBox] = None
121
+ mask: Optional[np.array] = None
122
+
123
+ @classmethod
124
+ def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
125
+ return cls(
126
+ score=detection_dict["score"],
127
+ label=detection_dict["label"],
128
+ box=BoundingBox(
129
+ xmin=detection_dict["box"]["xmin"],
130
+ ymin=detection_dict["box"]["ymin"],
131
+ xmax=detection_dict["box"]["xmax"],
132
+ ymax=detection_dict["box"]["ymax"],
133
+ ),
134
+ )
135
+
136
+
137
+ # Utils
138
+ def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
139
+ # Find contours in the binary mask
140
+ contours, _ = cv2.findContours(
141
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
142
+ )
143
+
144
+ # Find the contour with the largest area
145
+ largest_contour = max(contours, key=cv2.contourArea)
146
+
147
+ # Extract the vertices of the contour
148
+ polygon = largest_contour.reshape(-1, 2).tolist()
149
+
150
+ return polygon
151
+
152
+
153
+ def polygon_to_mask(
154
+ polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]
155
+ ) -> np.ndarray:
156
+ """
157
+ Convert a polygon to a segmentation mask.
158
+
159
+ Args:
160
+ - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
161
+ - image_shape (tuple): Shape of the image (height, width) for the mask.
162
+
163
+ Returns:
164
+ - np.ndarray: Segmentation mask with the polygon filled.
165
+ """
166
+ # Create an empty mask
167
+ mask = np.zeros(image_shape, dtype=np.uint8)
168
+
169
+ # Convert polygon to an array of points
170
+ pts = np.array(polygon, dtype=np.int32)
171
+
172
+ # Fill the polygon with white color (255)
173
+ cv2.fillPoly(mask, [pts], color=(255,))
174
+
175
+ return mask
176
+
177
+
178
+ def load_image(image_str: str) -> Image.Image:
179
+ if image_str.startswith("http"):
180
+ image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
181
+ else:
182
+ image = Image.open(image_str).convert("RGB")
183
+
184
+ return image
185
+
186
+
187
+ def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
188
+ boxes = []
189
+ for result in results:
190
+ xyxy = result.box.xyxy
191
+ boxes.append(xyxy)
192
+
193
+ return [boxes]
194
+
195
+
196
+ def refine_masks(
197
+ masks: torch.BoolTensor, polygon_refinement: bool = False
198
+ ) -> List[np.ndarray]:
199
+ masks = masks.cpu().float()
200
+ masks = masks.permute(0, 2, 3, 1)
201
+ masks = masks.mean(axis=-1)
202
+ masks = (masks > 0).int()
203
+ masks = masks.numpy().astype(np.uint8)
204
+ masks = list(masks)
205
+
206
+ if polygon_refinement:
207
+ for idx, mask in enumerate(masks):
208
+ shape = mask.shape
209
+ polygon = mask_to_polygon(mask)
210
+ mask = polygon_to_mask(polygon, shape)
211
+ masks[idx] = mask
212
+
213
+ return masks
214
+
215
+
216
+ # Post-processing Utils
217
+ def generate_colored_segmentation(label_image):
218
+ # Create a PIL Image from the label image (assuming it's a 2D numpy array)
219
+ label_image_pil = Image.fromarray(label_image.astype(np.uint8), mode="P")
220
+
221
+ # Apply the palette to the image
222
+ palette = create_palette()
223
+ label_image_pil.putpalette(palette)
224
+
225
+ return label_image_pil
226
+
227
+
228
+ def plot_segmentation(image, detections):
229
+ seg_map = np.zeros(image.size[::-1], dtype=np.uint8)
230
+ for i, detection in enumerate(detections):
231
+ mask = detection.mask
232
+ seg_map[mask > 0] = i + 1
233
+ seg_map_pil = generate_colored_segmentation(seg_map)
234
+ return seg_map_pil
235
+
236
+
237
+ # Grounded SAM
238
+ def prepare_model(
239
+ device: str = "cuda",
240
+ detector_id: Optional[str] = None,
241
+ segmenter_id: Optional[str] = None,
242
+ ):
243
+ detector_id = (
244
+ detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
245
+ )
246
+ object_detector = pipeline(
247
+ model=detector_id, task="zero-shot-object-detection", device=device
248
+ )
249
+
250
+ segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
251
+ processor = AutoProcessor.from_pretrained(segmenter_id)
252
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
253
+
254
+ return object_detector, processor, segmentator
255
+
256
+
257
+ def detect(
258
+ object_detector: Any,
259
+ image: Image.Image,
260
+ labels: List[str],
261
+ threshold: float = 0.3,
262
+ ) -> List[Dict[str, Any]]:
263
+ """
264
+ Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
265
+ """
266
+ labels = [label if label.endswith(".") else label + "." for label in labels]
267
+
268
+ results = object_detector(image, candidate_labels=labels, threshold=threshold)
269
+ results = [DetectionResult.from_dict(result) for result in results]
270
+
271
+ return results
272
+
273
+
274
+ def segment(
275
+ processor: Any,
276
+ segmentator: Any,
277
+ image: Image.Image,
278
+ boxes: Optional[List[List[List[float]]]] = None,
279
+ detection_results: Optional[List[Dict[str, Any]]] = None,
280
+ polygon_refinement: bool = False,
281
+ ) -> List[DetectionResult]:
282
+ """
283
+ Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
284
+ """
285
+ if detection_results is None and boxes is None:
286
+ raise ValueError(
287
+ "Either detection_results or detection_boxes must be provided."
288
+ )
289
+
290
+ if boxes is None:
291
+ boxes = get_boxes(detection_results)
292
+
293
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(
294
+ segmentator.device, segmentator.dtype
295
+ )
296
+
297
+ outputs = segmentator(**inputs)
298
+ masks = processor.post_process_masks(
299
+ masks=outputs.pred_masks,
300
+ original_sizes=inputs.original_sizes,
301
+ reshaped_input_sizes=inputs.reshaped_input_sizes,
302
+ )[0]
303
+
304
+ masks = refine_masks(masks, polygon_refinement)
305
+
306
+ if detection_results is None:
307
+ detection_results = [DetectionResult() for _ in masks]
308
+
309
+ for detection_result, mask in zip(detection_results, masks):
310
+ detection_result.mask = mask
311
+
312
+ return detection_results
313
+
314
+
315
+ def grounded_segmentation(
316
+ object_detector,
317
+ processor,
318
+ segmentator,
319
+ image: Union[Image.Image, str],
320
+ labels: Union[str, List[str]],
321
+ threshold: float = 0.3,
322
+ polygon_refinement: bool = False,
323
+ ) -> Tuple[np.ndarray, List[DetectionResult], Image.Image]:
324
+ if isinstance(image, str):
325
+ image = load_image(image)
326
+ if isinstance(labels, str):
327
+ labels = labels.split(",")
328
+
329
+ detections = detect(object_detector, image, labels, threshold)
330
+ detections = segment(
331
+ processor,
332
+ segmentator,
333
+ image,
334
+ detection_results=detections,
335
+ polygon_refinement=polygon_refinement,
336
+ )
337
+
338
+ seg_map_pil = plot_segmentation(image, detections)
339
+
340
+ return np.array(image), detections, seg_map_pil
341
+
342
+
343
+ if __name__ == "__main__":
344
+ parser = argparse.ArgumentParser()
345
+ parser.add_argument("--image", type=str, required=True)
346
+ parser.add_argument("--labels", type=str, nargs="+", required=True)
347
+ parser.add_argument("--output", type=str, default="./", help="Output directory")
348
+ parser.add_argument("--threshold", type=float, default=0.3)
349
+ parser.add_argument(
350
+ "--detector_id", type=str, default="IDEA-Research/grounding-dino-base"
351
+ )
352
+ parser.add_argument("--segmenter_id", type=str, default="facebook/sam-vit-base")
353
+ args = parser.parse_args()
354
+
355
+ device = "cuda" if torch.cuda.is_available() else "cpu"
356
+ object_detector, processor, segmentator = prepare_model(
357
+ device=device, detector_id=args.detector_id, segmenter_id=args.segmenter_id
358
+ )
359
+
360
+ image_array, detections, seg_map_pil = grounded_segmentation(
361
+ object_detector,
362
+ processor,
363
+ segmentator,
364
+ image=args.image,
365
+ labels=args.labels,
366
+ threshold=args.threshold,
367
+ polygon_refinement=True,
368
+ )
369
+
370
+ os.makedirs(args.output, exist_ok=True)
371
+ seg_map_pil.save(os.path.join(args.output, "segmentation.png"))
scripts/grounding_sam2.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/Mengmouxu/SceneGen
2
+ # Original license: MIT license
3
+ # Copyright (c) the SceneGen authors
4
+
5
+ import argparse
6
+ import os
7
+ import random
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import requests
14
+ import torch
15
+ from PIL import Image
16
+ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
17
+ from contextlib import nullcontext
18
+
19
+
20
+ def create_palette():
21
+ # Define a palette with 24 colors for labels 0-23 (example colors)
22
+ palette = [
23
+ 0,
24
+ 0,
25
+ 0, # Label 0 (black)
26
+ 255,
27
+ 0,
28
+ 0, # Label 1 (red)
29
+ 0,
30
+ 255,
31
+ 0, # Label 2 (green)
32
+ 0,
33
+ 0,
34
+ 255, # Label 3 (blue)
35
+ 255,
36
+ 255,
37
+ 0, # Label 4 (yellow)
38
+ 255,
39
+ 0,
40
+ 255, # Label 5 (magenta)
41
+ 0,
42
+ 255,
43
+ 255, # Label 6 (cyan)
44
+ 128,
45
+ 0,
46
+ 0, # Label 7 (dark red)
47
+ 0,
48
+ 128,
49
+ 0, # Label 8 (dark green)
50
+ 0,
51
+ 0,
52
+ 128, # Label 9 (dark blue)
53
+ 128,
54
+ 128,
55
+ 0, # Label 10
56
+ 128,
57
+ 0,
58
+ 128, # Label 11
59
+ 0,
60
+ 128,
61
+ 128, # Label 12
62
+ 64,
63
+ 0,
64
+ 0, # Label 13
65
+ 0,
66
+ 64,
67
+ 0, # Label 14
68
+ 0,
69
+ 0,
70
+ 64, # Label 15
71
+ 64,
72
+ 64,
73
+ 0, # Label 16
74
+ 64,
75
+ 0,
76
+ 64, # Label 17
77
+ 0,
78
+ 64,
79
+ 64, # Label 18
80
+ 192,
81
+ 192,
82
+ 192, # Label 19 (light gray)
83
+ 128,
84
+ 128,
85
+ 128, # Label 20 (gray)
86
+ 255,
87
+ 165,
88
+ 0, # Label 21 (orange)
89
+ 75,
90
+ 0,
91
+ 130, # Label 22 (indigo)
92
+ 238,
93
+ 130,
94
+ 238, # Label 23 (violet)
95
+ ]
96
+ # Extend the palette to have 768 values (256 * 3)
97
+ palette.extend([0] * (768 - len(palette)))
98
+ return palette
99
+
100
+
101
+ PALETTE = create_palette()
102
+
103
+
104
+ # Result Utils
105
+ @dataclass
106
+ class BoundingBox:
107
+ xmin: int
108
+ ymin: int
109
+ xmax: int
110
+ ymax: int
111
+
112
+ @property
113
+ def xyxy(self) -> List[float]:
114
+ return [self.xmin, self.ymin, self.xmax, self.ymax]
115
+
116
+
117
+ @dataclass
118
+ class DetectionResult:
119
+ score: Optional[float] = None
120
+ label: Optional[str] = None
121
+ box: Optional[BoundingBox] = None
122
+ mask: Optional[np.array] = None
123
+
124
+ @classmethod
125
+ def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
126
+ return cls(
127
+ score=detection_dict["score"],
128
+ label=detection_dict["label"],
129
+ box=BoundingBox(
130
+ xmin=detection_dict["box"]["xmin"],
131
+ ymin=detection_dict["box"]["ymin"],
132
+ xmax=detection_dict["box"]["xmax"],
133
+ ymax=detection_dict["box"]["ymax"],
134
+ ),
135
+ )
136
+
137
+
138
+ # Utils
139
+ def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
140
+ # Find contours in the binary mask
141
+ contours, _ = cv2.findContours(
142
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
143
+ )
144
+
145
+ # Find the contour with the largest area
146
+ largest_contour = max(contours, key=cv2.contourArea)
147
+
148
+ # Extract the vertices of the contour
149
+ polygon = largest_contour.reshape(-1, 2).tolist()
150
+
151
+ return polygon
152
+
153
+
154
+ def polygon_to_mask(
155
+ polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]
156
+ ) -> np.ndarray:
157
+ """
158
+ Convert a polygon to a segmentation mask.
159
+
160
+ Args:
161
+ - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
162
+ - image_shape (tuple): Shape of the image (height, width) for the mask.
163
+
164
+ Returns:
165
+ - np.ndarray: Segmentation mask with the polygon filled.
166
+ """
167
+ # Create an empty mask
168
+ mask = np.zeros(image_shape, dtype=np.uint8)
169
+
170
+ # Convert polygon to an array of points
171
+ pts = np.array(polygon, dtype=np.int32)
172
+
173
+ # Fill the polygon with white color (255)
174
+ cv2.fillPoly(mask, [pts], color=(255,))
175
+
176
+ return mask
177
+
178
+
179
+ def load_image(image_str: str) -> Image.Image:
180
+ if image_str.startswith("http"):
181
+ image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
182
+ else:
183
+ image = Image.open(image_str).convert("RGB")
184
+
185
+ return image
186
+
187
+
188
+ def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
189
+ boxes = []
190
+ for result in results:
191
+ xyxy = result.box.xyxy
192
+ boxes.append(xyxy)
193
+
194
+ return [boxes]
195
+
196
+
197
+ def refine_masks(
198
+ masks: torch.BoolTensor, polygon_refinement: bool = False
199
+ ) -> List[np.ndarray]:
200
+ masks = masks.cpu().float()
201
+ masks = masks.permute(0, 2, 3, 1)
202
+ masks = masks.mean(axis=-1)
203
+ masks = (masks > 0).int()
204
+ masks = masks.numpy().astype(np.uint8)
205
+ masks = list(masks)
206
+
207
+ if polygon_refinement:
208
+ for idx, mask in enumerate(masks):
209
+ shape = mask.shape
210
+ polygon = mask_to_polygon(mask)
211
+ mask = polygon_to_mask(polygon, shape)
212
+ masks[idx] = mask
213
+
214
+ return masks
215
+
216
+
217
+ # Post-processing Utils
218
+ def generate_colored_segmentation(label_image):
219
+ # Create a PIL Image from the label image (assuming it's a 2D numpy array)
220
+ label_image_pil = Image.fromarray(label_image.astype(np.uint8), mode="P")
221
+
222
+ # Apply the palette to the image
223
+ palette = create_palette()
224
+ label_image_pil.putpalette(palette)
225
+
226
+ return label_image_pil
227
+
228
+
229
+ def plot_segmentation(image, detections):
230
+ seg_map = np.zeros(image.size[::-1], dtype=np.uint8)
231
+ for i, detection in enumerate(detections):
232
+ mask = detection.mask
233
+ seg_map[mask > 0] = i + 1
234
+ seg_map_pil = generate_colored_segmentation(seg_map)
235
+ return seg_map_pil
236
+
237
+
238
+ # Grounded SAM
239
+ def prepare_model(
240
+ device: str = "cuda",
241
+ detector_id: Optional[str] = None,
242
+ segmenter_id: Optional[str] = None,
243
+ ):
244
+ detector_id = (
245
+ detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
246
+ )
247
+ object_detector = pipeline(
248
+ model=detector_id, task="zero-shot-object-detection", device=device
249
+ )
250
+
251
+ segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"
252
+ processor = AutoProcessor.from_pretrained(segmenter_id)
253
+ segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
254
+
255
+ return object_detector, processor, segmentator
256
+
257
+
258
+ def detect(
259
+ object_detector: Any,
260
+ image: Image.Image,
261
+ labels: List[str],
262
+ threshold: float = 0.3,
263
+ ) -> List[Dict[str, Any]]:
264
+ """
265
+ Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
266
+ """
267
+ labels = [label if label.endswith(".") else label + "." for label in labels]
268
+
269
+ results = object_detector(image, candidate_labels=labels, threshold=threshold)
270
+ results = [DetectionResult.from_dict(result) for result in results]
271
+
272
+ return results
273
+
274
+
275
+ def segment(
276
+ predictor: Any,
277
+ image: Image.Image,
278
+ boxes: Optional[List[List[List[float]]]] = None,
279
+ detection_results: Optional[List[Dict[str, Any]]] = None,
280
+ polygon_refinement: bool = False,
281
+ ) -> List[DetectionResult]:
282
+ """
283
+ Use SAM2 predictor to generate masks given an image + a set of bounding boxes.
284
+ """
285
+
286
+ if detection_results is None and boxes is None:
287
+ raise ValueError("Either detection_results or detection_boxes must be provided.")
288
+
289
+ # Build boxes from detections if not provided
290
+ if boxes is None:
291
+ boxes = get_boxes(detection_results)
292
+ # Flatten potential [[...], ...] -> [...]
293
+ if isinstance(boxes, list) and len(boxes) == 1 and isinstance(boxes[0], list):
294
+ boxes = boxes[0]
295
+
296
+ # Ensure image is a numpy RGB array (H, W, 3)
297
+ if isinstance(image, Image.Image):
298
+ np_image = np.array(image.convert("RGB"))
299
+ else:
300
+ np_image = np.array(image)
301
+
302
+ # Resolve device
303
+ device = getattr(predictor, "device", None)
304
+ if device is None:
305
+ model = getattr(predictor, "model", None)
306
+ if model is not None:
307
+ device = next(model.parameters()).device
308
+ if device is None:
309
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
310
+
311
+ # Prepare autocast context only for CUDA
312
+ amp_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if device.type == "cuda" else nullcontext()
313
+
314
+ # Run predictor
315
+ with torch.inference_mode():
316
+ with amp_ctx:
317
+ predictor.set_image(np_image)
318
+
319
+ # Boxes to tensor
320
+ boxes_t = torch.tensor(boxes, dtype=torch.float32, device=device)
321
+ # Transform boxes if predictor exposes a transform like SAM/SAM2
322
+ if hasattr(predictor, "transform") and hasattr(predictor.transform, "apply_boxes_torch"):
323
+ boxes_in = predictor.transform.apply_boxes_torch(boxes_t, np_image.shape[:2])
324
+ else:
325
+ boxes_in = boxes_t
326
+
327
+ # Predict masks for boxes; request single mask per box
328
+ masks, scores, _ = predictor.predict(
329
+ box=boxes_in,
330
+ multimask_output=False
331
+ )
332
+
333
+ # Normalize masks to numpy [N, H, W] boolean
334
+ if isinstance(masks, torch.Tensor):
335
+ masks_np = masks.detach().cpu().numpy()
336
+ else:
337
+ masks_np = np.asarray(masks)
338
+
339
+ if masks_np.ndim == 4 and masks_np.shape[1] == 1:
340
+ masks_np = masks_np[:, 0] # [N, 1, H, W] -> [N, H, W]
341
+ masks_np = (masks_np > 0).astype(np.uint8)
342
+
343
+ # Reuse refine_masks to optionally polygon-refine
344
+ masks_torch = torch.from_numpy(masks_np).unsqueeze(1).to(torch.bool) # [N,1,H,W]
345
+ masks_list = refine_masks(masks_torch, polygon_refinement)
346
+
347
+ if detection_results is None:
348
+ detection_results = [DetectionResult() for _ in masks_list]
349
+
350
+ for detection_result, mask in zip(detection_results, masks_list):
351
+ detection_result.mask = mask
352
+
353
+ return detection_results
threeDFixer/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from . import models
7
+ from . import modules
8
+ from . import pipelines
9
+ from . import renderers
10
+ from . import representations
11
+ from . import utils
threeDFixer/datasets/__init__.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from TRELLIS:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+ import importlib
8
+
9
+ __attributes = {
10
+ 'SparseStructure': 'sparse_structure',
11
+
12
+ 'SparseFeat2Render': 'sparse_feat2render',
13
+ 'SLat2Render':'structured_latent2render',
14
+ 'Slat2RenderGeo':'structured_latent2render',
15
+
16
+ 'SparseStructureLatent': 'sparse_structure_latent',
17
+ 'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
18
+ 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
19
+
20
+ 'SLat': 'structured_latent',
21
+ 'TextConditionedSLat': 'structured_latent',
22
+ 'ImageConditionedSLat': 'structured_latent',
23
+
24
+ 'ImageConditionedSparseStructureLatentRandRot': 'sparse_structure_latent_random_rot',
25
+ 'ImageConditionedSLatRandRot': 'structured_latent_random_rot',
26
+ 'SparseFeat2RenderRandRot': 'sparse_feat2render_random_rot',
27
+ 'Slat2RenderGeoRandRot': 'structured_latent2render_random_rot',
28
+
29
+ 'ObjectImageConditionedSparseStructureVoxel': 'scene_sparse_structure_latent_obj_pretrain',
30
+ 'SceneImageConditionedVoxel': 'scene_sparse_structure_latent',
31
+ 'SceneConditionedSLat': 'scene_structured_latent',
32
+ }
33
+
34
+ __submodules = []
35
+
36
+ __all__ = list(__attributes.keys()) + __submodules
37
+
38
+ def __getattr__(name):
39
+ if name not in globals():
40
+ if name in __attributes:
41
+ module_name = __attributes[name]
42
+ module = importlib.import_module(f".{module_name}", __name__)
43
+ globals()[name] = getattr(module, name)
44
+ elif name in __submodules:
45
+ module = importlib.import_module(f".{name}", __name__)
46
+ globals()[name] = module
47
+ else:
48
+ raise AttributeError(f"module {__name__} has no attribute {name}")
49
+ return globals()[name]
50
+
51
+
52
+ # For Pylance
53
+ if __name__ == '__main__':
54
+ from .sparse_structure import SparseStructure
55
+
56
+ from .sparse_feat2render import SparseFeat2Render
57
+ from .structured_latent2render import (
58
+ SLat2Render,
59
+ Slat2RenderGeo,
60
+ )
61
+
62
+ from .sparse_structure_latent import (
63
+ SparseStructureLatent,
64
+ TextConditionedSparseStructureLatent,
65
+ ImageConditionedSparseStructureLatent,
66
+ )
67
+
68
+ from .structured_latent import (
69
+ SLat,
70
+ TextConditionedSLat,
71
+ ImageConditionedSLat,
72
+ )
73
+
74
+ # rot mesh
75
+ from .sparse_structure_latent_random_rot import (
76
+ ImageConditionedSparseStructureLatentRandRot
77
+ )
78
+
79
+ # rot SLAT
80
+ from .structured_latent_random_rot import (
81
+ ImageConditionedSLatRandRot
82
+ )
83
+
84
+ # VAE gs dec
85
+ from .sparse_feat2render_random_rot import (
86
+ SparseFeat2RenderRandRot
87
+ )
88
+
89
+ # VAE mesh dec
90
+ from .structured_latent2render_random_rot import (
91
+ Slat2RenderGeoRandRot
92
+ )
93
+
94
+ # object-level pre-training
95
+ from .scene_sparse_structure_latent_obj_pretrain import (
96
+ ObjectImageConditionedSparseStructureVoxel
97
+ )
98
+
99
+ # scene-level training dataloader for stage 1
100
+ from .scene_sparse_structure_latent import (
101
+ SceneImageConditionedVoxel
102
+ )
103
+
104
+ # scene-level training dataloader for stage 2
105
+ from .scene_structured_latent import (
106
+ SceneConditionedSLat
107
+ )
threeDFixer/datasets/utils.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # See the LICENSE file in the project root for full license information.
4
+
5
+ import os
6
+ import json
7
+ import cv2
8
+ import torch
9
+ from PIL import Image
10
+ import imageio
11
+ import numpy as np
12
+ import open3d as o3d
13
+ from einops import rearrange
14
+
15
+ def voxelize_mesh(points, faces, clip_range_first=False, return_mask=True, resolution=64):
16
+ if clip_range_first:
17
+ points = np.clip(points, -0.5 + 1e-6, 0.5 - 1e-6)
18
+ mesh = o3d.geometry.TriangleMesh()
19
+ mesh.vertices = o3d.utility.Vector3dVector(points)
20
+ if isinstance(faces, o3d.cuda.pybind.utility.Vector3iVector):
21
+ mesh.triangles = faces
22
+ else:
23
+ mesh.triangles = o3d.cuda.pybind.utility.Vector3iVector(faces)
24
+ voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
25
+ vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
26
+ assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
27
+ vertices = (vertices + 0.5) / 64 - 0.5
28
+ coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous()
29
+ ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long)
30
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
31
+ if return_mask:
32
+ ss_mask = rearrange(ss, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float()
33
+ return ss , ss_mask
34
+ else:
35
+ return ss
36
+
37
+ def transform_vertices(vertices, ops, params):
38
+ for op, param in zip(ops, params):
39
+ if op == 'scale':
40
+ vertices = vertices * param
41
+ elif op == 'translation':
42
+ vertices = vertices + param
43
+ else:
44
+ raise NotImplementedError
45
+ return vertices
46
+
47
+ def normalize_vertices(vertices, scale_factor=1.0):
48
+ min_pos, max_pos = np.min(vertices, axis=0), np.max(vertices, axis=0)
49
+ trans_pos = (min_pos + max_pos)[None] / 2.0
50
+ scale_pos = np.max(max_pos - min_pos) * scale_factor # 1: [-0.5, 0.5], 2.0: [-0.25, 0.25]
51
+
52
+ vertices = transform_vertices(vertices, ops=['translation', 'scale'],
53
+ params=[-trans_pos, 1.0 / (scale_pos + 1e-6)])
54
+ return vertices, trans_pos, scale_pos
55
+
56
+ def renormalize_vertices(vertices, val_range=0.5, scale_factor=1.25):
57
+ min_pos, max_pos = np.min(vertices, axis=0), np.max(vertices, axis=0)
58
+ if (min_pos < -val_range).any() or (max_pos > val_range).any():
59
+ trans_pos = (min_pos + max_pos)[None] / 2.0
60
+ scale_pos = np.max(max_pos - min_pos) * scale_factor # 1: [-0.5, 0.5], 2.0: [-0.25, 0.25]
61
+ vertices = transform_vertices(vertices, ops=['translation', 'scale'],
62
+ params=[-trans_pos, 1.0 / (scale_pos + 1e-6)])
63
+ return vertices
64
+
65
+ def rot_vertices(vertices, rot_angles, axis_list=['z']):
66
+ pcd = o3d.geometry.PointCloud()
67
+ pcd.points = o3d.utility.Vector3dVector(vertices)
68
+ for ang, axis in zip(rot_angles, axis_list):
69
+ if axis == 'x':
70
+ R = pcd.get_rotation_matrix_from_xyz((ang, 0, 0))
71
+ pcd.rotate(R, center=(0., 0., 0.))
72
+ del R
73
+ elif axis == 'y':
74
+ R = pcd.get_rotation_matrix_from_xyz((0, ang, 0))
75
+ pcd.rotate(R, center=(0., 0., 0.))
76
+ del R
77
+ elif axis == 'z':
78
+ R = pcd.get_rotation_matrix_from_xyz((0, 0, ang))
79
+ pcd.rotate(R, center=(0., 0., 0.))
80
+ del R
81
+ else:
82
+ raise NotImplementedError
83
+ rot_vertices = np.array(pcd.points)
84
+ del pcd
85
+ return rot_vertices
86
+
87
+ def _rotmat_x(a: torch.Tensor) -> torch.Tensor:
88
+ # a: scalar tensor
89
+ ca, sa = torch.cos(a), torch.sin(a)
90
+ R = torch.stack([
91
+ torch.stack([torch.ones_like(a), torch.zeros_like(a), torch.zeros_like(a)]),
92
+ torch.stack([torch.zeros_like(a), ca, -sa]),
93
+ torch.stack([torch.zeros_like(a), sa, ca]),
94
+ ])
95
+ return R # [3,3]
96
+
97
+ def _rotmat_y(a: torch.Tensor) -> torch.Tensor:
98
+ ca, sa = torch.cos(a), torch.sin(a)
99
+ R = torch.stack([
100
+ torch.stack([ca, torch.zeros_like(a), sa]),
101
+ torch.stack([torch.zeros_like(a), torch.ones_like(a), torch.zeros_like(a)]),
102
+ torch.stack([-sa, torch.zeros_like(a), ca]),
103
+ ])
104
+ return R
105
+
106
+ def _rotmat_z(a: torch.Tensor) -> torch.Tensor:
107
+ ca, sa = torch.cos(a), torch.sin(a)
108
+ R = torch.stack([
109
+ torch.stack([ca, -sa, torch.zeros_like(a)]),
110
+ torch.stack([sa, ca, torch.zeros_like(a)]),
111
+ torch.stack([torch.zeros_like(a), torch.zeros_like(a), torch.ones_like(a)]),
112
+ ])
113
+ return R
114
+
115
+ def rot_vertices_torch(vertices, rot_angles, axis_list=('z',), center=(0.0, 0.0, 0.0)):
116
+ """
117
+ vertices: (N,3) numpy or torch
118
+ rot_angles: iterable of angles (radians), length matches axis_list
119
+ axis_list: iterable like ['x','y','z'] (applied in order)
120
+ center: rotation center, default origin (0,0,0), same as your Open3D code
121
+
122
+ return: torch.Tensor (N,3)
123
+ """
124
+ v = torch.as_tensor(vertices)
125
+ device, dtype = v.device, v.dtype
126
+
127
+ c = torch.tensor(center, device=device, dtype=dtype).view(1, 3)
128
+ v = v - c # translate to center
129
+
130
+ # Compose rotations in the same order as your for-loop:
131
+ # Open3D effectively does v <- v @ R^T (for row-vector points).
132
+ for ang, axis in zip(rot_angles, axis_list):
133
+ a = torch.as_tensor(ang, device=device, dtype=dtype)
134
+ if axis == 'x':
135
+ R = _rotmat_x(a)
136
+ elif axis == 'y':
137
+ R = _rotmat_y(a)
138
+ elif axis == 'z':
139
+ R = _rotmat_z(a)
140
+ else:
141
+ raise NotImplementedError(f"Unknown axis {axis}")
142
+
143
+ v = v @ R.T # match Open3D row-vector convention
144
+
145
+ v = v + c
146
+ return v
147
+
148
+ def get_instance_mask(instance_mask_path):
149
+ index_mask = imageio.v3.imread(instance_mask_path)
150
+ index_mask = np.rint(index_mask.astype(np.float32) / 65535 * 100.0) # hand coded, max obj nums = 100
151
+ instance_list = np.unique(index_mask).astype(np.uint8)
152
+ return index_mask, instance_list
153
+
154
+ def get_gt_depth(gt_depth_path, metadata):
155
+ gt_depth = imageio.v3.imread(gt_depth_path).astype(np.float32) / 65535.
156
+ depth_min, depth_max = metadata['depth']['min'], metadata['depth']['max']
157
+ gt_depth = gt_depth * (depth_max - depth_min) + depth_min
158
+ return torch.from_numpy(gt_depth).to(dtype=torch.float32)
159
+
160
+ def get_est_depth(est_depth_path):
161
+ npz = np.load(est_depth_path)
162
+ est_depth = npz['depth']
163
+ est_depth_mask = npz['mask']
164
+ est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32)
165
+ ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
166
+ est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
167
+ est_depth = torch.where(ivalid_mask, 0.0, est_depth)
168
+ return est_depth, est_depth_mask
169
+
170
+ def get_mix_est_depth(est_depth_path, image_size):
171
+ if 'MoGe' in est_depth_path:
172
+ npz = np.load(est_depth_path)
173
+ est_depth = npz['depth']
174
+ est_depth_mask = npz['mask']
175
+ est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32)
176
+ ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
177
+ est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
178
+ est_depth = torch.where(ivalid_mask, 0.0, est_depth)
179
+ return est_depth, est_depth_mask
180
+ elif 'DAv2_' in est_depth_path or 'ml-depth-pro' in est_depth_path:
181
+ npz = np.load(est_depth_path)
182
+ est_depth = npz['depth']
183
+ est_depth_mask = np.logical_not(np.logical_or(
184
+ np.isnan(est_depth),
185
+ np.isinf(est_depth),
186
+ ))
187
+ est_depth = torch.from_numpy(est_depth).to(dtype=torch.float32)
188
+ ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
189
+ est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
190
+ est_depth = torch.where(ivalid_mask, 0.0, est_depth)
191
+ return est_depth, est_depth_mask
192
+ elif 'VGGT_1B' in est_depth_path:
193
+ npz = np.load(est_depth_path)
194
+ est_depth = npz['depth']
195
+ est_depth_mask = npz['depth_conf'] > 2.0
196
+ valid_depth_mask = np.logical_not(np.logical_or(
197
+ np.isnan(est_depth),
198
+ np.isinf(est_depth),
199
+ ))
200
+ est_depth_mask = np.logical_and(
201
+ est_depth_mask,
202
+ valid_depth_mask
203
+ )
204
+ est_depth = np.where(valid_depth_mask, est_depth, 0.0)
205
+
206
+ depth_min, depth_max = np.min(est_depth), np.max(est_depth)
207
+ est_depth = (est_depth - depth_min) / (depth_max - depth_min + 1e-6)
208
+ est_depth = Image.fromarray(est_depth)
209
+ est_depth = est_depth.resize((image_size, image_size), Image.Resampling.NEAREST)
210
+ est_depth = torch.tensor(np.array(est_depth)).to(dtype=torch.float32)
211
+ est_depth = est_depth * (depth_max - depth_min) + depth_min
212
+
213
+ est_depth_mask = Image.fromarray(est_depth_mask.astype(np.float32))
214
+ est_depth_mask = est_depth_mask.resize((image_size, image_size), Image.Resampling.NEAREST)
215
+ est_depth_mask = np.array(est_depth_mask) > 0.5
216
+
217
+ ivalid_mask = torch.logical_or(torch.isnan(est_depth), torch.isinf(est_depth))
218
+ est_depth_mask = np.logical_and(est_depth_mask, ~ivalid_mask.detach().cpu().numpy())
219
+ est_depth = torch.where(ivalid_mask, 0.0, est_depth)
220
+ return est_depth, est_depth_mask
221
+
222
+ def lstsq_align_depth(est_depth, gt_depth, mask):
223
+ valid_coords = torch.nonzero(mask)
224
+ if valid_coords.shape[0] > 0:
225
+ valid_gt_depth = gt_depth[valid_coords[:, 0], valid_coords[:, 1]]
226
+ valid_est_depth = est_depth[valid_coords[:, 0], valid_coords[:, 1]]
227
+ X = torch.linalg.lstsq(valid_est_depth[None, :, None], valid_gt_depth[None, :, None]).solution
228
+ lstsq_scale = X.item()
229
+ else:
230
+ lstsq_scale = 1.0
231
+ return est_depth * lstsq_scale
232
+
233
+ def get_cam_poses(frame_info, H, W):
234
+ camera_angle_x = float(frame_info['camera_angle_x'])
235
+ focal = .5 * W / np.tan(.5 * camera_angle_x)
236
+ K = np.array([
237
+ [focal, 0, 0.5*W],
238
+ [0, focal, 0.5*H],
239
+ [0, 0, 1]
240
+ ])
241
+ K = torch.from_numpy(K).float()
242
+ c2w = torch.from_numpy(np.array(frame_info['transform_matrix'])).float()
243
+ return K, c2w
244
+
245
+ def edge_mask_morph_gradient(mask, kernel, iterations=1):
246
+ """
247
+ mask: HxW, bool/uint8
248
+ ksize: 3/5/7... 越大边缘越厚
249
+ return: edge_mask uint8 {0,1}
250
+ """
251
+ m = (mask.astype(np.uint8) > 0).astype(np.uint8)
252
+
253
+ dil = cv2.dilate(m, kernel, iterations=iterations, borderType=cv2.BORDER_CONSTANT, borderValue=0.0)
254
+ ero = cv2.erode(m, kernel, iterations=iterations, borderType=cv2.BORDER_CONSTANT, borderValue=0.0)
255
+
256
+ edge = (dil - ero) # 0/1/2
257
+ edge = (edge > 0).astype(np.uint8)
258
+ return edge
259
+
260
+ def process_scene_image(image: Image.Image, instance_mask: np.ndarray, image_size: int,
261
+ resize_perturb: bool = False, resize_perturb_ratio: float = 0.0):
262
+ image_rgba = image
263
+ try:
264
+ alpha = np.array(image_rgba.getchannel("A")) > 0
265
+ except ValueError:
266
+ alpha = np.ones_like(np.array(image_rgba.getchannel(0))) > 0
267
+ alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255
268
+
269
+ image_resized = image_rgba.resize((image_size, image_size), Image.Resampling.LANCZOS).convert("RGB")
270
+ alpha_resized = Image.fromarray(alpha, mode="L").resize((image_size, image_size), Image.Resampling.NEAREST)
271
+
272
+ if resize_perturb and np.random.rand() < resize_perturb_ratio:
273
+ rand_reso = np.random.randint(32, image_size)
274
+
275
+ image_resized = image_resized.resize((rand_reso, rand_reso), Image.Resampling.LANCZOS)
276
+ image_resized = image_resized.resize((image_size, image_size), Image.Resampling.LANCZOS)
277
+
278
+ alpha_resized = alpha_resized.resize((rand_reso, rand_reso), Image.Resampling.NEAREST)
279
+ alpha_resized = alpha_resized.resize((image_size, image_size), Image.Resampling.NEAREST)
280
+
281
+ img_np = np.array(image_resized, dtype=np.uint8)
282
+ img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
283
+
284
+ a_np = np.array(alpha_resized, dtype=np.uint8)
285
+ a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0
286
+ img4 = torch.cat([img_t, a_t], dim=0) # (4,S,S)
287
+ return img_t, img4
288
+
289
+ def get_rays(i, j, K, c2w):
290
+ i = i.float() + 0.5
291
+ j = j.float() + 0.5
292
+ dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
293
+ # Rotate ray directions from camera frame to the world frame
294
+ rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
295
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
296
+ rays_o = c2w[:3,-1].expand(rays_d.shape)
297
+ return rays_o, rays_d
298
+
299
+ def get_rays_fast(u: torch.Tensor, v: torch.Tensor, K: torch.Tensor, c2w: torch.Tensor):
300
+ """
301
+ u, v: 1D tensor (pixel coords), dtype long/int64 or int32
302
+ K: (3,3) or (4,4) but used as 3x3; on same device as output
303
+ c2w: (4,4) or (3,4), uses [:3,:3] and [:3,3]
304
+ return:
305
+ rays_o: (N,3)
306
+ rays_d: (N,3)
307
+ """
308
+ # 确保 float 并加 0.5 取像素中心
309
+ u = u.to(dtype=torch.float32) + 0.5
310
+ v = v.to(dtype=torch.float32) + 0.5
311
+
312
+ fx, fy = K[0, 0], K[1, 1]
313
+ cx, cy = K[0, 2], K[1, 2]
314
+
315
+ # dirs in camera frame (N,3)
316
+ dirs = torch.stack([(u - cx) / fx,
317
+ -(v - cy) / fy,
318
+ -torch.ones_like(u)], dim=-1)
319
+
320
+ # 旋转到世界坐标:dirs @ R^T (更常见/更快)
321
+ R = c2w[:3, :3] # (3,3)
322
+ rays_d = dirs @ R.T # (N,3)
323
+
324
+ # 原点:相机中心 (3,) 扩展到 (N,3)
325
+ t = c2w[:3, 3]
326
+ rays_o = t.expand_as(rays_d)
327
+ return rays_o, rays_d
328
+
329
+ def process_instance_image(image: Image.Image, instance_mask: np.ndarray, color_mask: np.ndarray, depth_map: torch.Tensor,
330
+ K: torch.Tensor, c2w: torch.Tensor, image_size: int):
331
+ image_rgba = image
332
+ try:
333
+ alpha = np.asarray(image_rgba.getchannel("A")) > 0
334
+ except ValueError:
335
+ alpha = np.ones_like(np.array(image_rgba.getchannel(0))) > 0
336
+ alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255
337
+ valid_mask = np.array(alpha).nonzero()
338
+
339
+ bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
340
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
341
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
342
+ aug_size_ratio = 1.2
343
+ aug_hsize = hsize * aug_size_ratio
344
+ aug_center_offset = [0, 0]
345
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
346
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
347
+
348
+ i, j = torch.from_numpy(valid_mask[1]), torch.from_numpy(valid_mask[0])
349
+ rays_o, rays_d = get_rays(i, j, K, c2w)
350
+ rays_color = color_mask[valid_mask[0], valid_mask[1]].astype(np.float32)
351
+ rays_t = depth_map[valid_mask[0], valid_mask[1]]
352
+
353
+ image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS)
354
+ alpha_resized = Image.fromarray(alpha, mode="L").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
355
+
356
+ img_np = np.asarray(image_resized, dtype=np.uint8)
357
+ img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
358
+
359
+ a_np = np.asarray(alpha_resized, dtype=np.uint8)
360
+ a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0
361
+ return img_t, a_t, rays_o, rays_d, rays_color, rays_t
362
+
363
+ def get_crop_area_rays(image: Image.Image, instance_mask: np.ndarray, K: torch.Tensor, c2w: torch.Tensor, image_size):
364
+
365
+ alpha = np.asarray(image.getchannel("A")) > 0
366
+ if instance_mask is not None:
367
+ alpha = np.logical_and(alpha, instance_mask).astype(np.float32) # * 255
368
+ else:
369
+ alpha = alpha.astype(np.float32)
370
+ valid_mask = np.array(alpha).nonzero()
371
+
372
+ bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
373
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
374
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
375
+ aug_size_ratio = 1.2
376
+ aug_hsize = hsize * aug_size_ratio
377
+ aug_center_offset = [0, 0]
378
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
379
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
380
+
381
+ i, j = torch.meshgrid(
382
+ torch.linspace(aug_bbox[0], aug_bbox[2]-1, steps=image_size),
383
+ torch.linspace(aug_bbox[1], aug_bbox[3]-1, steps=image_size)
384
+ )
385
+ rays_o, rays_d = get_rays(i, j, K, c2w)
386
+ return rays_o, rays_d
387
+
388
+ def process_instance_image_crop(image: Image.Image, instance_mask: np.ndarray, color_mask: np.ndarray,
389
+ depth_map: torch.Tensor,
390
+ gt_depth_map: torch.Tensor,
391
+ K: torch.Tensor, c2w: torch.Tensor, image_size: int,
392
+ edge_mask_morph_gradient_fn):
393
+ image_rgba = image
394
+ alpha = np.asarray(image_rgba.getchannel("A")) > 0
395
+ alpha = np.logical_and(alpha, instance_mask).astype(np.float32) # * 255
396
+ valid_mask = np.array(alpha).nonzero()
397
+
398
+ bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
399
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
400
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
401
+ aug_size_ratio = 1.2
402
+ aug_hsize = hsize * aug_size_ratio
403
+ aug_center_offset = [0, 0]
404
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
405
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
406
+
407
+ i, j = torch.meshgrid(
408
+ torch.linspace(aug_bbox[0], aug_bbox[2]-1, steps=image_size),
409
+ torch.linspace(aug_bbox[1], aug_bbox[3]-1, steps=image_size)
410
+ )
411
+ rays_o, rays_d = get_rays(i, j, K, c2w)
412
+
413
+ image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS)
414
+ alpha_resized = Image.fromarray(alpha, mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
415
+ depth_map_resized = Image.fromarray(depth_map.detach().cpu().numpy(), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
416
+ gt_depth_map_resized = Image.fromarray(gt_depth_map.detach().cpu().numpy(), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
417
+ color_mask_resized = Image.fromarray(color_mask.astype(np.float32), mode="F").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
418
+
419
+ img_np = np.asarray(image_resized, dtype=np.uint8)
420
+ img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
421
+
422
+ a_np = np.asarray(alpha_resized, dtype=np.float32).astype(dtype=np.uint8)
423
+
424
+ edge_mask = edge_mask_morph_gradient_fn((a_np > 0).astype(np.uint8))
425
+ fg_mask = (a_np > edge_mask).astype(np.uint8)
426
+ rays_color = fg_mask.astype(np.float32) + edge_mask.astype(np.float32) * 0.5
427
+
428
+ valid_mask = fg_mask.nonzero()
429
+ rays_t = torch.from_numpy(np.asarray(depth_map_resized).astype(np.float32))
430
+
431
+ a_t = torch.from_numpy(a_np).unsqueeze(0).float() # / 255.0
432
+ return img_t, a_t, fg_mask, rays_o, rays_d, rays_color, rays_t, valid_mask, depth_map_resized, gt_depth_map_resized, color_mask_resized
433
+
434
+ def process_instance_image_only(image: Image.Image, instance_mask: np.ndarray, image_size: int):
435
+ image_rgba = image
436
+ alpha = np.asarray(image_rgba.getchannel("A")) > 0
437
+ alpha = np.logical_and(alpha, instance_mask).astype(np.uint8) * 255
438
+ valid_mask = np.array(alpha).nonzero()
439
+
440
+ bbox = [valid_mask[1].min(), valid_mask[0].min(), valid_mask[1].max(), valid_mask[0].max()]
441
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
442
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
443
+ aug_size_ratio = 1.2
444
+ aug_hsize = hsize * aug_size_ratio
445
+ aug_center_offset = [0, 0]
446
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
447
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
448
+
449
+ image_resized = image_rgba.crop(aug_bbox).convert("RGB").resize((image_size, image_size), Image.Resampling.LANCZOS)
450
+ alpha_resized = Image.fromarray(alpha, mode="L").crop(aug_bbox).resize((image_size, image_size), Image.Resampling.NEAREST)
451
+
452
+ img_np = np.asarray(image_resized, dtype=np.uint8)
453
+ img_t = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0
454
+
455
+ a_np = np.asarray(alpha_resized, dtype=np.uint8)
456
+ a_t = torch.from_numpy(a_np).unsqueeze(0).float() / 255.0
457
+ return img_t, a_t
458
+
459
+ def crop_depth_image(depth_image, aug_bbox, image_size):
460
+ d = depth_image.cpu()
461
+ d_np = d.numpy().astype(np.float32)
462
+ img = Image.fromarray(d_np, mode="F")
463
+ img = img.crop(aug_bbox)
464
+ img = img.resize((image_size, image_size), Image.Resampling.NEAREST)
465
+ out = torch.from_numpy(np.asarray(img, dtype=np.float32))
466
+ return out
467
+
468
+ def proj_depth2pcd(mask, depth, image, rays_o, rays_d):
469
+ mask = torch.nonzero(mask)
470
+
471
+ ###
472
+ mask = [mask[:, 0].detach().cpu().numpy(), mask[:, 1].detach().cpu().numpy()]
473
+ pixel_depth = depth[mask[0], mask[1]]
474
+ pixel_color = image.detach().permute(1, 2, 0)[mask[0], mask[1]]
475
+
476
+ pixel_points = rays_o[mask[0], mask[1]] + rays_d[mask[0], mask[1]] * pixel_depth[:, None] # pt
477
+ return pixel_points.detach().cpu().numpy(), pixel_color.detach().cpu().numpy()
478
+
479
+ def vox2pts(ss, resolution = 64):
480
+ coords = torch.nonzero(ss[0] > 0, as_tuple=False)
481
+ position = (coords.float() + 0.5) / resolution - 0.5
482
+ position = position.detach().cpu().numpy()
483
+ return position
484
+
485
+ def voxelize_pcd(points, points_color=None, clip_range_first=False, return_mask=True, resolution=64):
486
+ if clip_range_first:
487
+ points = np.clip(points, -0.5 + 1e-6, 0.5 - 1e-6)
488
+ pcd = o3d.geometry.PointCloud()
489
+ pcd.points = o3d.utility.Vector3dVector(points)
490
+ voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(pcd, voxel_size=1/resolution, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
491
+ vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
492
+ assert np.all(vertices >= 0) and np.all(vertices < resolution), "Some vertices are out of bounds"
493
+ vertices = (vertices + 0.5) / resolution - 0.5
494
+ coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous()
495
+ ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long)
496
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
497
+
498
+ if points_color is not None:
499
+ points_t = torch.from_numpy(points).to(torch.float32)
500
+ colors_t = torch.from_numpy(points_color).to(torch.float32)
501
+
502
+ coords = torch.floor((points_t + 0.5) * resolution).to(torch.long)
503
+ coords = torch.clamp(coords, 0, resolution - 1)
504
+ ix, iy, iz = coords[:, 0], coords[:, 1], coords[:, 2]
505
+ lin = ix * (resolution * resolution) + iy * resolution + iz # linear index in [0, R^3)
506
+
507
+ sum_color = torch.zeros((resolution * resolution * resolution), dtype=torch.float32)
508
+ sum_color.index_add_(0, lin, colors_t)
509
+ count = torch.zeros((resolution * resolution * resolution,), dtype=torch.long)
510
+ ones = torch.ones_like(lin, dtype=torch.long)
511
+ count.index_add_(0, lin, ones)
512
+
513
+ count_f = count.to(torch.float32)
514
+ mean_color = sum_color / torch.clamp(count_f, min=1.0) # empty -> divide by 1 (still 0)
515
+ color_mean = mean_color.view(resolution, resolution, resolution, 1).permute(3, 0, 1, 2).contiguous()
516
+ if return_mask:
517
+ ss_mask = rearrange(ss if points_color is None else color_mean, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float()
518
+ return ss , ss_mask
519
+ else:
520
+ return ss
521
+
522
+ def voxelize_pcd_pt(points, points_color=None, clip_range_first=False, return_mask=True, resolution=64):
523
+ points = torch.nan_to_num(points)
524
+ points_color = torch.nan_to_num(points_color) if isinstance(points_color, torch.Tensor) else points_color
525
+ device = points.device
526
+ if clip_range_first:
527
+ points = torch.clip(points, -0.5 + 1e-6, 0.5 - 1e-6)
528
+ pcd = o3d.geometry.PointCloud()
529
+ pcd.points = o3d.utility.Vector3dVector(points.detach().cpu().numpy())
530
+ voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud_within_bounds(pcd, voxel_size=1/resolution, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
531
+ vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
532
+ assert np.all(vertices >= 0) and np.all(vertices < resolution), "Some vertices are out of bounds"
533
+ vertices = (vertices + 0.5) / resolution - 0.5
534
+ coords = ((torch.tensor(vertices, device=device) + 0.5) * resolution).int().contiguous()
535
+ ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long, device=device)
536
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
537
+
538
+ if points_color is not None:
539
+ points_t = points.to(torch.float32)
540
+ colors_t = points_color.to(torch.float32)
541
+
542
+ coords = torch.floor((points_t + 0.5) * resolution).to(torch.long)
543
+ coords = torch.clamp(coords, 0, resolution - 1)
544
+ ix, iy, iz = coords[:, 0], coords[:, 1], coords[:, 2]
545
+ lin = ix * (resolution * resolution) + iy * resolution + iz # linear index in [0, R^3)
546
+
547
+ sum_color = torch.zeros((resolution * resolution * resolution), dtype=torch.float32, device=device)
548
+ sum_color.index_add_(0, lin, colors_t)
549
+ count = torch.zeros((resolution * resolution * resolution,), dtype=torch.long, device=device)
550
+ ones = torch.ones_like(lin, dtype=torch.long)
551
+ count.index_add_(0, lin, ones)
552
+
553
+ count_f = count.to(torch.float32)
554
+ mean_color = sum_color / torch.clamp(count_f, min=1.0) # empty -> divide by 1 (still 0)
555
+ color_mean = mean_color.view(resolution, resolution, resolution, 1).permute(3, 0, 1, 2).contiguous()
556
+ if return_mask:
557
+ ss_mask = rearrange(ss if points_color is None else color_mean, 'c (x n1) (y n2) (z n3) -> (n1 n2 n3 c) x y z', n1=4, n2=4, n3=4).float()
558
+ return ss , ss_mask
559
+ else:
560
+ return ss
561
+
562
+ def get_std_cond(root, instance, crop_size, return_mask=False):
563
+ image_root = os.path.join(root, 'renders_cond', instance)
564
+ if os.path.exists(os.path.join(image_root, 'transforms.json')):
565
+ with open(os.path.join(image_root, 'transforms.json')) as f:
566
+ metadata = json.load(f)
567
+ else:
568
+ image_root = os.path.join(root, 'renders', instance)
569
+ with open(os.path.join(image_root, 'transforms.json')) as f:
570
+ metadata = json.load(f)
571
+ n_views = len(metadata['frames'])
572
+ view = np.random.randint(n_views)
573
+ metadata = metadata['frames'][view]
574
+
575
+ image_path = os.path.join(image_root, metadata['file_path'])
576
+ image = Image.open(image_path)
577
+
578
+ alpha = np.array(image.getchannel(3))
579
+ bbox = np.array(alpha).nonzero()
580
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
581
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
582
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
583
+ aug_size_ratio = 1.2
584
+ aug_hsize = hsize * aug_size_ratio
585
+ aug_center_offset = [0, 0]
586
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
587
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
588
+ image = image.crop(aug_bbox)
589
+
590
+ image = image.resize((crop_size, crop_size), Image.Resampling.LANCZOS)
591
+ alpha = image.getchannel(3)
592
+ image = image.convert('RGB')
593
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
594
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
595
+ image = image * alpha.unsqueeze(0)
596
+ if return_mask:
597
+ return image, alpha.unsqueeze(0)
598
+ else:
599
+ return image
600
+
601
+ def map_rotated_slat2canonical_pose(vertices, rot_slat_info):
602
+ vertices_scale = rot_slat_info['scale']
603
+ vertices_trans = np.array(rot_slat_info['translation'])
604
+ rand_rot = rot_slat_info['rotate']
605
+ pcd = o3d.geometry.PointCloud()
606
+ vertices = vertices * vertices_scale
607
+ vertices = vertices + vertices_trans
608
+ pcd.points = o3d.utility.Vector3dVector(vertices)
609
+ R1 = pcd.get_rotation_matrix_from_xyz((-rand_rot[0], 0, 0))
610
+ R2 = pcd.get_rotation_matrix_from_xyz((0, -rand_rot[1], 0))
611
+ R3 = pcd.get_rotation_matrix_from_xyz((0, 0, -rand_rot[2]))
612
+ pcd.rotate(R3, center=(0., 0., 0.))
613
+ pcd.rotate(R2, center=(0., 0., 0.))
614
+ pcd.rotate(R1, center=(0., 0., 0.))
615
+ vertices = np.asarray(pcd.points)
616
+
617
+ return vertices
618
+
619
+ def project2ply(mask, depth, image, K, c2w):
620
+ mask = torch.nonzero(mask)
621
+
622
+ rays_o, rays_d = get_rays(mask[:, 1], mask[:, 0], K, c2w)
623
+
624
+ ###
625
+ mask = [mask[:, 0].detach().cpu().numpy(), mask[:, 1].detach().cpu().numpy()]
626
+ pixel_depth = depth[mask[0], mask[1]]
627
+ pixel_color = image.detach().permute(1, 2, 0).cpu().numpy()[mask[0], mask[1]]
628
+
629
+ pixel_points = rays_o + rays_d * pixel_depth[:, None]
630
+ pixel_points = pixel_points.detach().cpu().numpy()
631
+ return pixel_points, pixel_color
threeDFixer/models/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from TRELLIS:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+ import importlib
8
+
9
+ __attributes = {
10
+ 'SparseStructureEncoder': 'sparse_structure_vae',
11
+ 'SparseStructureDecoder': 'sparse_structure_vae',
12
+
13
+ 'SparseStructureFlowModel': 'sparse_structure_flow',
14
+
15
+ 'SLatEncoder': 'structured_latent_vae',
16
+ 'SLatGaussianDecoder': 'structured_latent_vae',
17
+ 'SLatRadianceFieldDecoder': 'structured_latent_vae',
18
+ 'SLatMeshDecoder': 'structured_latent_vae',
19
+ 'ElasticSLatEncoder': 'structured_latent_vae',
20
+ 'ElasticSLatGaussianDecoder': 'structured_latent_vae',
21
+ 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
22
+ 'ElasticSLatMeshDecoder': 'structured_latent_vae',
23
+
24
+ 'SLatFlowModel': 'structured_latent_flow',
25
+ 'ElasticSLatFlowModel': 'structured_latent_flow',
26
+
27
+ 'SceneSLatFlowModel': 'scene_structured_latent_flow',
28
+ 'ElasticSceneSLatFlowModel': 'scene_structured_latent_flow',
29
+ 'SceneSparseStructureFlowModule': 'scene_sparse_structure_flow',
30
+ }
31
+
32
+ __submodules = []
33
+
34
+ __all__ = list(__attributes.keys()) + __submodules
35
+
36
+ def __getattr__(name):
37
+ if name not in globals():
38
+ if name in __attributes:
39
+ module_name = __attributes[name]
40
+ module = importlib.import_module(f".{module_name}", __name__)
41
+ globals()[name] = getattr(module, name)
42
+ elif name in __submodules:
43
+ module = importlib.import_module(f".{name}", __name__)
44
+ globals()[name] = module
45
+ else:
46
+ raise AttributeError(f"module {__name__} has no attribute {name}")
47
+ return globals()[name]
48
+
49
+
50
+ def from_pretrained(path: str, **kwargs):
51
+ """
52
+ Load a model from a pretrained checkpoint.
53
+
54
+ Args:
55
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
56
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
57
+ **kwargs: Additional arguments for the model constructor.
58
+ """
59
+ import os
60
+ import json
61
+ import torch
62
+ from safetensors.torch import load_file
63
+ from ..utils.dist_utils import read_file_dist
64
+ is_local = os.path.exists(f"{path}.json") and (os.path.exists(f"{path}.safetensors") or os.path.exists(f"{path}.pt"))
65
+
66
+ if is_local:
67
+ config_file = f"{path}.json"
68
+ model_file = f"{path}.safetensors" if os.path.exists(f"{path}.safetensors") else f"{path}.pt"
69
+ else:
70
+ from huggingface_hub import hf_hub_download
71
+ path_parts = path.split('/')
72
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
73
+ model_name = '/'.join(path_parts[2:])
74
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
75
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
76
+
77
+ with open(config_file, 'r') as f:
78
+ config = json.load(f)
79
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
80
+ if model_file.endswith(".safetensors"):
81
+ model.load_state_dict(load_file(model_file))
82
+ else:
83
+ model_ckpt = torch.load(read_file_dist(model_file), map_location='cpu', weights_only=True)
84
+ model.load_state_dict(model_ckpt)
85
+ if model.dtype == torch.float16:
86
+ model.convert_to_fp16()
87
+
88
+ return model
89
+
90
+
91
+ # For Pylance
92
+ if __name__ == '__main__':
93
+ from .sparse_structure_vae import (
94
+ SparseStructureEncoder,
95
+ SparseStructureDecoder,
96
+ )
97
+
98
+ from .sparse_structure_flow import SparseStructureFlowModel
99
+
100
+ from .structured_latent_vae import (
101
+ SLatEncoder,
102
+ SLatGaussianDecoder,
103
+ SLatRadianceFieldDecoder,
104
+ SLatMeshDecoder,
105
+ ElasticSLatEncoder,
106
+ ElasticSLatGaussianDecoder,
107
+ ElasticSLatRadianceFieldDecoder,
108
+ ElasticSLatMeshDecoder,
109
+ )
110
+
111
+ from .structured_latent_flow import (
112
+ SLatFlowModel,
113
+ ElasticSLatFlowModel,
114
+ )
115
+
116
+ from .scene_sparse_structure_flow import (
117
+ SceneSparseStructureFlowModule
118
+ )
119
+
120
+ from .scene_structured_latent_flow import (
121
+ SceneSLatFlowModel,
122
+ ElasticSceneSLatFlowModel
123
+ )
threeDFixer/models/scene_sparse_structure_flow.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from TRELLIS:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+ from typing import *
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ from . import from_pretrained
13
+ from ..modules.utils import convert_module_to_f16, convert_module_to_f32
14
+ from ..modules.transformer import SceneModulatedTransformerCrossBlock
15
+ from ..modules.spatial import patchify, unpatchify
16
+ from .sparse_structure_flow import (
17
+ SparseStructureFlowModel,
18
+ TimestepEmbedder
19
+ )
20
+
21
+ def mean_flat(x):
22
+ """
23
+ Take the mean over all non-batch dimensions.
24
+ """
25
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
26
+
27
+ class SceneSparseStructureFlowModule(nn.Module):
28
+ def __init__(
29
+ self,
30
+ resolution: int,
31
+ in_channels: int,
32
+ model_channels: int,
33
+ cond_channels: int,
34
+ out_channels: int,
35
+ num_blocks: int,
36
+ num_heads: Optional[int] = None,
37
+ num_head_channels: Optional[int] = 64,
38
+ mlp_ratio: float = 4,
39
+ patch_size: int = 2,
40
+ pe_mode: Literal["ape", "rope"] = "ape",
41
+ use_fp16: bool = False,
42
+ use_checkpoint: bool = False,
43
+ share_mod: bool = False,
44
+ qk_rms_norm: bool = False,
45
+ qk_rms_norm_cross: bool = False,
46
+ pretrained_ss_flow_dit: str = None,
47
+ resume_ckpts: str = None,
48
+ ):
49
+ super().__init__()
50
+ self.resolution = resolution
51
+ self.in_channels = in_channels
52
+ self.model_channels = model_channels
53
+ self.cond_channels = cond_channels
54
+ self.out_channels = out_channels
55
+ self.num_blocks = num_blocks
56
+ self.num_heads = num_heads or model_channels // num_head_channels
57
+ self.mlp_ratio = mlp_ratio
58
+ self.patch_size = patch_size
59
+ self.pe_mode = pe_mode
60
+ self.use_fp16 = use_fp16
61
+ self.use_checkpoint = use_checkpoint
62
+ self.share_mod = share_mod
63
+ self.qk_rms_norm = qk_rms_norm
64
+ self.qk_rms_norm_cross = qk_rms_norm_cross
65
+ self.dtype = torch.float16 if use_fp16 else torch.float32
66
+
67
+ self.input_layer_vox_partial = nn.Linear(in_channels * patch_size**3, model_channels)
68
+ self.input_layer_mask_partial = nn.Linear(64, model_channels)
69
+
70
+ self.dpt_ratio_embedder = TimestepEmbedder(model_channels)
71
+
72
+ self.blocks = nn.ModuleList([
73
+ SceneModulatedTransformerCrossBlock(
74
+ model_channels,
75
+ cond_channels,
76
+ num_heads=self.num_heads,
77
+ mlp_ratio=self.mlp_ratio,
78
+ attn_mode='full',
79
+ use_checkpoint=self.use_checkpoint,
80
+ use_rope=(pe_mode == "rope"),
81
+ share_mod=share_mod,
82
+ qk_rms_norm=self.qk_rms_norm,
83
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
84
+ )
85
+ for _ in range(num_blocks)
86
+ ])
87
+ self.control_path = nn.Sequential(*[
88
+ nn.Linear(model_channels, model_channels) for _ in range(num_blocks)
89
+ ])
90
+
91
+ self.neg_cache = {}
92
+ self.cond_vox_cache = None
93
+
94
+ self.initialize_weights()
95
+ if pretrained_ss_flow_dit is not None:
96
+ if pretrained_ss_flow_dit.endswith('.pt'):
97
+ print (f'loading pretrained weight: {pretrained_ss_flow_dit}')
98
+ model_ckpt = torch.load(pretrained_ss_flow_dit, map_location='cpu', weights_only=True)
99
+ self.input_layer_vox_partial.load_state_dict(
100
+ {k.replace('input_layer.', ''): model_ckpt[k] for k in filter(lambda x: 'input_layer' in x, model_ckpt.keys())}
101
+ )
102
+ self.dpt_ratio_embedder.load_state_dict(
103
+ {k.replace('t_embedder.', ''): model_ckpt[k] for k in filter(lambda x: 't_embedder' in x, model_ckpt.keys())}
104
+ )
105
+
106
+ for block_index, module in enumerate(self.blocks):
107
+ module: SceneModulatedTransformerCrossBlock
108
+ module.load_state_dict(
109
+ {k.replace(f'blocks.{block_index}', ''): model_ckpt[k] for k in filter(lambda x: f'blocks.{block_index}' in x, model_ckpt.keys())}, strict=False
110
+ )
111
+ module.norm4.load_state_dict(module.norm1.state_dict())
112
+ module.norm5.load_state_dict(module.norm2.state_dict())
113
+ module.self_attn_dpt_ratio.load_state_dict(module.self_attn.state_dict())
114
+ module.cross_attn_extra.load_state_dict(module.cross_attn.state_dict())
115
+ nn.init.constant_(module.self_attn_dpt_ratio.to_out.weight, 0)
116
+ if module.self_attn_dpt_ratio.to_out.bias is not None:
117
+ nn.init.constant_(module.self_attn_dpt_ratio.to_out.bias, 0)
118
+ nn.init.constant_(module.cross_attn_extra.to_out.weight, 0)
119
+ if module.cross_attn_extra.to_out.bias is not None:
120
+ nn.init.constant_(module.cross_attn_extra.to_out.bias, 0)
121
+ del model_ckpt
122
+ else:
123
+ print (f'loading pretrained weight: {pretrained_ss_flow_dit}')
124
+ pre_trained_models = from_pretrained(pretrained_ss_flow_dit)
125
+ pre_trained_models: SparseStructureFlowModel
126
+
127
+ self.input_layer_vox_partial.load_state_dict(pre_trained_models.input_layer.state_dict())
128
+ self.dpt_ratio_embedder.load_state_dict(pre_trained_models.t_embedder.state_dict())
129
+
130
+ for block_index, module in enumerate(self.blocks):
131
+ module: SceneModulatedTransformerCrossBlock
132
+ module.load_state_dict(pre_trained_models.blocks[block_index].state_dict(), strict=False)
133
+ module.norm4.load_state_dict(module.norm1.state_dict())
134
+ module.norm5.load_state_dict(module.norm2.state_dict())
135
+ module.self_attn_dpt_ratio.load_state_dict(module.self_attn.state_dict())
136
+ module.cross_attn_extra.load_state_dict(module.cross_attn.state_dict())
137
+ nn.init.constant_(module.self_attn_dpt_ratio.to_out.weight, 0)
138
+ if module.self_attn_dpt_ratio.to_out.bias is not None:
139
+ nn.init.constant_(module.self_attn_dpt_ratio.to_out.bias, 0)
140
+ nn.init.constant_(module.cross_attn_extra.to_out.weight, 0)
141
+ if module.cross_attn_extra.to_out.bias is not None:
142
+ nn.init.constant_(module.cross_attn_extra.to_out.bias, 0)
143
+ del pre_trained_models
144
+ if resume_ckpts is not None:
145
+ print (f'loading pretrained weight: {resume_ckpts}')
146
+ model_ckpt = torch.load(resume_ckpts, map_location='cpu', weights_only=True)
147
+ self.load_state_dict(model_ckpt, strict=False)
148
+ del model_ckpt
149
+ if use_fp16:
150
+ self.convert_to_fp16()
151
+
152
+ def clear_neg_cache(self):
153
+ self.neg_cache = {}
154
+
155
+ def clear_cond_vox_cache(self):
156
+ self.cond_vox_cache = None
157
+
158
+ @property
159
+ def device(self) -> torch.device:
160
+ """
161
+ Return the device of the model.
162
+ """
163
+ return next(self.parameters()).device
164
+
165
+ def convert_to_fp16(self) -> None:
166
+ """
167
+ Convert the torso of the model to float16.
168
+ """
169
+ self.blocks.apply(convert_module_to_f16)
170
+ self.control_path.apply(convert_module_to_f16)
171
+
172
+ def convert_to_fp32(self) -> None:
173
+ """
174
+ Convert the torso of the model to float32.
175
+ """
176
+ self.blocks.apply(convert_module_to_f32)
177
+ self.control_path.apply(convert_module_to_f32)
178
+
179
+ def initialize_weights(self) -> None:
180
+ # Initialize transformer layers:
181
+ def _basic_init(module):
182
+ if isinstance(module, nn.Linear):
183
+ torch.nn.init.xavier_uniform_(module.weight)
184
+ if module.bias is not None:
185
+ nn.init.constant_(module.bias, 0)
186
+ self.apply(_basic_init)
187
+
188
+ for block in self.control_path:
189
+ nn.init.constant_(block.weight, 0)
190
+ nn.init.constant_(block.bias, 0)
191
+
192
+ # Zero-out adaLN modulation layers in DiT blocks:
193
+ if self.share_mod:
194
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
195
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
196
+ else:
197
+ for block in self.blocks:
198
+ nn.init.constant_(block.adaLN_modulation_dpt[-1].weight, 0)
199
+ nn.init.constant_(block.adaLN_modulation_dpt[-1].bias, 0)
200
+
201
+ # Zero-out input layers:
202
+ nn.init.constant_(self.input_layer_mask_partial.weight, 0)
203
+ nn.init.constant_(self.input_layer_mask_partial.bias, 0)
204
+
205
+ def input_voxel(self, x, input_layer, pos_emb):
206
+ ########## voxel tokens
207
+ h = patchify(x, self.patch_size)
208
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
209
+
210
+ h = input_layer(h)
211
+ h = h + pos_emb
212
+ ########## voxel tokens
213
+ return h
214
+
215
+ def input_mask(self, x, input_layer):
216
+ h = patchify(x, self.patch_size)
217
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
218
+ h = input_layer(h)
219
+ return h
220
+
221
+ def forward(self, *args, **kwargs):
222
+ if kwargs.pop("w_align_loss", False):
223
+ return self._train_forward(*args, **kwargs, w_align_loss=True)
224
+ else:
225
+ return self._infer_forward(*args, **kwargs)
226
+
227
+ def _train_forward(self, x: torch.Tensor, t: torch.Tensor, cond: Dict[str,torch.Tensor],
228
+ forzen_denoiser: SparseStructureFlowModel, est_depth_ratio: torch.Tensor,
229
+ w_align_loss: bool = False) -> torch.Tensor:
230
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
231
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
232
+
233
+ h = self.input_voxel(x, forzen_denoiser.input_layer, forzen_denoiser.pos_emb[None])
234
+
235
+ cond_vox = self.input_voxel(cond['cond_partial_vox'], self.input_layer_vox_partial, forzen_denoiser.pos_emb[None]) + \
236
+ self.input_mask(cond['cond_partial_vox_mask'], self.input_layer_mask_partial)
237
+
238
+ cond_moge = cond['cond_scene']
239
+ cond_dino = cond['cond_instance']
240
+ cond_dino_masked = cond['cond_instance_masked']
241
+ if w_align_loss:
242
+ std_cond_dino = cond['std_cond_instance']
243
+ std_cond_dino = std_cond_dino.type(self.dtype)
244
+ std_h = h
245
+ std_h = std_h.type(self.dtype)
246
+
247
+ t_emb = forzen_denoiser.t_embedder(t)
248
+ if self.share_mod:
249
+ t_emb = forzen_denoiser.adaLN_modulation(t_emb)
250
+ t_emb = t_emb.type(self.dtype)
251
+ est_depth_ratio_emb = self.dpt_ratio_embedder(est_depth_ratio)
252
+ est_depth_ratio_emb = est_depth_ratio_emb.type(self.dtype)
253
+ h = h.type(self.dtype)
254
+ cond_control = cond_moge
255
+ cond_control = cond_control.type(self.dtype)
256
+ cond_vox = cond_vox.type(self.dtype)
257
+ cond_dino = cond_dino.type(self.dtype)
258
+ cond_dino_masked = cond_dino_masked.type(self.dtype)
259
+
260
+ align_loss = 0.0
261
+ acount = 0
262
+ for block_index, frozen_block in enumerate(forzen_denoiser.blocks):
263
+ h = frozen_block(h, t_emb, cond_dino_masked)
264
+ if block_index < len(self.blocks):
265
+ cond_vox = self.blocks[block_index](cond_vox, t_emb, est_depth_ratio_emb, cond_dino, cond_control)
266
+ ctrl_feats = self.control_path[block_index](cond_vox)
267
+ h = h + ctrl_feats
268
+
269
+ if w_align_loss:
270
+ with torch.no_grad():
271
+ std_h = frozen_block(std_h, t_emb, std_cond_dino)
272
+ acount += 1
273
+ reference = std_h
274
+ source = h
275
+
276
+ z_tilde_j = torch.nn.functional.normalize(source, dim=-1, eps=1e-6)
277
+ z_j = torch.nn.functional.normalize(reference, dim=-1, eps=1e-6)
278
+ align_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
279
+
280
+ h = h.type(x.dtype)
281
+
282
+ h = F.layer_norm(h, h.shape[-1:])
283
+ h = forzen_denoiser.out_layer(h)
284
+
285
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
286
+ h = unpatchify(h, self.patch_size).contiguous()
287
+
288
+ if w_align_loss:
289
+ return h, align_loss / acount
290
+ else:
291
+ return h
292
+
293
+ def _infer_forward(self, x: torch.Tensor, t: torch.Tensor, cond: Dict[str,torch.Tensor],
294
+ forzen_denoiser: SparseStructureFlowModel, est_depth_ratio: torch.Tensor) -> torch.Tensor:
295
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
296
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
297
+
298
+ h = self.input_voxel(x, forzen_denoiser.input_layer, forzen_denoiser.pos_emb[None])
299
+ cond_vox = self.input_voxel(cond['cond_partial_vox'], self.input_layer_vox_partial, forzen_denoiser.pos_emb[None]) + \
300
+ self.input_mask(cond['cond_partial_vox_mask'], self.input_layer_mask_partial)
301
+
302
+ cond_moge = cond['cond_scene']
303
+ cond_dino = cond['cond_instance']
304
+ cond_dino_masked = cond['cond_instance_masked']
305
+
306
+ t_emb = forzen_denoiser.t_embedder(t)
307
+ if self.share_mod:
308
+ t_emb = forzen_denoiser.adaLN_modulation(t_emb)
309
+ t_emb = t_emb.type(self.dtype)
310
+ est_depth_ratio_emb = self.dpt_ratio_embedder(est_depth_ratio)
311
+ est_depth_ratio_emb = est_depth_ratio_emb.type(self.dtype)
312
+ h = h.type(self.dtype)
313
+ cond_control = cond_moge
314
+ cond_control = cond_control.type(self.dtype)
315
+ cond_vox = cond_vox.type(self.dtype)
316
+ cond_dino = cond_dino.type(self.dtype)
317
+ cond_dino_masked = cond_dino_masked.type(self.dtype)
318
+
319
+ for block_index, frozen_block in enumerate(forzen_denoiser.blocks):
320
+ h = frozen_block(h, t_emb, cond_dino_masked)
321
+ if block_index < len(self.blocks):
322
+ cond_vox = self.blocks[block_index](cond_vox, t_emb, est_depth_ratio_emb, cond_dino, cond_control)
323
+ ctrl_feats = self.control_path[block_index](cond_vox)
324
+ h = h + ctrl_feats
325
+
326
+ h = h.type(x.dtype)
327
+
328
+ h = F.layer_norm(h, h.shape[-1:])
329
+ h = forzen_denoiser.out_layer(h)
330
+
331
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
332
+ h = unpatchify(h, self.patch_size).contiguous()
333
+
334
+ return h
threeDFixer/models/scene_structured_latent_flow.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from TRELLIS:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+ from typing import *
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
13
+ from ..modules.transformer import AbsolutePositionEmbedder
14
+ from ..modules.norm import LayerNorm32
15
+ from ..modules import sparse as sp
16
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock, ModulatedSceneSparseTransformerCrossBlock
17
+ from .sparse_structure_flow import TimestepEmbedder
18
+ from .scene_sparse_structure_flow import mean_flat
19
+ from .structured_latent_flow import SparseResBlock3d, SLatFlowModel
20
+ from .sparse_elastic_mixin import SparseTransformerElasticMixin
21
+ from . import from_pretrained
22
+
23
+ class SceneSLatFlowModel(nn.Module):
24
+ def __init__(
25
+ self,
26
+ resolution: int,
27
+ in_channels: int,
28
+ cond_slat_channels: int,
29
+ model_channels: int,
30
+ cond_channels: int,
31
+ out_channels: int,
32
+ num_blocks: int,
33
+ num_heads: Optional[int] = None,
34
+ num_head_channels: Optional[int] = 64,
35
+ mlp_ratio: float = 4,
36
+ patch_size: int = 2,
37
+ num_io_res_blocks: int = 2,
38
+ io_block_channels: List[int] = None,
39
+ pe_mode: Literal["ape", "rope"] = "ape",
40
+ use_fp16: bool = False,
41
+ use_checkpoint: bool = False,
42
+ use_skip_connection: bool = True,
43
+ share_mod: bool = False,
44
+ qk_rms_norm: bool = False,
45
+ qk_rms_norm_cross: bool = False,
46
+ pretrained_flow_dit: str = None,
47
+ ):
48
+ super().__init__()
49
+ self.resolution = resolution
50
+ self.in_channels = in_channels
51
+ self.cond_slat_channels = cond_slat_channels
52
+ self.model_channels = model_channels
53
+ self.cond_channels = cond_channels
54
+ self.out_channels = out_channels
55
+ self.num_blocks = num_blocks
56
+ self.num_heads = num_heads or model_channels // num_head_channels
57
+ self.mlp_ratio = mlp_ratio
58
+ self.patch_size = patch_size
59
+ self.num_io_res_blocks = num_io_res_blocks
60
+ self.io_block_channels = io_block_channels
61
+ self.pe_mode = pe_mode
62
+ self.use_fp16 = use_fp16
63
+ self.use_checkpoint = use_checkpoint
64
+ self.use_skip_connection = use_skip_connection
65
+ self.share_mod = share_mod
66
+ self.qk_rms_norm = qk_rms_norm
67
+ self.qk_rms_norm_cross = qk_rms_norm_cross
68
+ self.dtype = torch.float16 if use_fp16 else torch.float32
69
+
70
+ if self.io_block_channels is not None:
71
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
72
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
73
+
74
+ self.vis_ratio_embedder = TimestepEmbedder(model_channels)
75
+
76
+ self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
77
+ self.input_layer_cond = sp.SparseLinear(cond_slat_channels, model_channels if io_block_channels is None else io_block_channels[0])
78
+
79
+ self.input_blocks = nn.ModuleList([])
80
+ if io_block_channels is not None:
81
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
82
+ self.input_blocks.extend([
83
+ SparseResBlock3d(
84
+ chs,
85
+ model_channels,
86
+ out_channels=chs,
87
+ )
88
+ for _ in range(num_io_res_blocks-1)
89
+ ])
90
+ self.input_blocks.append(
91
+ SparseResBlock3d(
92
+ chs,
93
+ model_channels,
94
+ out_channels=next_chs,
95
+ downsample=True,
96
+ )
97
+ )
98
+
99
+ self.blocks = nn.ModuleList([
100
+ ModulatedSceneSparseTransformerCrossBlock(
101
+ model_channels,
102
+ cond_channels,
103
+ num_heads=self.num_heads,
104
+ mlp_ratio=self.mlp_ratio,
105
+ attn_mode='full',
106
+ use_checkpoint=self.use_checkpoint,
107
+ use_rope=(pe_mode == "rope"),
108
+ share_mod=self.share_mod,
109
+ qk_rms_norm=self.qk_rms_norm,
110
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
111
+ )
112
+ for _ in range(num_blocks)
113
+ ])
114
+
115
+ self.control_path = nn.Sequential(*[
116
+ sp.SparseLinear(model_channels, model_channels) for _ in range(num_blocks)
117
+ ])
118
+
119
+ self.initialize_weights()
120
+ if pretrained_flow_dit is not None:
121
+ if pretrained_flow_dit.endswith('.pt'):
122
+ print (f'loading pretrained weight: {pretrained_flow_dit}')
123
+ model_ckpt = torch.load(pretrained_flow_dit, map_location='cpu', weights_only=True)
124
+ self.input_layer.load_state_dict(
125
+ {k.replace('input_layer.', ''): model_ckpt[k] for k in filter(lambda x: 'input_layer' in x, model_ckpt.keys())}
126
+ )
127
+ self.vis_ratio_embedder.load_state_dict(
128
+ {k.replace('t_embedder.', ''): model_ckpt[k] for k in filter(lambda x: 't_embedder' in x, model_ckpt.keys())}
129
+ )
130
+ self.input_blocks.load_state_dict(
131
+ {k.replace('input_blocks.', ''): model_ckpt[k] for k in filter(lambda x: 'input_blocks' in x, model_ckpt.keys())}
132
+ )
133
+
134
+ for block_index, module in enumerate(self.blocks):
135
+ module: ModulatedSceneSparseTransformerCrossBlock
136
+ module.load_state_dict(
137
+ {k.replace(f'blocks.{block_index}', ''): model_ckpt[k] for k in filter(lambda x: f'blocks.{block_index}' in x, model_ckpt.keys())}, strict=False
138
+ )
139
+ module.norm4.load_state_dict(module.norm1.state_dict())
140
+ module.norm5.load_state_dict(module.norm2.state_dict())
141
+ module.self_attn_vis_ratio.load_state_dict(module.self_attn.state_dict())
142
+ module.cross_attn_extra.load_state_dict(module.cross_attn.state_dict())
143
+ nn.init.constant_(module.self_attn_vis_ratio.to_out.weight, 0)
144
+ if module.self_attn_vis_ratio.to_out.bias is not None:
145
+ nn.init.constant_(module.self_attn_vis_ratio.to_out.bias, 0)
146
+ nn.init.constant_(module.cross_attn_extra.to_out.weight, 0)
147
+ if module.cross_attn_extra.to_out.bias is not None:
148
+ nn.init.constant_(module.cross_attn_extra.to_out.bias, 0)
149
+ del model_ckpt
150
+ else:
151
+ print (f'loading pretrained weight: {pretrained_flow_dit}')
152
+ pre_trained_models = from_pretrained(pretrained_flow_dit)
153
+ pre_trained_models: SLatFlowModel
154
+
155
+ self.input_layer.load_state_dict(pre_trained_models.input_layer.state_dict())
156
+ self.vis_ratio_embedder.load_state_dict(pre_trained_models.t_embedder.state_dict())
157
+ self.input_blocks.load_state_dict(pre_trained_models.input_blocks.state_dict())
158
+
159
+ for block_index, module in enumerate(self.blocks):
160
+ module: ModulatedSceneSparseTransformerCrossBlock
161
+ module.load_state_dict(pre_trained_models.blocks[block_index].state_dict(), strict=False)
162
+ module.norm4.load_state_dict(module.norm1.state_dict())
163
+ module.norm5.load_state_dict(module.norm2.state_dict())
164
+ module.self_attn_vis_ratio.load_state_dict(module.self_attn.state_dict())
165
+ module.cross_attn_extra.load_state_dict(module.cross_attn.state_dict())
166
+ nn.init.constant_(module.self_attn_vis_ratio.to_out.weight, 0)
167
+ if module.self_attn_vis_ratio.to_out.bias is not None:
168
+ nn.init.constant_(module.self_attn_vis_ratio.to_out.bias, 0)
169
+ nn.init.constant_(module.cross_attn_extra.to_out.weight, 0)
170
+ if module.cross_attn_extra.to_out.bias is not None:
171
+ nn.init.constant_(module.cross_attn_extra.to_out.bias, 0)
172
+ del pre_trained_models
173
+ if use_fp16:
174
+ self.convert_to_fp16()
175
+
176
+ @property
177
+ def device(self) -> torch.device:
178
+ """
179
+ Return the device of the model.
180
+ """
181
+ return next(self.parameters()).device
182
+
183
+ def convert_to_fp16(self) -> None:
184
+ """
185
+ Convert the torso of the model to float16.
186
+ """
187
+ self.input_blocks.apply(convert_module_to_f16)
188
+ self.blocks.apply(convert_module_to_f16)
189
+ self.control_path.apply(convert_module_to_f16)
190
+
191
+ def convert_to_fp32(self) -> None:
192
+ """
193
+ Convert the torso of the model to float32.
194
+ """
195
+ self.input_blocks.apply(convert_module_to_f16)
196
+ self.blocks.apply(convert_module_to_f32)
197
+ self.control_path.apply(convert_module_to_f32)
198
+
199
+ def initialize_weights(self) -> None:
200
+ # Initialize transformer layers:
201
+ def _basic_init(module):
202
+ if isinstance(module, nn.Linear):
203
+ torch.nn.init.xavier_uniform_(module.weight)
204
+ if module.bias is not None:
205
+ nn.init.constant_(module.bias, 0)
206
+ self.apply(_basic_init)
207
+
208
+ # Initialize timestep embedding MLP:
209
+ nn.init.normal_(self.vis_ratio_embedder.mlp[0].weight, std=0.02)
210
+ nn.init.normal_(self.vis_ratio_embedder.mlp[2].weight, std=0.02)
211
+
212
+ # Zero-out adaLN modulation layers in DiT blocks:
213
+ if self.share_mod:
214
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
215
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
216
+ else:
217
+ for block in self.blocks:
218
+ nn.init.constant_(block.adaLN_modulation_vis[-1].weight, 0)
219
+ nn.init.constant_(block.adaLN_modulation_vis[-1].bias, 0)
220
+
221
+ for block in self.control_path:
222
+ nn.init.constant_(block.weight, 0)
223
+ nn.init.constant_(block.bias, 0)
224
+
225
+ def forward(self, *args, **kwargs):
226
+ stage = kwargs.pop('stage', None)
227
+ if stage == 'train':
228
+ return self._train_forward(*args, **kwargs)
229
+ elif stage == 'infer':
230
+ return self._infer_forward(*args, **kwargs)
231
+ elif stage == 'infer_std':
232
+ return self._infer_std_forward(*args, **kwargs)
233
+
234
+ def _input_slat(self, x: sp.SparseTensor, emb: torch.Tensor,
235
+ input_layer: Callable, input_blocks: Callable,
236
+ pos_embedder: Callable, residual_h: Callable = None
237
+ ):
238
+ h = input_layer(x).type(self.dtype)
239
+ skips = []
240
+ # pack with input blocks
241
+ for block in input_blocks:
242
+ h = block(h, emb)
243
+ skips.append(h.feats)
244
+
245
+ if self.pe_mode == "ape" and pos_embedder is not None:
246
+ h = h + pos_embedder(h.coords[:, 1:]).type(self.dtype)
247
+
248
+ if residual_h is not None:
249
+ h = residual_h(h)
250
+
251
+ return h, skips
252
+
253
+ def _train_forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: Dict[str,torch.Tensor], vis_ratio: torch.Tensor,
254
+ forzen_denoiser: SLatFlowModel) -> sp.SparseTensor:
255
+
256
+ t_emb = forzen_denoiser.t_embedder(t)
257
+ if forzen_denoiser.share_mod:
258
+ t_emb = forzen_denoiser.adaLN_modulation(t_emb)
259
+ t_emb = t_emb.type(self.dtype)
260
+
261
+ # moge feats and image mask
262
+ cond_moge = cond['cond_scene']
263
+ cond_dino = cond['cond_instance']
264
+ cond_dino_masked = cond['cond_instance_masked']
265
+ std_cond_dino = cond['std_cond_instance']
266
+ # voxels with projected feats
267
+ x_feat = cond['cond_voxel_feats']
268
+
269
+ cond_control = cond_moge
270
+ cond_control = cond_control.type(self.dtype)
271
+ cond_dino_masked = cond_dino_masked.type(self.dtype)
272
+ cond_dino = cond_dino.type(self.dtype)
273
+ std_cond_dino = std_cond_dino.type(self.dtype)
274
+
275
+ vis_ratio_emb = self.vis_ratio_embedder(vis_ratio)
276
+ vis_ratio_emb = vis_ratio_emb.type(self.dtype)
277
+
278
+ # input layer of frozen part
279
+ h, skips = self._input_slat(x, t_emb, self.input_layer,
280
+ forzen_denoiser.input_blocks,
281
+ forzen_denoiser.pos_embedder if self.pe_mode == "ape" else None)
282
+ # input layer of frozen part
283
+
284
+ # condition branch
285
+ ctrl_h, _ = self._input_slat(x_feat, vis_ratio_emb,
286
+ self.input_layer_cond, self.input_blocks,
287
+ forzen_denoiser.pos_embedder if self.pe_mode == "ape" else None)
288
+ # condition branch
289
+
290
+ std_h = h
291
+ align_loss = 0.0
292
+ acount = 0
293
+ for block_index, block in enumerate(forzen_denoiser.blocks):
294
+ h = block(h, t_emb, cond_dino_masked)
295
+ if block_index < self.num_blocks:
296
+ ctrl_h = self.blocks[block_index](ctrl_h, t_emb, vis_ratio_emb, cond_dino, cond_control)
297
+ h = h + self.control_path[block_index](ctrl_h)
298
+
299
+ std_h = block(std_h, t_emb, std_cond_dino)
300
+
301
+ std_h: sp.SparseTensor
302
+ h: sp.SparseTensor
303
+ for batch_std_h, batch_h in zip(sp.sparse_unbind(std_h, dim=0), sp.sparse_unbind(h, dim=0)):
304
+ acount += 1
305
+ reference_feats = batch_std_h.feats
306
+ source_feats = batch_h.feats
307
+ z_tilde_j = torch.nn.functional.normalize(source_feats, dim=-1, eps=1e-6)
308
+ z_j = torch.nn.functional.normalize(reference_feats, dim=-1, eps=1e-6)
309
+ align_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
310
+ align_loss /= acount
311
+
312
+ # unpack with output blocks
313
+ for block, skip in zip(forzen_denoiser.out_blocks, reversed(skips)):
314
+ if self.use_skip_connection:
315
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
316
+ else:
317
+ h = block(h, t_emb)
318
+
319
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
320
+ h = forzen_denoiser.out_layer(h.type(x.dtype))
321
+ return h, align_loss
322
+
323
+ def _infer_forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: Dict[str,torch.Tensor], vis_ratio: torch.Tensor,
324
+ forzen_denoiser: SLatFlowModel) -> sp.SparseTensor:
325
+
326
+ t_emb = forzen_denoiser.t_embedder(t)
327
+ if forzen_denoiser.share_mod:
328
+ t_emb = forzen_denoiser.adaLN_modulation(t_emb)
329
+ t_emb = t_emb.type(self.dtype)
330
+
331
+ # moge feats and image mask
332
+ cond_moge = cond['cond_scene']
333
+ cond_dino = cond['cond_instance']
334
+ cond_dino_masked = cond['cond_instance_masked']
335
+ # voxels with projected feats
336
+ x_feat = cond['cond_voxel_feats']
337
+
338
+ neg_infer = cond.pop("neg_infer", False)
339
+
340
+ cond_control = cond_moge
341
+ cond_control = cond_control.type(self.dtype)
342
+ cond_dino = cond_dino.type(self.dtype)
343
+ cond_dino_masked = cond_dino_masked.type(self.dtype)
344
+
345
+ vis_ratio_emb = self.vis_ratio_embedder(vis_ratio)
346
+ vis_ratio_emb = vis_ratio_emb.type(self.dtype)
347
+
348
+ # input layer of frozen part
349
+ h, skips = self._input_slat(x, t_emb, self.input_layer,
350
+ forzen_denoiser.input_blocks,
351
+ forzen_denoiser.pos_embedder if self.pe_mode == "ape" else None)
352
+ # input layer of frozen part
353
+
354
+ # condition branch
355
+ if not neg_infer:
356
+ ctrl_h, _ = self._input_slat(x_feat, vis_ratio_emb, self.input_layer_cond,
357
+ forzen_denoiser.input_blocks,
358
+ forzen_denoiser.pos_embedder if self.pe_mode == "ape" else None)
359
+ # condition branch
360
+
361
+ for block_index, block in enumerate(forzen_denoiser.blocks):
362
+ h = block(h, t_emb, cond_dino_masked)
363
+ if not neg_infer:
364
+ if block_index < self.num_blocks:
365
+ ctrl_h = self.blocks[block_index](ctrl_h, t_emb, vis_ratio_emb, cond_dino, cond_control)
366
+ h = h + self.control_path[block_index](ctrl_h)
367
+
368
+ # unpack with output blocks
369
+ for block, skip in zip(forzen_denoiser.out_blocks, reversed(skips)):
370
+ if self.use_skip_connection:
371
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
372
+ else:
373
+ h = block(h, t_emb)
374
+
375
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
376
+ h = forzen_denoiser.out_layer(h.type(x.dtype))
377
+ return h
378
+
379
+ def _infer_std_forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: Dict[str,torch.Tensor], vis_ratio: torch.Tensor,
380
+ forzen_denoiser: SLatFlowModel) -> sp.SparseTensor:
381
+
382
+ t_emb = forzen_denoiser.t_embedder(t)
383
+ if forzen_denoiser.share_mod:
384
+ t_emb = forzen_denoiser.adaLN_modulation(t_emb)
385
+ t_emb = t_emb.type(self.dtype)
386
+
387
+ cond_dino = cond['std_cond_instance']
388
+ cond_dino = cond_dino.type(self.dtype)
389
+
390
+ # input layer of frozen part
391
+ h, skips = self._input_slat(x, t_emb, forzen_denoiser.input_layer,
392
+ forzen_denoiser.input_blocks,
393
+ forzen_denoiser.pos_embedder if self.pe_mode == "ape" else None)
394
+ # input layer of frozen part
395
+
396
+ for block_index, block in enumerate(forzen_denoiser.blocks):
397
+ h = block(h, t_emb, cond_dino)
398
+
399
+ # unpack with output blocks
400
+ for block, skip in zip(forzen_denoiser.out_blocks, reversed(skips)):
401
+ if self.use_skip_connection:
402
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
403
+ else:
404
+ h = block(h, t_emb)
405
+
406
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
407
+ h = forzen_denoiser.out_layer(h.type(x.dtype))
408
+ return h
409
+
410
+ class ElasticSceneSLatFlowModel(SparseTransformerElasticMixin, SceneSLatFlowModel):
411
+ """
412
+ SLat Flow Model with elastic memory management.
413
+ Used for training with low VRAM.
414
+ """
415
+ pass
threeDFixer/models/sparse_elastic_mixin.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from contextlib import contextmanager
7
+ from typing import *
8
+ import math
9
+ from ..modules import sparse as sp
10
+ from ..utils.elastic_utils import ElasticModuleMixin
11
+
12
+
13
+ class SparseTransformerElasticMixin(ElasticModuleMixin):
14
+ def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
15
+ return x.feats.shape[0]
16
+
17
+ @contextmanager
18
+ def with_mem_ratio(self, mem_ratio=1.0):
19
+ if mem_ratio == 1.0:
20
+ yield 1.0
21
+ return
22
+ num_blocks = len(self.blocks)
23
+ num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
24
+ exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
25
+ for i in range(num_blocks):
26
+ self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
27
+ yield exact_mem_ratio
28
+ for i in range(num_blocks):
29
+ self.blocks[i].use_checkpoint = False
threeDFixer/models/sparse_structure_flow.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/microsoft/TRELLIS
2
+ # Original license: MIT
3
+ # Copyright (c) the TRELLIS authors
4
+ # Minor modifications by Ze-Xin Yin and Robot labs of Horizon Robotics, 2026.
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from . import from_pretrained
12
+ from ..modules.utils import convert_module_to_f16, convert_module_to_f32
13
+ from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
14
+ from ..modules.spatial import patchify, unpatchify
15
+
16
+
17
+ class TimestepEmbedder(nn.Module):
18
+ """
19
+ Embeds scalar timesteps into vector representations.
20
+ """
21
+ def __init__(self, hidden_size, frequency_embedding_size=256):
22
+ super().__init__()
23
+ self.mlp = nn.Sequential(
24
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
25
+ nn.SiLU(),
26
+ nn.Linear(hidden_size, hidden_size, bias=True),
27
+ )
28
+ self.frequency_embedding_size = frequency_embedding_size
29
+
30
+ @staticmethod
31
+ def timestep_embedding(t, dim, max_period=10000):
32
+ """
33
+ Create sinusoidal timestep embeddings.
34
+
35
+ Args:
36
+ t: a 1-D Tensor of N indices, one per batch element.
37
+ These may be fractional.
38
+ dim: the dimension of the output.
39
+ max_period: controls the minimum frequency of the embeddings.
40
+
41
+ Returns:
42
+ an (N, D) Tensor of positional embeddings.
43
+ """
44
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
48
+ ).to(device=t.device)
49
+ args = t[:, None].float() * freqs[None]
50
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
+ if dim % 2:
52
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
53
+ return embedding
54
+
55
+ def forward(self, t):
56
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
57
+ t_emb = self.mlp(t_freq)
58
+ return t_emb
59
+
60
+
61
+ class SparseStructureFlowModel(nn.Module):
62
+ def __init__(
63
+ self,
64
+ resolution: int,
65
+ in_channels: int,
66
+ model_channels: int,
67
+ cond_channels: int,
68
+ out_channels: int,
69
+ num_blocks: int,
70
+ num_heads: Optional[int] = None,
71
+ num_head_channels: Optional[int] = 64,
72
+ mlp_ratio: float = 4,
73
+ patch_size: int = 2,
74
+ pe_mode: Literal["ape", "rope"] = "ape",
75
+ use_fp16: bool = False,
76
+ use_checkpoint: bool = False,
77
+ share_mod: bool = False,
78
+ qk_rms_norm: bool = False,
79
+ qk_rms_norm_cross: bool = False,
80
+ pretrained_ss_flow_dit: str = None,
81
+ ):
82
+ super().__init__()
83
+ self.resolution = resolution
84
+ self.in_channels = in_channels
85
+ self.model_channels = model_channels
86
+ self.cond_channels = cond_channels
87
+ self.out_channels = out_channels
88
+ self.num_blocks = num_blocks
89
+ self.num_heads = num_heads or model_channels // num_head_channels
90
+ self.mlp_ratio = mlp_ratio
91
+ self.patch_size = patch_size
92
+ self.pe_mode = pe_mode
93
+ self.use_fp16 = use_fp16
94
+ self.use_checkpoint = use_checkpoint
95
+ self.share_mod = share_mod
96
+ self.qk_rms_norm = qk_rms_norm
97
+ self.qk_rms_norm_cross = qk_rms_norm_cross
98
+ self.dtype = torch.float16 if use_fp16 else torch.float32
99
+
100
+ self.t_embedder = TimestepEmbedder(model_channels)
101
+ if share_mod:
102
+ self.adaLN_modulation = nn.Sequential(
103
+ nn.SiLU(),
104
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
105
+ )
106
+
107
+ if pe_mode == "ape":
108
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
109
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
110
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
111
+ pos_emb = pos_embedder(coords)
112
+ self.register_buffer("pos_emb", pos_emb)
113
+
114
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
115
+
116
+ self.blocks = nn.ModuleList([
117
+ ModulatedTransformerCrossBlock(
118
+ model_channels,
119
+ cond_channels,
120
+ num_heads=self.num_heads,
121
+ mlp_ratio=self.mlp_ratio,
122
+ attn_mode='full',
123
+ use_checkpoint=self.use_checkpoint,
124
+ use_rope=(pe_mode == "rope"),
125
+ share_mod=share_mod,
126
+ qk_rms_norm=self.qk_rms_norm,
127
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
128
+ )
129
+ for _ in range(num_blocks)
130
+ ])
131
+
132
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
133
+
134
+ self.initialize_weights()
135
+ if pretrained_ss_flow_dit is not None:
136
+ if pretrained_ss_flow_dit.endswith('.pt'):
137
+ print (f'loading pretrained weight: {pretrained_ss_flow_dit}')
138
+ model_ckpt = torch.load(pretrained_ss_flow_dit, map_location='cpu', weights_only=True)
139
+ self.load_state_dict(model_ckpt)
140
+ del model_ckpt
141
+ else:
142
+ print (f'loading pretrained weight: {pretrained_ss_flow_dit}')
143
+ pre_trained_models = from_pretrained(pretrained_ss_flow_dit)
144
+ pre_trained_models: SparseStructureFlowModel
145
+ self.load_state_dict(pre_trained_models.state_dict())
146
+ del pre_trained_models
147
+ if use_fp16:
148
+ self.convert_to_fp16()
149
+
150
+ @property
151
+ def device(self) -> torch.device:
152
+ """
153
+ Return the device of the model.
154
+ """
155
+ return next(self.parameters()).device
156
+
157
+ def convert_to_fp16(self) -> None:
158
+ """
159
+ Convert the torso of the model to float16.
160
+ """
161
+ self.blocks.apply(convert_module_to_f16)
162
+
163
+ def convert_to_fp32(self) -> None:
164
+ """
165
+ Convert the torso of the model to float32.
166
+ """
167
+ self.blocks.apply(convert_module_to_f32)
168
+
169
+ def initialize_weights(self) -> None:
170
+ # Initialize transformer layers:
171
+ def _basic_init(module):
172
+ if isinstance(module, nn.Linear):
173
+ torch.nn.init.xavier_uniform_(module.weight)
174
+ if module.bias is not None:
175
+ nn.init.constant_(module.bias, 0)
176
+ self.apply(_basic_init)
177
+
178
+ # Initialize timestep embedding MLP:
179
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
180
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
181
+
182
+ # Zero-out adaLN modulation layers in DiT blocks:
183
+ if self.share_mod:
184
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
185
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
186
+ else:
187
+ for block in self.blocks:
188
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
189
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
190
+
191
+ # Zero-out output layers:
192
+ nn.init.constant_(self.out_layer.weight, 0)
193
+ nn.init.constant_(self.out_layer.bias, 0)
194
+
195
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
196
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
197
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
198
+
199
+ h = patchify(x, self.patch_size)
200
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
201
+
202
+ h = self.input_layer(h)
203
+ h = h + self.pos_emb[None]
204
+ t_emb = self.t_embedder(t)
205
+ if self.share_mod:
206
+ t_emb = self.adaLN_modulation(t_emb)
207
+ t_emb = t_emb.type(self.dtype)
208
+ h = h.type(self.dtype)
209
+ cond = cond.type(self.dtype)
210
+ for block in self.blocks:
211
+ h = block(h, t_emb, cond)
212
+ h = h.type(x.dtype)
213
+ h = F.layer_norm(h, h.shape[-1:])
214
+ h = self.out_layer(h)
215
+
216
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
217
+ h = unpatchify(h, self.patch_size).contiguous()
218
+
219
+ return h
threeDFixer/models/sparse_structure_vae.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/microsoft/TRELLIS
2
+ # Original license: MIT
3
+ # Copyright (c) the TRELLIS authors
4
+ # Minor modifications by Ze-Xin Yin and Robot labs of Horizon Robotics, 2026.
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from ..modules.norm import GroupNorm32, ChannelLayerNorm32
11
+ from ..modules.spatial import pixel_shuffle_3d
12
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
13
+ from . import from_pretrained
14
+
15
+
16
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
17
+ """
18
+ Return a normalization layer.
19
+ """
20
+ if norm_type == "group":
21
+ return GroupNorm32(32, *args, **kwargs)
22
+ elif norm_type == "layer":
23
+ return ChannelLayerNorm32(*args, **kwargs)
24
+ else:
25
+ raise ValueError(f"Invalid norm type {norm_type}")
26
+
27
+
28
+ class ResBlock3d(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ out_channels: Optional[int] = None,
33
+ norm_type: Literal["group", "layer"] = "layer",
34
+ ):
35
+ super().__init__()
36
+ self.channels = channels
37
+ self.out_channels = out_channels or channels
38
+
39
+ self.norm1 = norm_layer(norm_type, channels)
40
+ self.norm2 = norm_layer(norm_type, self.out_channels)
41
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
42
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
43
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ h = self.norm1(x)
47
+ h = F.silu(h)
48
+ h = self.conv1(h)
49
+ h = self.norm2(h)
50
+ h = F.silu(h)
51
+ h = self.conv2(h)
52
+ h = h + self.skip_connection(x)
53
+ return h
54
+
55
+
56
+ class DownsampleBlock3d(nn.Module):
57
+ def __init__(
58
+ self,
59
+ in_channels: int,
60
+ out_channels: int,
61
+ mode: Literal["conv", "avgpool"] = "conv",
62
+ ):
63
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
64
+
65
+ super().__init__()
66
+ self.in_channels = in_channels
67
+ self.out_channels = out_channels
68
+
69
+ if mode == "conv":
70
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
71
+ elif mode == "avgpool":
72
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ if hasattr(self, "conv"):
76
+ return self.conv(x)
77
+ else:
78
+ return F.avg_pool3d(x, 2)
79
+
80
+
81
+ class UpsampleBlock3d(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ mode: Literal["conv", "nearest"] = "conv",
87
+ ):
88
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
89
+
90
+ super().__init__()
91
+ self.in_channels = in_channels
92
+ self.out_channels = out_channels
93
+
94
+ if mode == "conv":
95
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
96
+ elif mode == "nearest":
97
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ if hasattr(self, "conv"):
101
+ x = self.conv(x)
102
+ return pixel_shuffle_3d(x, 2)
103
+ else:
104
+ return F.interpolate(x, scale_factor=2, mode="nearest")
105
+
106
+
107
+ class SparseStructureEncoder(nn.Module):
108
+ """
109
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
110
+
111
+ Args:
112
+ in_channels (int): Channels of the input.
113
+ latent_channels (int): Channels of the latent representation.
114
+ num_res_blocks (int): Number of residual blocks at each resolution.
115
+ channels (List[int]): Channels of the encoder blocks.
116
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
117
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
118
+ use_fp16 (bool): Whether to use FP16.
119
+ """
120
+ def __init__(
121
+ self,
122
+ in_channels: int,
123
+ latent_channels: int,
124
+ num_res_blocks: int,
125
+ channels: List[int],
126
+ num_res_blocks_middle: int = 2,
127
+ norm_type: Literal["group", "layer"] = "layer",
128
+ use_fp16: bool = False,
129
+ pretrained_ss_enc: str = None,
130
+ ):
131
+ super().__init__()
132
+ self.in_channels = in_channels
133
+ self.latent_channels = latent_channels
134
+ self.num_res_blocks = num_res_blocks
135
+ self.channels = channels
136
+ self.num_res_blocks_middle = num_res_blocks_middle
137
+ self.norm_type = norm_type
138
+ self.use_fp16 = use_fp16
139
+ self.dtype = torch.float16 if use_fp16 else torch.float32
140
+
141
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
142
+
143
+ self.blocks = nn.ModuleList([])
144
+ for i, ch in enumerate(channels):
145
+ self.blocks.extend([
146
+ ResBlock3d(ch, ch)
147
+ for _ in range(num_res_blocks)
148
+ ])
149
+ if i < len(channels) - 1:
150
+ self.blocks.append(
151
+ DownsampleBlock3d(ch, channels[i+1])
152
+ )
153
+
154
+ self.middle_block = nn.Sequential(*[
155
+ ResBlock3d(channels[-1], channels[-1])
156
+ for _ in range(num_res_blocks_middle)
157
+ ])
158
+
159
+ self.out_layer = nn.Sequential(
160
+ norm_layer(norm_type, channels[-1]),
161
+ nn.SiLU(),
162
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
163
+ )
164
+
165
+ if pretrained_ss_enc is not None:
166
+ if pretrained_ss_enc.endswith('.pt'):
167
+ print (f'loading pretrained weight: {pretrained_ss_enc}')
168
+ model_ckpt = torch.load(pretrained_ss_enc, map_location='cpu', weights_only=True)
169
+ self.load_state_dict(model_ckpt)
170
+ del model_ckpt
171
+ else:
172
+ print (f'loading pretrained weight: {pretrained_ss_enc}')
173
+ pre_trained_models = from_pretrained(pretrained_ss_enc)
174
+ pre_trained_models: SparseStructureEncoder
175
+ self.load_state_dict(pre_trained_models.state_dict())
176
+ del pre_trained_models
177
+ if use_fp16:
178
+ self.convert_to_fp16()
179
+
180
+ @property
181
+ def device(self) -> torch.device:
182
+ """
183
+ Return the device of the model.
184
+ """
185
+ return next(self.parameters()).device
186
+
187
+ def convert_to_fp16(self) -> None:
188
+ """
189
+ Convert the torso of the model to float16.
190
+ """
191
+ self.use_fp16 = True
192
+ self.dtype = torch.float16
193
+ self.blocks.apply(convert_module_to_f16)
194
+ self.middle_block.apply(convert_module_to_f16)
195
+
196
+ def convert_to_fp32(self) -> None:
197
+ """
198
+ Convert the torso of the model to float32.
199
+ """
200
+ self.use_fp16 = False
201
+ self.dtype = torch.float32
202
+ self.blocks.apply(convert_module_to_f32)
203
+ self.middle_block.apply(convert_module_to_f32)
204
+
205
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
206
+ h = self.input_layer(x)
207
+ h = h.type(self.dtype)
208
+
209
+ for block in self.blocks:
210
+ h = block(h)
211
+ h = self.middle_block(h)
212
+
213
+ h = h.type(x.dtype)
214
+ h = self.out_layer(h)
215
+
216
+ mean, logvar = h.chunk(2, dim=1)
217
+
218
+ if sample_posterior:
219
+ std = torch.exp(0.5 * logvar)
220
+ z = mean + std * torch.randn_like(std)
221
+ else:
222
+ z = mean
223
+
224
+ if return_raw:
225
+ return z, mean, logvar
226
+ return z
227
+
228
+
229
+ class SparseStructureDecoder(nn.Module):
230
+ """
231
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
232
+
233
+ Args:
234
+ out_channels (int): Channels of the output.
235
+ latent_channels (int): Channels of the latent representation.
236
+ num_res_blocks (int): Number of residual blocks at each resolution.
237
+ channels (List[int]): Channels of the decoder blocks.
238
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
239
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
240
+ use_fp16 (bool): Whether to use FP16.
241
+ """
242
+ def __init__(
243
+ self,
244
+ out_channels: int,
245
+ latent_channels: int,
246
+ num_res_blocks: int,
247
+ channels: List[int],
248
+ num_res_blocks_middle: int = 2,
249
+ norm_type: Literal["group", "layer"] = "layer",
250
+ use_fp16: bool = False,
251
+ ):
252
+ super().__init__()
253
+ self.out_channels = out_channels
254
+ self.latent_channels = latent_channels
255
+ self.num_res_blocks = num_res_blocks
256
+ self.channels = channels
257
+ self.num_res_blocks_middle = num_res_blocks_middle
258
+ self.norm_type = norm_type
259
+ self.use_fp16 = use_fp16
260
+ self.dtype = torch.float16 if use_fp16 else torch.float32
261
+
262
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
263
+
264
+ self.middle_block = nn.Sequential(*[
265
+ ResBlock3d(channels[0], channels[0])
266
+ for _ in range(num_res_blocks_middle)
267
+ ])
268
+
269
+ self.blocks = nn.ModuleList([])
270
+ for i, ch in enumerate(channels):
271
+ self.blocks.extend([
272
+ ResBlock3d(ch, ch)
273
+ for _ in range(num_res_blocks)
274
+ ])
275
+ if i < len(channels) - 1:
276
+ self.blocks.append(
277
+ UpsampleBlock3d(ch, channels[i+1])
278
+ )
279
+
280
+ self.out_layer = nn.Sequential(
281
+ norm_layer(norm_type, channels[-1]),
282
+ nn.SiLU(),
283
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
284
+ )
285
+
286
+ if use_fp16:
287
+ self.convert_to_fp16()
288
+
289
+ @property
290
+ def device(self) -> torch.device:
291
+ """
292
+ Return the device of the model.
293
+ """
294
+ return next(self.parameters()).device
295
+
296
+ def convert_to_fp16(self) -> None:
297
+ """
298
+ Convert the torso of the model to float16.
299
+ """
300
+ self.use_fp16 = True
301
+ self.dtype = torch.float16
302
+ self.blocks.apply(convert_module_to_f16)
303
+ self.middle_block.apply(convert_module_to_f16)
304
+
305
+ def convert_to_fp32(self) -> None:
306
+ """
307
+ Convert the torso of the model to float32.
308
+ """
309
+ self.use_fp16 = False
310
+ self.dtype = torch.float32
311
+ self.blocks.apply(convert_module_to_f32)
312
+ self.middle_block.apply(convert_module_to_f32)
313
+
314
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
315
+ h = self.input_layer(x)
316
+
317
+ h = h.type(self.dtype)
318
+
319
+ h = self.middle_block(h)
320
+ for block in self.blocks:
321
+ h = block(h)
322
+
323
+ h = h.type(x.dtype)
324
+ h = self.out_layer(h)
325
+ return h
threeDFixer/models/structured_latent_flow.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/microsoft/TRELLIS
2
+ # Original license: MIT
3
+ # Copyright (c) the TRELLIS authors
4
+ # Minor modifications by Ze-Xin Yin and Robot labs of Horizon Robotics, 2026.
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
12
+ from ..modules.transformer import AbsolutePositionEmbedder
13
+ from ..modules.norm import LayerNorm32
14
+ from ..modules import sparse as sp
15
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
16
+ from .sparse_structure_flow import TimestepEmbedder
17
+ from .sparse_elastic_mixin import SparseTransformerElasticMixin
18
+ from . import from_pretrained
19
+
20
+
21
+ class SparseResBlock3d(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ emb_channels: int,
26
+ out_channels: Optional[int] = None,
27
+ downsample: bool = False,
28
+ upsample: bool = False,
29
+ ):
30
+ super().__init__()
31
+ self.channels = channels
32
+ self.emb_channels = emb_channels
33
+ self.out_channels = out_channels or channels
34
+ self.downsample = downsample
35
+ self.upsample = upsample
36
+
37
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
38
+
39
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
40
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
41
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
42
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
43
+ self.emb_layers = nn.Sequential(
44
+ nn.SiLU(),
45
+ nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
46
+ )
47
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
48
+ self.updown = None
49
+ if self.downsample:
50
+ self.updown = sp.SparseDownsample(2)
51
+ elif self.upsample:
52
+ self.updown = sp.SparseUpsample(2)
53
+
54
+ def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
55
+ if self.updown is not None:
56
+ x = self.updown(x)
57
+ return x
58
+
59
+ def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
60
+ emb_out = self.emb_layers(emb).type(x.dtype)
61
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
62
+
63
+ x = self._updown(x)
64
+ h = x.replace(self.norm1(x.feats))
65
+ h = h.replace(F.silu(h.feats))
66
+ h = self.conv1(h)
67
+ h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
68
+ h = h.replace(F.silu(h.feats))
69
+ h = self.conv2(h)
70
+ h = h + self.skip_connection(x)
71
+
72
+ return h
73
+
74
+
75
+ class SLatFlowModel(nn.Module):
76
+ def __init__(
77
+ self,
78
+ resolution: int,
79
+ in_channels: int,
80
+ model_channels: int,
81
+ cond_channels: int,
82
+ out_channels: int,
83
+ num_blocks: int,
84
+ num_heads: Optional[int] = None,
85
+ num_head_channels: Optional[int] = 64,
86
+ mlp_ratio: float = 4,
87
+ patch_size: int = 2,
88
+ num_io_res_blocks: int = 2,
89
+ io_block_channels: List[int] = None,
90
+ pe_mode: Literal["ape", "rope"] = "ape",
91
+ use_fp16: bool = False,
92
+ use_checkpoint: bool = False,
93
+ use_skip_connection: bool = True,
94
+ share_mod: bool = False,
95
+ qk_rms_norm: bool = False,
96
+ qk_rms_norm_cross: bool = False,
97
+ pretrained_flow_dit: str = None,
98
+ ):
99
+ super().__init__()
100
+ self.resolution = resolution
101
+ self.in_channels = in_channels
102
+ self.model_channels = model_channels
103
+ self.cond_channels = cond_channels
104
+ self.out_channels = out_channels
105
+ self.num_blocks = num_blocks
106
+ self.num_heads = num_heads or model_channels // num_head_channels
107
+ self.mlp_ratio = mlp_ratio
108
+ self.patch_size = patch_size
109
+ self.num_io_res_blocks = num_io_res_blocks
110
+ self.io_block_channels = io_block_channels
111
+ self.pe_mode = pe_mode
112
+ self.use_fp16 = use_fp16
113
+ self.use_checkpoint = use_checkpoint
114
+ self.use_skip_connection = use_skip_connection
115
+ self.share_mod = share_mod
116
+ self.qk_rms_norm = qk_rms_norm
117
+ self.qk_rms_norm_cross = qk_rms_norm_cross
118
+ self.dtype = torch.float16 if use_fp16 else torch.float32
119
+
120
+ if self.io_block_channels is not None:
121
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
122
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
123
+
124
+ self.t_embedder = TimestepEmbedder(model_channels)
125
+ if share_mod:
126
+ self.adaLN_modulation = nn.Sequential(
127
+ nn.SiLU(),
128
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
129
+ )
130
+
131
+ if pe_mode == "ape":
132
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
133
+
134
+ self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
135
+
136
+ self.input_blocks = nn.ModuleList([])
137
+ if io_block_channels is not None:
138
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
139
+ self.input_blocks.extend([
140
+ SparseResBlock3d(
141
+ chs,
142
+ model_channels,
143
+ out_channels=chs,
144
+ )
145
+ for _ in range(num_io_res_blocks-1)
146
+ ])
147
+ self.input_blocks.append(
148
+ SparseResBlock3d(
149
+ chs,
150
+ model_channels,
151
+ out_channels=next_chs,
152
+ downsample=True,
153
+ )
154
+ )
155
+
156
+ self.blocks = nn.ModuleList([
157
+ ModulatedSparseTransformerCrossBlock(
158
+ model_channels,
159
+ cond_channels,
160
+ num_heads=self.num_heads,
161
+ mlp_ratio=self.mlp_ratio,
162
+ attn_mode='full',
163
+ use_checkpoint=self.use_checkpoint,
164
+ use_rope=(pe_mode == "rope"),
165
+ share_mod=self.share_mod,
166
+ qk_rms_norm=self.qk_rms_norm,
167
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
168
+ )
169
+ for _ in range(num_blocks)
170
+ ])
171
+
172
+ self.out_blocks = nn.ModuleList([])
173
+ if io_block_channels is not None:
174
+ for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
175
+ self.out_blocks.append(
176
+ SparseResBlock3d(
177
+ prev_chs * 2 if self.use_skip_connection else prev_chs,
178
+ model_channels,
179
+ out_channels=chs,
180
+ upsample=True,
181
+ )
182
+ )
183
+ self.out_blocks.extend([
184
+ SparseResBlock3d(
185
+ chs * 2 if self.use_skip_connection else chs,
186
+ model_channels,
187
+ out_channels=chs,
188
+ )
189
+ for _ in range(num_io_res_blocks-1)
190
+ ])
191
+
192
+ self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
193
+
194
+ self.initialize_weights()
195
+ if pretrained_flow_dit is not None:
196
+ if pretrained_flow_dit.endswith('.pt'):
197
+ print (f'loading pretrained weight: {pretrained_flow_dit}')
198
+ model_ckpt = torch.load(pretrained_flow_dit, map_location='cpu', weights_only=True)
199
+ self.load_state_dict(model_ckpt)
200
+ del model_ckpt
201
+ else:
202
+ print (f'loading pretrained weight: {pretrained_flow_dit}')
203
+ pre_trained_models: SLatFlowModel
204
+ pre_trained_models = from_pretrained(pretrained_flow_dit)
205
+ self.load_state_dict(pre_trained_models.state_dict())
206
+ del pre_trained_models
207
+ if use_fp16:
208
+ self.convert_to_fp16()
209
+
210
+ @property
211
+ def device(self) -> torch.device:
212
+ """
213
+ Return the device of the model.
214
+ """
215
+ return next(self.parameters()).device
216
+
217
+ def convert_to_fp16(self) -> None:
218
+ """
219
+ Convert the torso of the model to float16.
220
+ """
221
+ self.input_blocks.apply(convert_module_to_f16)
222
+ self.blocks.apply(convert_module_to_f16)
223
+ self.out_blocks.apply(convert_module_to_f16)
224
+
225
+ def convert_to_fp32(self) -> None:
226
+ """
227
+ Convert the torso of the model to float32.
228
+ """
229
+ self.input_blocks.apply(convert_module_to_f32)
230
+ self.blocks.apply(convert_module_to_f32)
231
+ self.out_blocks.apply(convert_module_to_f32)
232
+
233
+ def initialize_weights(self) -> None:
234
+ # Initialize transformer layers:
235
+ def _basic_init(module):
236
+ if isinstance(module, nn.Linear):
237
+ torch.nn.init.xavier_uniform_(module.weight)
238
+ if module.bias is not None:
239
+ nn.init.constant_(module.bias, 0)
240
+ self.apply(_basic_init)
241
+
242
+ # Initialize timestep embedding MLP:
243
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
244
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
245
+
246
+ # Zero-out adaLN modulation layers in DiT blocks:
247
+ if self.share_mod:
248
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
249
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
250
+ else:
251
+ for block in self.blocks:
252
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
253
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
254
+
255
+ # Zero-out output layers:
256
+ nn.init.constant_(self.out_layer.weight, 0)
257
+ nn.init.constant_(self.out_layer.bias, 0)
258
+
259
+ def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
260
+ h = self.input_layer(x).type(self.dtype)
261
+ t_emb = self.t_embedder(t)
262
+ if self.share_mod:
263
+ t_emb = self.adaLN_modulation(t_emb)
264
+ t_emb = t_emb.type(self.dtype)
265
+ cond = cond.type(self.dtype)
266
+
267
+ skips = []
268
+ # pack with input blocks
269
+ for block in self.input_blocks:
270
+ h = block(h, t_emb)
271
+ skips.append(h.feats)
272
+
273
+ if self.pe_mode == "ape":
274
+ h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
275
+ for block in self.blocks:
276
+ h = block(h, t_emb, cond)
277
+
278
+ # unpack with output blocks
279
+ for block, skip in zip(self.out_blocks, reversed(skips)):
280
+ if self.use_skip_connection:
281
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
282
+ else:
283
+ h = block(h, t_emb)
284
+
285
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
286
+ h = self.out_layer(h.type(x.dtype))
287
+ return h
288
+
289
+
290
+ class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
291
+ """
292
+ SLat Flow Model with elastic memory management.
293
+ Used for training with low VRAM.
294
+ """
295
+ pass
threeDFixer/models/structured_latent_vae/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from .encoder import SLatEncoder, ElasticSLatEncoder
7
+ from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
8
+ from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
9
+ from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
threeDFixer/models/structured_latent_vae/base.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
10
+ from ...modules import sparse as sp
11
+ from ...modules.transformer import AbsolutePositionEmbedder
12
+ from ...modules.sparse.transformer import SparseTransformerBlock
13
+
14
+
15
+ def block_attn_config(self):
16
+ """
17
+ Return the attention configuration of the model.
18
+ """
19
+ for i in range(self.num_blocks):
20
+ if self.attn_mode == "shift_window":
21
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
22
+ elif self.attn_mode == "shift_sequence":
23
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
24
+ elif self.attn_mode == "shift_order":
25
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
26
+ elif self.attn_mode == "full":
27
+ yield "full", None, None, None, None
28
+ elif self.attn_mode == "swin":
29
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
30
+
31
+
32
+ class SparseTransformerBase(nn.Module):
33
+ """
34
+ Sparse Transformer without output layers.
35
+ Serve as the base class for encoder and decoder.
36
+ """
37
+ def __init__(
38
+ self,
39
+ in_channels: int,
40
+ model_channels: int,
41
+ num_blocks: int,
42
+ num_heads: Optional[int] = None,
43
+ num_head_channels: Optional[int] = 64,
44
+ mlp_ratio: float = 4.0,
45
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
46
+ window_size: Optional[int] = None,
47
+ pe_mode: Literal["ape", "rope"] = "ape",
48
+ use_fp16: bool = False,
49
+ use_checkpoint: bool = False,
50
+ qk_rms_norm: bool = False,
51
+ ):
52
+ super().__init__()
53
+ self.in_channels = in_channels
54
+ self.model_channels = model_channels
55
+ self.num_blocks = num_blocks
56
+ self.window_size = window_size
57
+ self.num_heads = num_heads or model_channels // num_head_channels
58
+ self.mlp_ratio = mlp_ratio
59
+ self.attn_mode = attn_mode
60
+ self.pe_mode = pe_mode
61
+ self.use_fp16 = use_fp16
62
+ self.use_checkpoint = use_checkpoint
63
+ self.qk_rms_norm = qk_rms_norm
64
+ self.dtype = torch.float16 if use_fp16 else torch.float32
65
+
66
+ if pe_mode == "ape":
67
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
68
+
69
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
70
+ self.blocks = nn.ModuleList([
71
+ SparseTransformerBlock(
72
+ model_channels,
73
+ num_heads=self.num_heads,
74
+ mlp_ratio=self.mlp_ratio,
75
+ attn_mode=attn_mode,
76
+ window_size=window_size,
77
+ shift_sequence=shift_sequence,
78
+ shift_window=shift_window,
79
+ serialize_mode=serialize_mode,
80
+ use_checkpoint=self.use_checkpoint,
81
+ use_rope=(pe_mode == "rope"),
82
+ qk_rms_norm=self.qk_rms_norm,
83
+ )
84
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
85
+ ])
86
+
87
+ @property
88
+ def device(self) -> torch.device:
89
+ """
90
+ Return the device of the model.
91
+ """
92
+ return next(self.parameters()).device
93
+
94
+ def convert_to_fp16(self) -> None:
95
+ """
96
+ Convert the torso of the model to float16.
97
+ """
98
+ self.blocks.apply(convert_module_to_f16)
99
+
100
+ def convert_to_fp32(self) -> None:
101
+ """
102
+ Convert the torso of the model to float32.
103
+ """
104
+ self.blocks.apply(convert_module_to_f32)
105
+
106
+ def initialize_weights(self) -> None:
107
+ # Initialize transformer layers:
108
+ def _basic_init(module):
109
+ if isinstance(module, nn.Linear):
110
+ torch.nn.init.xavier_uniform_(module.weight)
111
+ if module.bias is not None:
112
+ nn.init.constant_(module.bias, 0)
113
+ self.apply(_basic_init)
114
+
115
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
116
+ h = self.input_layer(x)
117
+ if self.pe_mode == "ape":
118
+ h = h + self.pos_embedder(x.coords[:, 1:])
119
+ h = h.type(self.dtype)
120
+ for block in self.blocks:
121
+ h = block(h)
122
+ return h
threeDFixer/models/structured_latent_vae/decoder_gs.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from ...modules import sparse as sp
11
+ from ...utils.random_utils import hammersley_sequence
12
+ from .base import SparseTransformerBase
13
+ from ...representations import Gaussian
14
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
15
+ from .. import from_pretrained
16
+
17
+
18
+ class SLatGaussianDecoder(SparseTransformerBase):
19
+ def __init__(
20
+ self,
21
+ resolution: int,
22
+ model_channels: int,
23
+ latent_channels: int,
24
+ num_blocks: int,
25
+ num_heads: Optional[int] = None,
26
+ num_head_channels: Optional[int] = 64,
27
+ mlp_ratio: float = 4,
28
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
29
+ window_size: int = 8,
30
+ pe_mode: Literal["ape", "rope"] = "ape",
31
+ use_fp16: bool = False,
32
+ use_checkpoint: bool = False,
33
+ qk_rms_norm: bool = False,
34
+ representation_config: dict = None,
35
+ pretrained_gs_dec: str = None,
36
+ ):
37
+ super().__init__(
38
+ in_channels=latent_channels,
39
+ model_channels=model_channels,
40
+ num_blocks=num_blocks,
41
+ num_heads=num_heads,
42
+ num_head_channels=num_head_channels,
43
+ mlp_ratio=mlp_ratio,
44
+ attn_mode=attn_mode,
45
+ window_size=window_size,
46
+ pe_mode=pe_mode,
47
+ use_fp16=use_fp16,
48
+ use_checkpoint=use_checkpoint,
49
+ qk_rms_norm=qk_rms_norm,
50
+ )
51
+ self.resolution = resolution
52
+ self.rep_config = representation_config
53
+ self._calc_layout()
54
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
55
+ self._build_perturbation()
56
+
57
+ self.initialize_weights()
58
+ if pretrained_gs_dec is not None:
59
+ if pretrained_gs_dec.endswith('.pt'):
60
+ print (f'loading pretrained weight: {pretrained_gs_dec}')
61
+ model_ckpt = torch.load(pretrained_gs_dec, map_location='cpu', weights_only=True)
62
+ self.load_state_dict(model_ckpt)
63
+ del model_ckpt
64
+ else:
65
+ print (f'loading pretrained weight: {pretrained_gs_dec}')
66
+ pre_trained_models: SLatGaussianDecoder
67
+ pre_trained_models = from_pretrained(pretrained_gs_dec)
68
+ self.load_state_dict(pre_trained_models.state_dict())
69
+ del pre_trained_models
70
+ if use_fp16:
71
+ self.convert_to_fp16()
72
+
73
+ def initialize_weights(self) -> None:
74
+ super().initialize_weights()
75
+ # Zero-out output layers:
76
+ nn.init.constant_(self.out_layer.weight, 0)
77
+ nn.init.constant_(self.out_layer.bias, 0)
78
+
79
+ def _build_perturbation(self) -> None:
80
+ perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
81
+ perturbation = torch.tensor(perturbation).float() * 2 - 1
82
+ perturbation = perturbation / self.rep_config['voxel_size']
83
+ perturbation = torch.atanh(perturbation).to(self.device)
84
+ self.register_buffer('offset_perturbation', perturbation)
85
+
86
+ def _calc_layout(self) -> None:
87
+ self.layout = {
88
+ '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
89
+ '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
90
+ '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
91
+ '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
92
+ '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
93
+ }
94
+ start = 0
95
+ for k, v in self.layout.items():
96
+ v['range'] = (start, start + v['size'])
97
+ start += v['size']
98
+ self.out_channels = start
99
+
100
+ def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
101
+ """
102
+ Convert a batch of network outputs to 3D representations.
103
+
104
+ Args:
105
+ x: The [N x * x C] sparse tensor output by the network.
106
+
107
+ Returns:
108
+ list of representations
109
+ """
110
+ ret = []
111
+ for i in range(x.shape[0]):
112
+ representation = Gaussian(
113
+ sh_degree=0,
114
+ aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
115
+ mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
116
+ scaling_bias = self.rep_config['scaling_bias'],
117
+ opacity_bias = self.rep_config['opacity_bias'],
118
+ scaling_activation = self.rep_config['scaling_activation']
119
+ )
120
+ xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
121
+ for k, v in self.layout.items():
122
+ if k == '_xyz':
123
+ offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
124
+ offset = offset * self.rep_config['lr'][k]
125
+ if self.rep_config['perturb_offset']:
126
+ offset = offset + self.offset_perturbation
127
+ offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
128
+ _xyz = xyz.unsqueeze(1) + offset
129
+ setattr(representation, k, _xyz.flatten(0, 1))
130
+ else:
131
+ feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
132
+ feats = feats * self.rep_config['lr'][k]
133
+ setattr(representation, k, feats)
134
+ ret.append(representation)
135
+ return ret
136
+
137
+ def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
138
+ h = super().forward(x)
139
+ h = h.type(x.dtype)
140
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
141
+ h = self.out_layer(h)
142
+ return self.to_representation(h)
143
+
144
+
145
+ class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
146
+ """
147
+ Slat VAE Gaussian decoder with elastic memory management.
148
+ Used for training with low VRAM.
149
+ """
150
+ pass
threeDFixer/models/structured_latent_vae/decoder_mesh.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
12
+ from ...modules import sparse as sp
13
+ from .base import SparseTransformerBase
14
+ from ...representations import MeshExtractResult
15
+ from ...representations.mesh import SparseFeatures2Mesh
16
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
17
+ from .. import from_pretrained
18
+
19
+
20
+ class SparseSubdivideBlock3d(nn.Module):
21
+ """
22
+ A 3D subdivide block that can subdivide the sparse tensor.
23
+
24
+ Args:
25
+ channels: channels in the inputs and outputs.
26
+ out_channels: if specified, the number of output channels.
27
+ num_groups: the number of groups for the group norm.
28
+ """
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ resolution: int,
33
+ out_channels: Optional[int] = None,
34
+ num_groups: int = 32
35
+ ):
36
+ super().__init__()
37
+ self.channels = channels
38
+ self.resolution = resolution
39
+ self.out_resolution = resolution * 2
40
+ self.out_channels = out_channels or channels
41
+
42
+ self.act_layers = nn.Sequential(
43
+ sp.SparseGroupNorm32(num_groups, channels),
44
+ sp.SparseSiLU()
45
+ )
46
+
47
+ self.sub = sp.SparseSubdivide()
48
+
49
+ self.out_layers = nn.Sequential(
50
+ sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
51
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
52
+ sp.SparseSiLU(),
53
+ zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
54
+ )
55
+
56
+ if self.out_channels == channels:
57
+ self.skip_connection = nn.Identity()
58
+ else:
59
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
60
+
61
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
62
+ """
63
+ Apply the block to a Tensor, conditioned on a timestep embedding.
64
+
65
+ Args:
66
+ x: an [N x C x ...] Tensor of features.
67
+ Returns:
68
+ an [N x C x ...] Tensor of outputs.
69
+ """
70
+ h = self.act_layers(x)
71
+ h = self.sub(h)
72
+ x = self.sub(x)
73
+ h = self.out_layers(h)
74
+ h = h + self.skip_connection(x)
75
+ return h
76
+
77
+
78
+ class SLatMeshDecoder(SparseTransformerBase):
79
+ def __init__(
80
+ self,
81
+ resolution: int,
82
+ model_channels: int,
83
+ latent_channels: int,
84
+ num_blocks: int,
85
+ num_heads: Optional[int] = None,
86
+ num_head_channels: Optional[int] = 64,
87
+ mlp_ratio: float = 4,
88
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
89
+ window_size: int = 8,
90
+ pe_mode: Literal["ape", "rope"] = "ape",
91
+ use_fp16: bool = False,
92
+ use_checkpoint: bool = False,
93
+ qk_rms_norm: bool = False,
94
+ representation_config: dict = None,
95
+ pretrained_mesh_dec: str = None,
96
+ ):
97
+ super().__init__(
98
+ in_channels=latent_channels,
99
+ model_channels=model_channels,
100
+ num_blocks=num_blocks,
101
+ num_heads=num_heads,
102
+ num_head_channels=num_head_channels,
103
+ mlp_ratio=mlp_ratio,
104
+ attn_mode=attn_mode,
105
+ window_size=window_size,
106
+ pe_mode=pe_mode,
107
+ use_fp16=use_fp16,
108
+ use_checkpoint=use_checkpoint,
109
+ qk_rms_norm=qk_rms_norm,
110
+ )
111
+ self.resolution = resolution
112
+ self.rep_config = representation_config
113
+ self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
114
+ self.out_channels = self.mesh_extractor.feats_channels
115
+ self.upsample = nn.ModuleList([
116
+ SparseSubdivideBlock3d(
117
+ channels=model_channels,
118
+ resolution=resolution,
119
+ out_channels=model_channels // 4
120
+ ),
121
+ SparseSubdivideBlock3d(
122
+ channels=model_channels // 4,
123
+ resolution=resolution * 2,
124
+ out_channels=model_channels // 8
125
+ )
126
+ ])
127
+ self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
128
+
129
+ self.initialize_weights()
130
+ if pretrained_mesh_dec is not None:
131
+ print (f'loading pretrained weight: {pretrained_mesh_dec}')
132
+ pre_trained_models: SLatMeshDecoder
133
+ pre_trained_models = from_pretrained(pretrained_mesh_dec)
134
+ self.load_state_dict(pre_trained_models.state_dict())
135
+ del pre_trained_models
136
+ if use_fp16:
137
+ self.convert_to_fp16()
138
+
139
+ def initialize_weights(self) -> None:
140
+ super().initialize_weights()
141
+ # Zero-out output layers:
142
+ nn.init.constant_(self.out_layer.weight, 0)
143
+ nn.init.constant_(self.out_layer.bias, 0)
144
+
145
+ def convert_to_fp16(self) -> None:
146
+ """
147
+ Convert the torso of the model to float16.
148
+ """
149
+ super().convert_to_fp16()
150
+ self.upsample.apply(convert_module_to_f16)
151
+
152
+ def convert_to_fp32(self) -> None:
153
+ """
154
+ Convert the torso of the model to float32.
155
+ """
156
+ super().convert_to_fp32()
157
+ self.upsample.apply(convert_module_to_f32)
158
+
159
+ def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
160
+ """
161
+ Convert a batch of network outputs to 3D representations.
162
+
163
+ Args:
164
+ x: The [N x * x C] sparse tensor output by the network.
165
+
166
+ Returns:
167
+ list of representations
168
+ """
169
+ ret = []
170
+ for i in range(x.shape[0]):
171
+ mesh = self.mesh_extractor(x[i], training=self.training)
172
+ ret.append(mesh)
173
+ return ret
174
+
175
+ def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
176
+ h = super().forward(x)
177
+ for block in self.upsample:
178
+ h = block(h)
179
+ h = h.type(x.dtype)
180
+ h = self.out_layer(h)
181
+ return self.to_representation(h)
182
+
183
+
184
+ class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder):
185
+ """
186
+ Slat VAE Mesh decoder with elastic memory management.
187
+ Used for training with low VRAM.
188
+ """
189
+ pass
threeDFixer/models/structured_latent_vae/decoder_rf.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from ...modules import sparse as sp
12
+ from .base import SparseTransformerBase
13
+ from ...representations import Strivec
14
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
15
+
16
+
17
+ class SLatRadianceFieldDecoder(SparseTransformerBase):
18
+ def __init__(
19
+ self,
20
+ resolution: int,
21
+ model_channels: int,
22
+ latent_channels: int,
23
+ num_blocks: int,
24
+ num_heads: Optional[int] = None,
25
+ num_head_channels: Optional[int] = 64,
26
+ mlp_ratio: float = 4,
27
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
28
+ window_size: int = 8,
29
+ pe_mode: Literal["ape", "rope"] = "ape",
30
+ use_fp16: bool = False,
31
+ use_checkpoint: bool = False,
32
+ qk_rms_norm: bool = False,
33
+ representation_config: dict = None,
34
+ ):
35
+ super().__init__(
36
+ in_channels=latent_channels,
37
+ model_channels=model_channels,
38
+ num_blocks=num_blocks,
39
+ num_heads=num_heads,
40
+ num_head_channels=num_head_channels,
41
+ mlp_ratio=mlp_ratio,
42
+ attn_mode=attn_mode,
43
+ window_size=window_size,
44
+ pe_mode=pe_mode,
45
+ use_fp16=use_fp16,
46
+ use_checkpoint=use_checkpoint,
47
+ qk_rms_norm=qk_rms_norm,
48
+ )
49
+ self.resolution = resolution
50
+ self.rep_config = representation_config
51
+ self._calc_layout()
52
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
53
+
54
+ self.initialize_weights()
55
+ if use_fp16:
56
+ self.convert_to_fp16()
57
+
58
+ def initialize_weights(self) -> None:
59
+ super().initialize_weights()
60
+ # Zero-out output layers:
61
+ nn.init.constant_(self.out_layer.weight, 0)
62
+ nn.init.constant_(self.out_layer.bias, 0)
63
+
64
+ def _calc_layout(self) -> None:
65
+ self.layout = {
66
+ 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
67
+ 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
68
+ 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
69
+ }
70
+ start = 0
71
+ for k, v in self.layout.items():
72
+ v['range'] = (start, start + v['size'])
73
+ start += v['size']
74
+ self.out_channels = start
75
+
76
+ def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
77
+ """
78
+ Convert a batch of network outputs to 3D representations.
79
+
80
+ Args:
81
+ x: The [N x * x C] sparse tensor output by the network.
82
+
83
+ Returns:
84
+ list of representations
85
+ """
86
+ ret = []
87
+ for i in range(x.shape[0]):
88
+ representation = Strivec(
89
+ sh_degree=0,
90
+ resolution=self.resolution,
91
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
92
+ rank=self.rep_config['rank'],
93
+ dim=self.rep_config['dim'],
94
+ device='cuda',
95
+ )
96
+ representation.density_shift = 0.0
97
+ representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
98
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
99
+ for k, v in self.layout.items():
100
+ setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
101
+ representation.trivec = representation.trivec + 1
102
+ ret.append(representation)
103
+ return ret
104
+
105
+ def forward(self, x: sp.SparseTensor) -> List[Strivec]:
106
+ h = super().forward(x)
107
+ h = h.type(x.dtype)
108
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
109
+ h = self.out_layer(h)
110
+ return self.to_representation(h)
111
+
112
+
113
+ class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
114
+ """
115
+ Slat VAE Radiance Field Decoder with elastic memory management.
116
+ Used for training with low VRAM.
117
+ """
118
+ pass
threeDFixer/models/structured_latent_vae/encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from ...modules import sparse as sp
11
+ from .base import SparseTransformerBase
12
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
13
+ from .. import from_pretrained
14
+
15
+
16
+ class SLatEncoder(SparseTransformerBase):
17
+ def __init__(
18
+ self,
19
+ resolution: int,
20
+ in_channels: int,
21
+ model_channels: int,
22
+ latent_channels: int,
23
+ num_blocks: int,
24
+ num_heads: Optional[int] = None,
25
+ num_head_channels: Optional[int] = 64,
26
+ mlp_ratio: float = 4,
27
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
28
+ window_size: int = 8,
29
+ pe_mode: Literal["ape", "rope"] = "ape",
30
+ use_fp16: bool = False,
31
+ use_checkpoint: bool = False,
32
+ qk_rms_norm: bool = False,
33
+ pretrained_slat_enc: str = None,
34
+ ):
35
+ super().__init__(
36
+ in_channels=in_channels,
37
+ model_channels=model_channels,
38
+ num_blocks=num_blocks,
39
+ num_heads=num_heads,
40
+ num_head_channels=num_head_channels,
41
+ mlp_ratio=mlp_ratio,
42
+ attn_mode=attn_mode,
43
+ window_size=window_size,
44
+ pe_mode=pe_mode,
45
+ use_fp16=use_fp16,
46
+ use_checkpoint=use_checkpoint,
47
+ qk_rms_norm=qk_rms_norm,
48
+ )
49
+ self.resolution = resolution
50
+ self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
51
+
52
+ self.initialize_weights()
53
+ if pretrained_slat_enc is not None:
54
+ print (f'loading pretrained weight: {pretrained_slat_enc}')
55
+ pre_trained_models: SLatEncoder
56
+ pre_trained_models = from_pretrained(pretrained_slat_enc)
57
+ self.load_state_dict(pre_trained_models.state_dict())
58
+ del pre_trained_models
59
+ if use_fp16:
60
+ self.convert_to_fp16()
61
+
62
+ def initialize_weights(self) -> None:
63
+ super().initialize_weights()
64
+ # Zero-out output layers:
65
+ nn.init.constant_(self.out_layer.weight, 0)
66
+ nn.init.constant_(self.out_layer.bias, 0)
67
+
68
+ def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
69
+ h = super().forward(x)
70
+ h = h.type(x.dtype)
71
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
72
+ h = self.out_layer(h)
73
+
74
+ # Sample from the posterior distribution
75
+ mean, logvar = h.feats.chunk(2, dim=-1)
76
+ if sample_posterior:
77
+ std = torch.exp(0.5 * logvar)
78
+ z = mean + std * torch.randn_like(std)
79
+ else:
80
+ z = mean
81
+ z = h.replace(z)
82
+
83
+ if return_raw:
84
+ return z, mean, logvar
85
+ else:
86
+ return z
87
+
88
+
89
+ class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
90
+ """
91
+ SLat VAE encoder with elastic memory management.
92
+ Used for training with low VRAM.
93
+ """
threeDFixer/modules/attention/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+
8
+ BACKEND = 'flash_attn'
9
+ DEBUG = False
10
+
11
+ def __from_env():
12
+ import os
13
+
14
+ global BACKEND
15
+ global DEBUG
16
+
17
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
18
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
19
+
20
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
21
+ BACKEND = env_attn_backend
22
+ if env_sttn_debug is not None:
23
+ DEBUG = env_sttn_debug == '1'
24
+
25
+ print(f"[ATTENTION] Using backend: {BACKEND}")
26
+
27
+
28
+ __from_env()
29
+
30
+
31
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
32
+ global BACKEND
33
+ BACKEND = backend
34
+
35
+ def set_debug(debug: bool):
36
+ global DEBUG
37
+ DEBUG = debug
38
+
39
+
40
+ from .full_attn import *
41
+ from .modules import *
threeDFixer/modules/attention/full_attn.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import math
9
+ from . import DEBUG, BACKEND
10
+
11
+ if BACKEND == 'xformers':
12
+ import xformers.ops as xops
13
+ elif BACKEND == 'flash_attn':
14
+ import flash_attn
15
+ elif BACKEND == 'sdpa':
16
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
17
+ elif BACKEND == 'naive':
18
+ pass
19
+ else:
20
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
21
+
22
+
23
+ __all__ = [
24
+ 'scaled_dot_product_attention',
25
+ ]
26
+
27
+
28
+ def _naive_sdpa(q, k, v):
29
+ """
30
+ Naive implementation of scaled dot product attention.
31
+ """
32
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
33
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
34
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
35
+ scale_factor = 1 / math.sqrt(q.size(-1))
36
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
37
+ attn_weight = torch.softmax(attn_weight, dim=-1)
38
+ out = attn_weight @ v
39
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
40
+ return out
41
+
42
+
43
+ @overload
44
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Apply scaled dot product attention.
47
+
48
+ Args:
49
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
50
+ """
51
+ ...
52
+
53
+ @overload
54
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Apply scaled dot product attention.
57
+
58
+ Args:
59
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
60
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
61
+ """
62
+ ...
63
+
64
+ @overload
65
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
66
+ """
67
+ Apply scaled dot product attention.
68
+
69
+ Args:
70
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
71
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
72
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
73
+
74
+ Note:
75
+ k and v are assumed to have the same coordinate map.
76
+ """
77
+ ...
78
+
79
+ def scaled_dot_product_attention(*args, **kwargs):
80
+ arg_names_dict = {
81
+ 1: ['qkv'],
82
+ 2: ['q', 'kv'],
83
+ 3: ['q', 'k', 'v']
84
+ }
85
+ num_all_args = len(args) + len(kwargs)
86
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
87
+ for key in arg_names_dict[num_all_args][len(args):]:
88
+ assert key in kwargs, f"Missing argument {key}"
89
+
90
+ if num_all_args == 1:
91
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
92
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
93
+ device = qkv.device
94
+
95
+ elif num_all_args == 2:
96
+ q = args[0] if len(args) > 0 else kwargs['q']
97
+ kv = args[1] if len(args) > 1 else kwargs['kv']
98
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
99
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
100
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
101
+ device = q.device
102
+
103
+ elif num_all_args == 3:
104
+ q = args[0] if len(args) > 0 else kwargs['q']
105
+ k = args[1] if len(args) > 1 else kwargs['k']
106
+ v = args[2] if len(args) > 2 else kwargs['v']
107
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
108
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
109
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
110
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
111
+ device = q.device
112
+
113
+ if BACKEND == 'xformers':
114
+ if num_all_args == 1:
115
+ q, k, v = qkv.unbind(dim=2)
116
+ elif num_all_args == 2:
117
+ k, v = kv.unbind(dim=2)
118
+ out = xops.memory_efficient_attention(q, k, v)
119
+ elif BACKEND == 'flash_attn':
120
+ if num_all_args == 1:
121
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
122
+ elif num_all_args == 2:
123
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
124
+ elif num_all_args == 3:
125
+ out = flash_attn.flash_attn_func(q, k, v)
126
+ elif BACKEND == 'sdpa':
127
+ if num_all_args == 1:
128
+ q, k, v = qkv.unbind(dim=2)
129
+ elif num_all_args == 2:
130
+ k, v = kv.unbind(dim=2)
131
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
132
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
133
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
134
+ out = sdpa(q, k, v) # [N, H, L, C]
135
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
136
+ elif BACKEND == 'naive':
137
+ if num_all_args == 1:
138
+ q, k, v = qkv.unbind(dim=2)
139
+ elif num_all_args == 2:
140
+ k, v = kv.unbind(dim=2)
141
+ out = _naive_sdpa(q, k, v)
142
+ else:
143
+ raise ValueError(f"Unknown attention module: {BACKEND}")
144
+
145
+ return out
threeDFixer/modules/attention/modules.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from .full_attn import scaled_dot_product_attention
11
+
12
+
13
+ class MultiHeadRMSNorm(nn.Module):
14
+ def __init__(self, dim: int, heads: int):
15
+ super().__init__()
16
+ self.scale = dim ** 0.5
17
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
21
+
22
+
23
+ class RotaryPositionEmbedder(nn.Module):
24
+ def __init__(self, hidden_size: int, in_channels: int = 3):
25
+ super().__init__()
26
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
27
+ self.hidden_size = hidden_size
28
+ self.in_channels = in_channels
29
+ self.freq_dim = hidden_size // in_channels // 2
30
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
31
+ self.freqs = 1.0 / (10000 ** self.freqs)
32
+
33
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
34
+ self.freqs = self.freqs.to(indices.device)
35
+ phases = torch.outer(indices, self.freqs)
36
+ phases = torch.polar(torch.ones_like(phases), phases)
37
+ return phases
38
+
39
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
40
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
41
+ x_rotated = x_complex * phases
42
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
43
+ return x_embed
44
+
45
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
46
+ """
47
+ Args:
48
+ q (sp.SparseTensor): [..., N, D] tensor of queries
49
+ k (sp.SparseTensor): [..., N, D] tensor of keys
50
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
51
+ """
52
+ if indices is None:
53
+ indices = torch.arange(q.shape[-2], device=q.device)
54
+ if len(q.shape) > 2:
55
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
56
+
57
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
58
+ if phases.shape[1] < self.hidden_size // 2:
59
+ phases = torch.cat([phases, torch.polar(
60
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
61
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
62
+ )], dim=-1)
63
+ q_embed = self._rotary_embedding(q, phases)
64
+ k_embed = self._rotary_embedding(k, phases)
65
+ return q_embed, k_embed
66
+
67
+
68
+ class MultiHeadAttention(nn.Module):
69
+ def __init__(
70
+ self,
71
+ channels: int,
72
+ num_heads: int,
73
+ ctx_channels: Optional[int]=None,
74
+ type: Literal["self", "cross"] = "self",
75
+ attn_mode: Literal["full", "windowed"] = "full",
76
+ window_size: Optional[int] = None,
77
+ shift_window: Optional[Tuple[int, int, int]] = None,
78
+ qkv_bias: bool = True,
79
+ use_rope: bool = False,
80
+ qk_rms_norm: bool = False,
81
+ ):
82
+ super().__init__()
83
+ assert channels % num_heads == 0
84
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
85
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
86
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
87
+
88
+ if attn_mode == "windowed":
89
+ raise NotImplementedError("Windowed attention is not yet implemented")
90
+
91
+ self.channels = channels
92
+ self.head_dim = channels // num_heads
93
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
94
+ self.num_heads = num_heads
95
+ self._type = type
96
+ self.attn_mode = attn_mode
97
+ self.window_size = window_size
98
+ self.shift_window = shift_window
99
+ self.use_rope = use_rope
100
+ self.qk_rms_norm = qk_rms_norm
101
+
102
+ if self._type == "self":
103
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
104
+ else:
105
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
106
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
107
+
108
+ if self.qk_rms_norm:
109
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
110
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
111
+
112
+ self.to_out = nn.Linear(channels, channels)
113
+
114
+ if use_rope:
115
+ self.rope = RotaryPositionEmbedder(channels)
116
+
117
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
118
+ B, L, C = x.shape
119
+ if self._type == "self":
120
+ qkv = self.to_qkv(x)
121
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
122
+ if self.use_rope:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ q, k = self.rope(q, k, indices)
125
+ qkv = torch.stack([q, k, v], dim=2)
126
+ if self.attn_mode == "full":
127
+ if self.qk_rms_norm:
128
+ q, k, v = qkv.unbind(dim=2)
129
+ q = self.q_rms_norm(q)
130
+ k = self.k_rms_norm(k)
131
+ h = scaled_dot_product_attention(q, k, v)
132
+ else:
133
+ h = scaled_dot_product_attention(qkv)
134
+ elif self.attn_mode == "windowed":
135
+ raise NotImplementedError("Windowed attention is not yet implemented")
136
+ else:
137
+ Lkv = context.shape[1]
138
+ q = self.to_q(x)
139
+ kv = self.to_kv(context)
140
+ q = q.reshape(B, L, self.num_heads, -1)
141
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
142
+ if self.qk_rms_norm:
143
+ q = self.q_rms_norm(q)
144
+ k, v = kv.unbind(dim=2)
145
+ k = self.k_rms_norm(k)
146
+ h = scaled_dot_product_attention(q, k, v)
147
+ else:
148
+ h = scaled_dot_product_attention(q, kv)
149
+ h = h.reshape(B, L, -1)
150
+ h = self.to_out(h)
151
+ return h
threeDFixer/modules/norm.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class LayerNorm32(nn.LayerNorm):
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ return super().forward(x.float()).type(x.dtype)
13
+
14
+
15
+ class GroupNorm32(nn.GroupNorm):
16
+ """
17
+ A GroupNorm layer that converts to float32 before the forward pass.
18
+ """
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return super().forward(x.float()).type(x.dtype)
21
+
22
+
23
+ class ChannelLayerNorm32(LayerNorm32):
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ DIM = x.dim()
26
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
27
+ x = super().forward(x)
28
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
29
+ return x
30
+
threeDFixer/modules/sparse/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'spconv'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseReLU': 'nonlinearity',
59
+ 'SparseSiLU': 'nonlinearity',
60
+ 'SparseGELU': 'nonlinearity',
61
+ 'SparseActivation': 'nonlinearity',
62
+ 'SparseLinear': 'linear',
63
+ 'sparse_scaled_dot_product_attention': 'attention',
64
+ 'SerializeMode': 'attention',
65
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
+ 'SparseMultiHeadAttention': 'attention',
68
+ 'SparseConv3d': 'conv',
69
+ 'SparseInverseConv3d': 'conv',
70
+ 'SparseDownsample': 'spatial',
71
+ 'SparseUpsample': 'spatial',
72
+ 'SparseSubdivide' : 'spatial'
73
+ }
74
+
75
+ __submodules = ['transformer']
76
+
77
+ __all__ = list(__attributes.keys()) + __submodules
78
+
79
+ def __getattr__(name):
80
+ if name not in globals():
81
+ if name in __attributes:
82
+ module_name = __attributes[name]
83
+ module = importlib.import_module(f".{module_name}", __name__)
84
+ globals()[name] = getattr(module, name)
85
+ elif name in __submodules:
86
+ module = importlib.import_module(f".{name}", __name__)
87
+ globals()[name] = module
88
+ else:
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+ return globals()[name]
91
+
92
+
93
+ # For Pylance
94
+ if __name__ == '__main__':
95
+ from .basic import *
96
+ from .norm import *
97
+ from .nonlinearity import *
98
+ from .linear import *
99
+ from .attention import *
100
+ from .conv import *
101
+ from .spatial import *
102
+ import transformer
threeDFixer/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from .full_attn import *
7
+ from .serialized_attn import *
8
+ from .windowed_attn import *
9
+ from .modules import *
threeDFixer/modules/sparse/attention/full_attn.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ from .. import SparseTensor
9
+ from .. import DEBUG, ATTN
10
+
11
+ if ATTN == 'xformers':
12
+ import xformers.ops as xops
13
+ elif ATTN == 'flash_attn':
14
+ import flash_attn
15
+ else:
16
+ raise ValueError(f"Unknown attention module: {ATTN}")
17
+
18
+
19
+ __all__ = [
20
+ 'sparse_scaled_dot_product_attention',
21
+ ]
22
+
23
+
24
+ @overload
25
+ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
26
+ """
27
+ Apply scaled dot product attention to a sparse tensor.
28
+
29
+ Args:
30
+ qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
31
+ """
32
+ ...
33
+
34
+ @overload
35
+ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
36
+ """
37
+ Apply scaled dot product attention to a sparse tensor.
38
+
39
+ Args:
40
+ q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
41
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
42
+ """
43
+ ...
44
+
45
+ @overload
46
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
47
+ """
48
+ Apply scaled dot product attention to a sparse tensor.
49
+
50
+ Args:
51
+ q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
52
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
53
+ """
54
+ ...
55
+
56
+ @overload
57
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
58
+ """
59
+ Apply scaled dot product attention to a sparse tensor.
60
+
61
+ Args:
62
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
63
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
64
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
65
+
66
+ Note:
67
+ k and v are assumed to have the same coordinate map.
68
+ """
69
+ ...
70
+
71
+ @overload
72
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
73
+ """
74
+ Apply scaled dot product attention to a sparse tensor.
75
+
76
+ Args:
77
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
78
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
79
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
80
+ """
81
+ ...
82
+
83
+ @overload
84
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
85
+ """
86
+ Apply scaled dot product attention to a sparse tensor.
87
+
88
+ Args:
89
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
90
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
91
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
92
+ """
93
+ ...
94
+
95
+ def sparse_scaled_dot_product_attention(*args, **kwargs):
96
+ arg_names_dict = {
97
+ 1: ['qkv'],
98
+ 2: ['q', 'kv'],
99
+ 3: ['q', 'k', 'v']
100
+ }
101
+ num_all_args = len(args) + len(kwargs)
102
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
103
+ for key in arg_names_dict[num_all_args][len(args):]:
104
+ assert key in kwargs, f"Missing argument {key}"
105
+
106
+ if num_all_args == 1:
107
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
108
+ assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
109
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
110
+ device = qkv.device
111
+
112
+ s = qkv
113
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
114
+ kv_seqlen = q_seqlen
115
+ qkv = qkv.feats # [T, 3, H, C]
116
+
117
+ elif num_all_args == 2:
118
+ q = args[0] if len(args) > 0 else kwargs['q']
119
+ kv = args[1] if len(args) > 1 else kwargs['kv']
120
+ assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
121
+ isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
122
+ f"Invalid types, got {type(q)} and {type(kv)}"
123
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
124
+ device = q.device
125
+
126
+ if isinstance(q, SparseTensor):
127
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
128
+ s = q
129
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
130
+ q = q.feats # [T_Q, H, C]
131
+ else:
132
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
133
+ s = None
134
+ N, L, H, C = q.shape
135
+ q_seqlen = [L] * N
136
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
137
+
138
+ if isinstance(kv, SparseTensor):
139
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
140
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
141
+ kv = kv.feats # [T_KV, 2, H, C]
142
+ else:
143
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
144
+ N, L, _, H, C = kv.shape
145
+ kv_seqlen = [L] * N
146
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
147
+
148
+ elif num_all_args == 3:
149
+ q = args[0] if len(args) > 0 else kwargs['q']
150
+ k = args[1] if len(args) > 1 else kwargs['k']
151
+ v = args[2] if len(args) > 2 else kwargs['v']
152
+ assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
153
+ isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
154
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
155
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
156
+ device = q.device
157
+
158
+ if isinstance(q, SparseTensor):
159
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
160
+ s = q
161
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
162
+ q = q.feats # [T_Q, H, Ci]
163
+ else:
164
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
165
+ s = None
166
+ N, L, H, CI = q.shape
167
+ q_seqlen = [L] * N
168
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
169
+
170
+ if isinstance(k, SparseTensor):
171
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
172
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
173
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
174
+ k = k.feats # [T_KV, H, Ci]
175
+ v = v.feats # [T_KV, H, Co]
176
+ else:
177
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
178
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
179
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
180
+ kv_seqlen = [L] * N
181
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
182
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
183
+
184
+ if DEBUG:
185
+ if s is not None:
186
+ for i in range(s.shape[0]):
187
+ assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
188
+ if num_all_args in [2, 3]:
189
+ assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
190
+ if num_all_args == 3:
191
+ assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
192
+ assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
193
+
194
+ if ATTN == 'xformers':
195
+ if num_all_args == 1:
196
+ q, k, v = qkv.unbind(dim=1)
197
+ elif num_all_args == 2:
198
+ k, v = kv.unbind(dim=1)
199
+ q = q.unsqueeze(0)
200
+ k = k.unsqueeze(0)
201
+ v = v.unsqueeze(0)
202
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
203
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
204
+ elif ATTN == 'flash_attn':
205
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
206
+ if num_all_args in [2, 3]:
207
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
208
+ if num_all_args == 1:
209
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
210
+ elif num_all_args == 2:
211
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
212
+ elif num_all_args == 3:
213
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
214
+ else:
215
+ raise ValueError(f"Unknown attention module: {ATTN}")
216
+
217
+ if s is not None:
218
+ return s.replace(out)
219
+ else:
220
+ return out.reshape(N, L, H, -1)
threeDFixer/modules/sparse/attention/modules.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from .. import SparseTensor
11
+ from .full_attn import sparse_scaled_dot_product_attention
12
+ from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
13
+ from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
14
+ from ...attention import RotaryPositionEmbedder
15
+
16
+
17
+ class SparseMultiHeadRMSNorm(nn.Module):
18
+ def __init__(self, dim: int, heads: int):
19
+ super().__init__()
20
+ self.scale = dim ** 0.5
21
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
22
+
23
+ def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
24
+ x_type = x.dtype
25
+ x = x.float()
26
+ if isinstance(x, SparseTensor):
27
+ x = x.replace(F.normalize(x.feats, dim=-1))
28
+ else:
29
+ x = F.normalize(x, dim=-1)
30
+ return (x * self.gamma * self.scale).to(x_type)
31
+
32
+
33
+ class SparseMultiHeadAttention(nn.Module):
34
+ def __init__(
35
+ self,
36
+ channels: int,
37
+ num_heads: int,
38
+ ctx_channels: Optional[int] = None,
39
+ type: Literal["self", "cross"] = "self",
40
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
41
+ window_size: Optional[int] = None,
42
+ shift_sequence: Optional[int] = None,
43
+ shift_window: Optional[Tuple[int, int, int]] = None,
44
+ serialize_mode: Optional[SerializeMode] = None,
45
+ qkv_bias: bool = True,
46
+ use_rope: bool = False,
47
+ qk_rms_norm: bool = False,
48
+ ):
49
+ super().__init__()
50
+ assert channels % num_heads == 0
51
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
52
+ assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
53
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
54
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
55
+ self.channels = channels
56
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
57
+ self.num_heads = num_heads
58
+ self._type = type
59
+ self.attn_mode = attn_mode
60
+ self.window_size = window_size
61
+ self.shift_sequence = shift_sequence
62
+ self.shift_window = shift_window
63
+ self.serialize_mode = serialize_mode
64
+ self.use_rope = use_rope
65
+ self.qk_rms_norm = qk_rms_norm
66
+
67
+ if self._type == "self":
68
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
69
+ else:
70
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
71
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
72
+
73
+ if self.qk_rms_norm:
74
+ self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
75
+ self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
76
+
77
+ self.to_out = nn.Linear(channels, channels)
78
+
79
+ if use_rope:
80
+ self.rope = RotaryPositionEmbedder(channels)
81
+
82
+ @staticmethod
83
+ def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
84
+ if isinstance(x, SparseTensor):
85
+ return x.replace(module(x.feats))
86
+ else:
87
+ return module(x)
88
+
89
+ @staticmethod
90
+ def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
91
+ if isinstance(x, SparseTensor):
92
+ return x.reshape(*shape)
93
+ else:
94
+ return x.reshape(*x.shape[:2], *shape)
95
+
96
+ def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
97
+ if isinstance(x, SparseTensor):
98
+ x_feats = x.feats.unsqueeze(0)
99
+ else:
100
+ x_feats = x
101
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
102
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
103
+
104
+ def _rope(self, qkv: SparseTensor) -> SparseTensor:
105
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
106
+ q, k = self.rope(q, k, qkv.coords[:, 1:])
107
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
108
+ return qkv
109
+
110
+ def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
111
+ if self._type == "self":
112
+ qkv = self._linear(self.to_qkv, x)
113
+ qkv = self._fused_pre(qkv, num_fused=3)
114
+ if self.use_rope:
115
+ qkv = self._rope(qkv)
116
+ if self.qk_rms_norm:
117
+ q, k, v = qkv.unbind(dim=1)
118
+ q = self.q_rms_norm(q)
119
+ k = self.k_rms_norm(k)
120
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
121
+ if self.attn_mode == "full":
122
+ h = sparse_scaled_dot_product_attention(qkv)
123
+ elif self.attn_mode == "serialized":
124
+ h = sparse_serialized_scaled_dot_product_self_attention(
125
+ qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
126
+ )
127
+ elif self.attn_mode == "windowed":
128
+ h = sparse_windowed_scaled_dot_product_self_attention(
129
+ qkv, self.window_size, shift_window=self.shift_window
130
+ )
131
+ else:
132
+ q = self._linear(self.to_q, x)
133
+ q = self._reshape_chs(q, (self.num_heads, -1))
134
+ kv = self._linear(self.to_kv, context)
135
+ kv = self._fused_pre(kv, num_fused=2)
136
+ if self.qk_rms_norm:
137
+ q = self.q_rms_norm(q)
138
+ k, v = kv.unbind(dim=1)
139
+ k = self.k_rms_norm(k)
140
+ kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
141
+ h = sparse_scaled_dot_product_attention(q, kv)
142
+ h = self._reshape_chs(h, (-1,))
143
+ h = self._linear(self.to_out, h)
144
+ return h
threeDFixer/modules/sparse/attention/serialized_attn.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ from enum import Enum
8
+ import torch
9
+ import math
10
+ from .. import SparseTensor
11
+ from .. import DEBUG, ATTN
12
+
13
+ if ATTN == 'xformers':
14
+ import xformers.ops as xops
15
+ elif ATTN == 'flash_attn':
16
+ import flash_attn
17
+ else:
18
+ raise ValueError(f"Unknown attention module: {ATTN}")
19
+
20
+
21
+ __all__ = [
22
+ 'sparse_serialized_scaled_dot_product_self_attention',
23
+ ]
24
+
25
+
26
+ class SerializeMode(Enum):
27
+ Z_ORDER = 0
28
+ Z_ORDER_TRANSPOSED = 1
29
+ HILBERT = 2
30
+ HILBERT_TRANSPOSED = 3
31
+
32
+
33
+ SerializeModes = [
34
+ SerializeMode.Z_ORDER,
35
+ SerializeMode.Z_ORDER_TRANSPOSED,
36
+ SerializeMode.HILBERT,
37
+ SerializeMode.HILBERT_TRANSPOSED
38
+ ]
39
+
40
+
41
+ def calc_serialization(
42
+ tensor: SparseTensor,
43
+ window_size: int,
44
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
45
+ shift_sequence: int = 0,
46
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
47
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
48
+ """
49
+ Calculate serialization and partitioning for a set of coordinates.
50
+
51
+ Args:
52
+ tensor (SparseTensor): The input tensor.
53
+ window_size (int): The window size to use.
54
+ serialize_mode (SerializeMode): The serialization mode to use.
55
+ shift_sequence (int): The shift of serialized sequence.
56
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
57
+
58
+ Returns:
59
+ (torch.Tensor, torch.Tensor): Forwards and backwards indices.
60
+ """
61
+ fwd_indices = []
62
+ bwd_indices = []
63
+ seq_lens = []
64
+ seq_batch_indices = []
65
+ offsets = [0]
66
+
67
+ if 'vox2seq' not in globals():
68
+ import vox2seq
69
+
70
+ # Serialize the input
71
+ serialize_coords = tensor.coords[:, 1:].clone()
72
+ serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
73
+ if serialize_mode == SerializeMode.Z_ORDER:
74
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
75
+ elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
76
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
77
+ elif serialize_mode == SerializeMode.HILBERT:
78
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
79
+ elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
80
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
81
+ else:
82
+ raise ValueError(f"Unknown serialize mode: {serialize_mode}")
83
+
84
+ for bi, s in enumerate(tensor.layout):
85
+ num_points = s.stop - s.start
86
+ num_windows = (num_points + window_size - 1) // window_size
87
+ valid_window_size = num_points / num_windows
88
+ to_ordered = torch.argsort(code[s.start:s.stop])
89
+ if num_windows == 1:
90
+ fwd_indices.append(to_ordered)
91
+ bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
92
+ fwd_indices[-1] += s.start
93
+ bwd_indices[-1] += offsets[-1]
94
+ seq_lens.append(num_points)
95
+ seq_batch_indices.append(bi)
96
+ offsets.append(offsets[-1] + seq_lens[-1])
97
+ else:
98
+ # Partition the input
99
+ offset = 0
100
+ mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
101
+ split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
102
+ bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
103
+ for i in range(num_windows):
104
+ mid = mids[i]
105
+ valid_start = split[i]
106
+ valid_end = split[i + 1]
107
+ padded_start = math.floor(mid - 0.5 * window_size)
108
+ padded_end = padded_start + window_size
109
+ fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
110
+ offset += valid_start - padded_start
111
+ bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
112
+ offset += padded_end - valid_start
113
+ fwd_indices[-1] += s.start
114
+ seq_lens.extend([window_size] * num_windows)
115
+ seq_batch_indices.extend([bi] * num_windows)
116
+ bwd_indices.append(bwd_index + offsets[-1])
117
+ offsets.append(offsets[-1] + num_windows * window_size)
118
+
119
+ fwd_indices = torch.cat(fwd_indices)
120
+ bwd_indices = torch.cat(bwd_indices)
121
+
122
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
123
+
124
+
125
+ def sparse_serialized_scaled_dot_product_self_attention(
126
+ qkv: SparseTensor,
127
+ window_size: int,
128
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
129
+ shift_sequence: int = 0,
130
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
131
+ ) -> SparseTensor:
132
+ """
133
+ Apply serialized scaled dot product self attention to a sparse tensor.
134
+
135
+ Args:
136
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
137
+ window_size (int): The window size to use.
138
+ serialize_mode (SerializeMode): The serialization mode to use.
139
+ shift_sequence (int): The shift of serialized sequence.
140
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
141
+ shift (int): The shift to use.
142
+ """
143
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
144
+
145
+ serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
146
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
147
+ if serialization_spatial_cache is None:
148
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
149
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
150
+ else:
151
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
152
+
153
+ M = fwd_indices.shape[0]
154
+ T = qkv.feats.shape[0]
155
+ H = qkv.feats.shape[2]
156
+ C = qkv.feats.shape[3]
157
+
158
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
159
+
160
+ if DEBUG:
161
+ start = 0
162
+ qkv_coords = qkv.coords[fwd_indices]
163
+ for i in range(len(seq_lens)):
164
+ assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
165
+ start += seq_lens[i]
166
+
167
+ if all([seq_len == window_size for seq_len in seq_lens]):
168
+ B = len(seq_lens)
169
+ N = window_size
170
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
171
+ if ATTN == 'xformers':
172
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
173
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
174
+ elif ATTN == 'flash_attn':
175
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
176
+ else:
177
+ raise ValueError(f"Unknown attention module: {ATTN}")
178
+ out = out.reshape(B * N, H, C) # [M, H, C]
179
+ else:
180
+ if ATTN == 'xformers':
181
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
182
+ q = q.unsqueeze(0) # [1, M, H, C]
183
+ k = k.unsqueeze(0) # [1, M, H, C]
184
+ v = v.unsqueeze(0) # [1, M, H, C]
185
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
186
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
187
+ elif ATTN == 'flash_attn':
188
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
189
+ .to(qkv.device).int()
190
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
191
+
192
+ out = out[bwd_indices] # [T, H, C]
193
+
194
+ if DEBUG:
195
+ qkv_coords = qkv_coords[bwd_indices]
196
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
197
+
198
+ return qkv.replace(out)
threeDFixer/modules/sparse/attention/windowed_attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import math
9
+ from .. import SparseTensor
10
+ from .. import DEBUG, ATTN
11
+
12
+ if ATTN == 'xformers':
13
+ import xformers.ops as xops
14
+ elif ATTN == 'flash_attn':
15
+ import flash_attn
16
+ else:
17
+ raise ValueError(f"Unknown attention module: {ATTN}")
18
+
19
+
20
+ __all__ = [
21
+ 'sparse_windowed_scaled_dot_product_self_attention',
22
+ ]
23
+
24
+
25
+ def calc_window_partition(
26
+ tensor: SparseTensor,
27
+ window_size: Union[int, Tuple[int, ...]],
28
+ shift_window: Union[int, Tuple[int, ...]] = 0
29
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
30
+ """
31
+ Calculate serialization and partitioning for a set of coordinates.
32
+
33
+ Args:
34
+ tensor (SparseTensor): The input tensor.
35
+ window_size (int): The window size to use.
36
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
37
+
38
+ Returns:
39
+ (torch.Tensor): Forwards indices.
40
+ (torch.Tensor): Backwards indices.
41
+ (List[int]): Sequence lengths.
42
+ (List[int]): Sequence batch indices.
43
+ """
44
+ DIM = tensor.coords.shape[1] - 1
45
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
46
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
47
+ shifted_coords = tensor.coords.clone().detach()
48
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
49
+
50
+ MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
51
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
52
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
53
+
54
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
55
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
56
+ fwd_indices = torch.argsort(shifted_indices)
57
+ bwd_indices = torch.empty_like(fwd_indices)
58
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
59
+ seq_lens = torch.bincount(shifted_indices)
60
+ seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
61
+ mask = seq_lens != 0
62
+ seq_lens = seq_lens[mask].tolist()
63
+ seq_batch_indices = seq_batch_indices[mask].tolist()
64
+
65
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
66
+
67
+
68
+ def sparse_windowed_scaled_dot_product_self_attention(
69
+ qkv: SparseTensor,
70
+ window_size: int,
71
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
72
+ ) -> SparseTensor:
73
+ """
74
+ Apply windowed scaled dot product self attention to a sparse tensor.
75
+
76
+ Args:
77
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
78
+ window_size (int): The window size to use.
79
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
80
+ shift (int): The shift to use.
81
+ """
82
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
83
+
84
+ serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
85
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
86
+ if serialization_spatial_cache is None:
87
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
88
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
89
+ else:
90
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
91
+
92
+ M = fwd_indices.shape[0]
93
+ T = qkv.feats.shape[0]
94
+ H = qkv.feats.shape[2]
95
+ C = qkv.feats.shape[3]
96
+
97
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
98
+
99
+ if DEBUG:
100
+ start = 0
101
+ qkv_coords = qkv.coords[fwd_indices]
102
+ for i in range(len(seq_lens)):
103
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
104
+ assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
105
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
106
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
107
+ start += seq_lens[i]
108
+
109
+ if all([seq_len == window_size for seq_len in seq_lens]):
110
+ B = len(seq_lens)
111
+ N = window_size
112
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
113
+ if ATTN == 'xformers':
114
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
115
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
116
+ elif ATTN == 'flash_attn':
117
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
118
+ else:
119
+ raise ValueError(f"Unknown attention module: {ATTN}")
120
+ out = out.reshape(B * N, H, C) # [M, H, C]
121
+ else:
122
+ if ATTN == 'xformers':
123
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
124
+ q = q.unsqueeze(0) # [1, M, H, C]
125
+ k = k.unsqueeze(0) # [1, M, H, C]
126
+ v = v.unsqueeze(0) # [1, M, H, C]
127
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
128
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
129
+ elif ATTN == 'flash_attn':
130
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
131
+ .to(qkv.device).int()
132
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
133
+
134
+ out = out[bwd_indices] # [T, H, C]
135
+
136
+ if DEBUG:
137
+ qkv_coords = qkv_coords[bwd_indices]
138
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
139
+
140
+ return qkv.replace(out)
threeDFixer/modules/sparse/basic.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ from . import BACKEND, DEBUG
10
+ SparseTensorData = None # Lazy import
11
+
12
+
13
+ __all__ = [
14
+ 'SparseTensor',
15
+ 'sparse_batch_broadcast',
16
+ 'sparse_batch_op',
17
+ 'sparse_cat',
18
+ 'sparse_unbind',
19
+ ]
20
+
21
+
22
+ class SparseTensor:
23
+ """
24
+ Sparse tensor with support for both torchsparse and spconv backends.
25
+
26
+ Parameters:
27
+ - feats (torch.Tensor): Features of the sparse tensor.
28
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
29
+ - shape (torch.Size): Shape of the sparse tensor.
30
+ - layout (List[slice]): Layout of the sparse tensor for each batch
31
+ - data (SparseTensorData): Sparse tensor data used for convolusion
32
+
33
+ NOTE:
34
+ - Data corresponding to a same batch should be contiguous.
35
+ - Coords should be in [0, 1023]
36
+ """
37
+ @overload
38
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
39
+
40
+ @overload
41
+ def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
42
+
43
+ def __init__(self, *args, **kwargs):
44
+ # Lazy import of sparse tensor backend
45
+ global SparseTensorData
46
+ if SparseTensorData is None:
47
+ import importlib
48
+ if BACKEND == 'torchsparse':
49
+ SparseTensorData = importlib.import_module('torchsparse').SparseTensor
50
+ elif BACKEND == 'spconv':
51
+ SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
52
+
53
+ method_id = 0
54
+ if len(args) != 0:
55
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
56
+ else:
57
+ method_id = 1 if 'data' in kwargs else 0
58
+
59
+ if method_id == 0:
60
+ feats, coords, shape, layout = args + (None,) * (4 - len(args))
61
+ if 'feats' in kwargs:
62
+ feats = kwargs['feats']
63
+ del kwargs['feats']
64
+ if 'coords' in kwargs:
65
+ coords = kwargs['coords']
66
+ del kwargs['coords']
67
+ if 'shape' in kwargs:
68
+ shape = kwargs['shape']
69
+ del kwargs['shape']
70
+ if 'layout' in kwargs:
71
+ layout = kwargs['layout']
72
+ del kwargs['layout']
73
+
74
+ if shape is None:
75
+ shape = self.__cal_shape(feats, coords)
76
+ if layout is None:
77
+ layout = self.__cal_layout(coords, shape[0])
78
+ if BACKEND == 'torchsparse':
79
+ self.data = SparseTensorData(feats, coords, **kwargs)
80
+ elif BACKEND == 'spconv':
81
+ spatial_shape = list(coords.max(0)[0] + 1)[1:]
82
+ self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
83
+ self.data._features = feats
84
+ elif method_id == 1:
85
+ data, shape, layout = args + (None,) * (3 - len(args))
86
+ if 'data' in kwargs:
87
+ data = kwargs['data']
88
+ del kwargs['data']
89
+ if 'shape' in kwargs:
90
+ shape = kwargs['shape']
91
+ del kwargs['shape']
92
+ if 'layout' in kwargs:
93
+ layout = kwargs['layout']
94
+ del kwargs['layout']
95
+
96
+ self.data = data
97
+ if shape is None:
98
+ shape = self.__cal_shape(self.feats, self.coords)
99
+ if layout is None:
100
+ layout = self.__cal_layout(self.coords, shape[0])
101
+
102
+ self._shape = shape
103
+ self._layout = layout
104
+ self._scale = kwargs.get('scale', (1, 1, 1))
105
+ self._spatial_cache = kwargs.get('spatial_cache', {})
106
+
107
+ if DEBUG:
108
+ try:
109
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
110
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
111
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
112
+ for i in range(self.shape[0]):
113
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
114
+ except Exception as e:
115
+ print('Debugging information:')
116
+ print(f"- Shape: {self.shape}")
117
+ print(f"- Layout: {self.layout}")
118
+ print(f"- Scale: {self._scale}")
119
+ print(f"- Coords: {self.coords}")
120
+ raise e
121
+
122
+ def __cal_shape(self, feats, coords):
123
+ shape = []
124
+ shape.append(coords[:, 0].max().item() + 1)
125
+ shape.extend([*feats.shape[1:]])
126
+ return torch.Size(shape)
127
+
128
+ def __cal_layout(self, coords, batch_size):
129
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
130
+ offset = torch.cumsum(seq_len, dim=0)
131
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
132
+ return layout
133
+
134
+ @property
135
+ def shape(self) -> torch.Size:
136
+ return self._shape
137
+
138
+ def dim(self) -> int:
139
+ return len(self.shape)
140
+
141
+ @property
142
+ def layout(self) -> List[slice]:
143
+ return self._layout
144
+
145
+ @property
146
+ def feats(self) -> torch.Tensor:
147
+ if BACKEND == 'torchsparse':
148
+ return self.data.F
149
+ elif BACKEND == 'spconv':
150
+ return self.data.features
151
+
152
+ @feats.setter
153
+ def feats(self, value: torch.Tensor):
154
+ if BACKEND == 'torchsparse':
155
+ self.data.F = value
156
+ elif BACKEND == 'spconv':
157
+ self.data.features = value
158
+
159
+ @property
160
+ def coords(self) -> torch.Tensor:
161
+ if BACKEND == 'torchsparse':
162
+ return self.data.C
163
+ elif BACKEND == 'spconv':
164
+ return self.data.indices
165
+
166
+ @coords.setter
167
+ def coords(self, value: torch.Tensor):
168
+ if BACKEND == 'torchsparse':
169
+ self.data.C = value
170
+ elif BACKEND == 'spconv':
171
+ self.data.indices = value
172
+
173
+ @property
174
+ def dtype(self):
175
+ return self.feats.dtype
176
+
177
+ @property
178
+ def device(self):
179
+ return self.feats.device
180
+
181
+ @overload
182
+ def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
183
+
184
+ @overload
185
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
186
+
187
+ def to(self, *args, **kwargs) -> 'SparseTensor':
188
+ device = None
189
+ dtype = None
190
+ if len(args) == 2:
191
+ device, dtype = args
192
+ elif len(args) == 1:
193
+ if isinstance(args[0], torch.dtype):
194
+ dtype = args[0]
195
+ else:
196
+ device = args[0]
197
+ if 'dtype' in kwargs:
198
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
199
+ dtype = kwargs['dtype']
200
+ if 'device' in kwargs:
201
+ assert device is None, "to() received multiple values for argument 'device'"
202
+ device = kwargs['device']
203
+
204
+ new_feats = self.feats.to(device=device, dtype=dtype)
205
+ new_coords = self.coords.to(device=device)
206
+ return self.replace(new_feats, new_coords)
207
+
208
+ def type(self, dtype):
209
+ new_feats = self.feats.type(dtype)
210
+ return self.replace(new_feats)
211
+
212
+ def cpu(self) -> 'SparseTensor':
213
+ new_feats = self.feats.cpu()
214
+ new_coords = self.coords.cpu()
215
+ return self.replace(new_feats, new_coords)
216
+
217
+ def cuda(self) -> 'SparseTensor':
218
+ new_feats = self.feats.cuda()
219
+ new_coords = self.coords.cuda()
220
+ return self.replace(new_feats, new_coords)
221
+
222
+ def half(self) -> 'SparseTensor':
223
+ new_feats = self.feats.half()
224
+ return self.replace(new_feats)
225
+
226
+ def float(self) -> 'SparseTensor':
227
+ new_feats = self.feats.float()
228
+ return self.replace(new_feats)
229
+
230
+ def detach(self) -> 'SparseTensor':
231
+ new_coords = self.coords.detach()
232
+ new_feats = self.feats.detach()
233
+ return self.replace(new_feats, new_coords)
234
+
235
+ def dense(self) -> torch.Tensor:
236
+ if BACKEND == 'torchsparse':
237
+ return self.data.dense()
238
+ elif BACKEND == 'spconv':
239
+ return self.data.dense()
240
+
241
+ def reshape(self, *shape) -> 'SparseTensor':
242
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
243
+ return self.replace(new_feats)
244
+
245
+ def unbind(self, dim: int) -> List['SparseTensor']:
246
+ return sparse_unbind(self, dim)
247
+
248
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
249
+ new_shape = [self.shape[0]]
250
+ new_shape.extend(feats.shape[1:])
251
+ if BACKEND == 'torchsparse':
252
+ new_data = SparseTensorData(
253
+ feats=feats,
254
+ coords=self.data.coords if coords is None else coords,
255
+ stride=self.data.stride,
256
+ spatial_range=self.data.spatial_range,
257
+ )
258
+ new_data._caches = self.data._caches
259
+ elif BACKEND == 'spconv':
260
+ new_data = SparseTensorData(
261
+ self.data.features.reshape(self.data.features.shape[0], -1),
262
+ self.data.indices,
263
+ self.data.spatial_shape,
264
+ self.data.batch_size,
265
+ self.data.grid,
266
+ self.data.voxel_num,
267
+ self.data.indice_dict
268
+ )
269
+ new_data._features = feats
270
+ new_data.benchmark = self.data.benchmark
271
+ new_data.benchmark_record = self.data.benchmark_record
272
+ new_data.thrust_allocator = self.data.thrust_allocator
273
+ new_data._timer = self.data._timer
274
+ new_data.force_algo = self.data.force_algo
275
+ new_data.int8_scale = self.data.int8_scale
276
+ if coords is not None:
277
+ new_data.indices = coords
278
+ new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
279
+ return new_tensor
280
+
281
+ @staticmethod
282
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
283
+ N, C = dim
284
+ x = torch.arange(aabb[0], aabb[3] + 1)
285
+ y = torch.arange(aabb[1], aabb[4] + 1)
286
+ z = torch.arange(aabb[2], aabb[5] + 1)
287
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
288
+ coords = torch.cat([
289
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
290
+ coords.repeat(N, 1),
291
+ ], dim=1).to(dtype=torch.int32, device=device)
292
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
293
+ return SparseTensor(feats=feats, coords=coords)
294
+
295
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
296
+ new_cache = {}
297
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
298
+ if k in self._spatial_cache:
299
+ new_cache[k] = self._spatial_cache[k]
300
+ if k in other._spatial_cache:
301
+ if k not in new_cache:
302
+ new_cache[k] = other._spatial_cache[k]
303
+ else:
304
+ new_cache[k].update(other._spatial_cache[k])
305
+ return new_cache
306
+
307
+ def __neg__(self) -> 'SparseTensor':
308
+ return self.replace(-self.feats)
309
+
310
+ def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
311
+ if isinstance(other, torch.Tensor):
312
+ try:
313
+ other = torch.broadcast_to(other, self.shape)
314
+ other = sparse_batch_broadcast(self, other)
315
+ except:
316
+ pass
317
+ if isinstance(other, SparseTensor):
318
+ other = other.feats
319
+ new_feats = op(self.feats, other)
320
+ new_tensor = self.replace(new_feats)
321
+ if isinstance(other, SparseTensor):
322
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
323
+ return new_tensor
324
+
325
+ def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
326
+ return self.__elemwise__(other, torch.add)
327
+
328
+ def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
329
+ return self.__elemwise__(other, torch.add)
330
+
331
+ def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
332
+ return self.__elemwise__(other, torch.sub)
333
+
334
+ def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
335
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
336
+
337
+ def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
338
+ return self.__elemwise__(other, torch.mul)
339
+
340
+ def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
341
+ return self.__elemwise__(other, torch.mul)
342
+
343
+ def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
344
+ return self.__elemwise__(other, torch.div)
345
+
346
+ def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
347
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
348
+
349
+ def __getitem__(self, idx):
350
+ if isinstance(idx, int):
351
+ idx = [idx]
352
+ elif isinstance(idx, slice):
353
+ idx = range(*idx.indices(self.shape[0]))
354
+ elif isinstance(idx, torch.Tensor):
355
+ if idx.dtype == torch.bool:
356
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
357
+ idx = idx.nonzero().squeeze(1)
358
+ elif idx.dtype in [torch.int32, torch.int64]:
359
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
360
+ else:
361
+ raise ValueError(f"Unknown index type: {idx.dtype}")
362
+ else:
363
+ raise ValueError(f"Unknown index type: {type(idx)}")
364
+
365
+ coords = []
366
+ feats = []
367
+ for new_idx, old_idx in enumerate(idx):
368
+ coords.append(self.coords[self.layout[old_idx]].clone())
369
+ coords[-1][:, 0] = new_idx
370
+ feats.append(self.feats[self.layout[old_idx]])
371
+ coords = torch.cat(coords, dim=0).contiguous()
372
+ feats = torch.cat(feats, dim=0).contiguous()
373
+ return SparseTensor(feats=feats, coords=coords)
374
+
375
+ def register_spatial_cache(self, key, value) -> None:
376
+ """
377
+ Register a spatial cache.
378
+ The spatial cache can be any thing you want to cache.
379
+ The registery and retrieval of the cache is based on current scale.
380
+ """
381
+ scale_key = str(self._scale)
382
+ if scale_key not in self._spatial_cache:
383
+ self._spatial_cache[scale_key] = {}
384
+ self._spatial_cache[scale_key][key] = value
385
+
386
+ def get_spatial_cache(self, key=None):
387
+ """
388
+ Get a spatial cache.
389
+ """
390
+ scale_key = str(self._scale)
391
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
392
+ if key is None:
393
+ return cur_scale_cache
394
+ return cur_scale_cache.get(key, None)
395
+
396
+
397
+ def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
398
+ """
399
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
400
+
401
+ Args:
402
+ input (torch.Tensor): 1D tensor to broadcast.
403
+ target (SparseTensor): Sparse tensor to broadcast to.
404
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
405
+ """
406
+ coords, feats = input.coords, input.feats
407
+ broadcasted = torch.zeros_like(feats)
408
+ for k in range(input.shape[0]):
409
+ broadcasted[input.layout[k]] = other[k]
410
+ return broadcasted
411
+
412
+
413
+ def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
414
+ """
415
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
416
+
417
+ Args:
418
+ input (torch.Tensor): 1D tensor to broadcast.
419
+ target (SparseTensor): Sparse tensor to broadcast to.
420
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
421
+ """
422
+ return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
423
+
424
+
425
+ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
426
+ """
427
+ Concatenate a list of sparse tensors.
428
+
429
+ Args:
430
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
431
+ """
432
+ if dim == 0:
433
+ start = 0
434
+ coords = []
435
+ for input in inputs:
436
+ coords.append(input.coords.clone())
437
+ coords[-1][:, 0] += start
438
+ start += input.shape[0]
439
+ coords = torch.cat(coords, dim=0)
440
+ feats = torch.cat([input.feats for input in inputs], dim=0)
441
+ output = SparseTensor(
442
+ coords=coords,
443
+ feats=feats,
444
+ )
445
+ else:
446
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
447
+ output = inputs[0].replace(feats)
448
+
449
+ return output
450
+
451
+
452
+ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
453
+ """
454
+ Unbind a sparse tensor along a dimension.
455
+
456
+ Args:
457
+ input (SparseTensor): Sparse tensor to unbind.
458
+ dim (int): Dimension to unbind.
459
+ """
460
+ if dim == 0:
461
+ return [input[i] for i in range(input.shape[0])]
462
+ else:
463
+ feats = input.feats.unbind(dim)
464
+ return [input.replace(f) for f in feats]
threeDFixer/modules/sparse/conv/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from .. import BACKEND
7
+
8
+
9
+ SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
10
+
11
+ def __from_env():
12
+ import os
13
+
14
+ global SPCONV_ALGO
15
+ env_spconv_algo = os.environ.get('SPCONV_ALGO')
16
+ if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
17
+ SPCONV_ALGO = env_spconv_algo
18
+ print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
19
+
20
+
21
+ __from_env()
22
+
23
+ if BACKEND == 'torchsparse':
24
+ from .conv_torchsparse import *
25
+ elif BACKEND == 'spconv':
26
+ from .conv_spconv import *
threeDFixer/modules/sparse/conv/conv_spconv.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from .. import SparseTensor
9
+ from .. import DEBUG
10
+ from . import SPCONV_ALGO
11
+
12
+ class SparseConv3d(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
14
+ super(SparseConv3d, self).__init__()
15
+ if 'spconv' not in globals():
16
+ import spconv.pytorch as spconv
17
+ algo = None
18
+ if SPCONV_ALGO == 'native':
19
+ algo = spconv.ConvAlgo.Native
20
+ elif SPCONV_ALGO == 'implicit_gemm':
21
+ algo = spconv.ConvAlgo.MaskImplicitGemm
22
+ if stride == 1 and (padding is None):
23
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
24
+ else:
25
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
26
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
27
+ self.padding = padding
28
+
29
+ def forward(self, x: SparseTensor) -> SparseTensor:
30
+ spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
31
+ new_data = self.conv(x.data)
32
+ new_shape = [x.shape[0], self.conv.out_channels]
33
+ new_layout = None if spatial_changed else x.layout
34
+
35
+ if spatial_changed and (x.shape[0] != 1):
36
+ # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
37
+ fwd = new_data.indices[:, 0].argsort()
38
+ bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
39
+ sorted_feats = new_data.features[fwd]
40
+ sorted_coords = new_data.indices[fwd]
41
+ unsorted_data = new_data
42
+ new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
43
+
44
+ out = SparseTensor(
45
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
46
+ scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
47
+ spatial_cache=x._spatial_cache,
48
+ )
49
+
50
+ if spatial_changed and (x.shape[0] != 1):
51
+ out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
52
+ out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
53
+
54
+ return out
55
+
56
+
57
+ class SparseInverseConv3d(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
59
+ super(SparseInverseConv3d, self).__init__()
60
+ if 'spconv' not in globals():
61
+ import spconv.pytorch as spconv
62
+ self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
63
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
64
+
65
+ def forward(self, x: SparseTensor) -> SparseTensor:
66
+ spatial_changed = any(s != 1 for s in self.stride)
67
+ if spatial_changed:
68
+ # recover the original spconv order
69
+ data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
70
+ bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
71
+ data = data.replace_feature(x.feats[bwd])
72
+ if DEBUG:
73
+ assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
74
+ else:
75
+ data = x.data
76
+
77
+ new_data = self.conv(data)
78
+ new_shape = [x.shape[0], self.conv.out_channels]
79
+ new_layout = None if spatial_changed else x.layout
80
+ out = SparseTensor(
81
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
82
+ scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
83
+ spatial_cache=x._spatial_cache,
84
+ )
85
+ return out
threeDFixer/modules/sparse/conv/conv_torchsparse.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from .. import SparseTensor
9
+
10
+
11
+ class SparseConv3d(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
13
+ super(SparseConv3d, self).__init__()
14
+ if 'torchsparse' not in globals():
15
+ import torchsparse
16
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
17
+
18
+ def forward(self, x: SparseTensor) -> SparseTensor:
19
+ out = self.conv(x.data)
20
+ new_shape = [x.shape[0], self.conv.out_channels]
21
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
22
+ out._spatial_cache = x._spatial_cache
23
+ out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
24
+ return out
25
+
26
+
27
+ class SparseInverseConv3d(nn.Module):
28
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
29
+ super(SparseInverseConv3d, self).__init__()
30
+ if 'torchsparse' not in globals():
31
+ import torchsparse
32
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
33
+
34
+ def forward(self, x: SparseTensor) -> SparseTensor:
35
+ out = self.conv(x.data)
36
+ new_shape = [x.shape[0], self.conv.out_channels]
37
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
38
+ out._spatial_cache = x._spatial_cache
39
+ out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
40
+ return out
41
+
42
+
43
+
threeDFixer/modules/sparse/linear.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from . import SparseTensor
9
+
10
+ __all__ = [
11
+ 'SparseLinear'
12
+ ]
13
+
14
+
15
+ class SparseLinear(nn.Linear):
16
+ def __init__(self, in_features, out_features, bias=True):
17
+ super(SparseLinear, self).__init__(in_features, out_features, bias)
18
+
19
+ def forward(self, input: SparseTensor) -> SparseTensor:
20
+ return input.replace(super().forward(input.feats))
threeDFixer/modules/sparse/nonlinearity.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from . import SparseTensor
9
+
10
+ __all__ = [
11
+ 'SparseReLU',
12
+ 'SparseSiLU',
13
+ 'SparseGELU',
14
+ 'SparseActivation'
15
+ ]
16
+
17
+
18
+ class SparseReLU(nn.ReLU):
19
+ def forward(self, input: SparseTensor) -> SparseTensor:
20
+ return input.replace(super().forward(input.feats))
21
+
22
+
23
+ class SparseSiLU(nn.SiLU):
24
+ def forward(self, input: SparseTensor) -> SparseTensor:
25
+ return input.replace(super().forward(input.feats))
26
+
27
+
28
+ class SparseGELU(nn.GELU):
29
+ def forward(self, input: SparseTensor) -> SparseTensor:
30
+ return input.replace(super().forward(input.feats))
31
+
32
+
33
+ class SparseActivation(nn.Module):
34
+ def __init__(self, activation: nn.Module):
35
+ super().__init__()
36
+ self.activation = activation
37
+
38
+ def forward(self, input: SparseTensor) -> SparseTensor:
39
+ return input.replace(self.activation(input.feats))
40
+
threeDFixer/modules/sparse/norm.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from . import SparseTensor
9
+ from . import DEBUG
10
+
11
+ __all__ = [
12
+ 'SparseGroupNorm',
13
+ 'SparseLayerNorm',
14
+ 'SparseGroupNorm32',
15
+ 'SparseLayerNorm32',
16
+ ]
17
+
18
+
19
+ class SparseGroupNorm(nn.GroupNorm):
20
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
21
+ super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
22
+
23
+ def forward(self, input: SparseTensor) -> SparseTensor:
24
+ nfeats = torch.zeros_like(input.feats)
25
+ for k in range(input.shape[0]):
26
+ if DEBUG:
27
+ assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
28
+ bfeats = input.feats[input.layout[k]]
29
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
30
+ bfeats = super().forward(bfeats)
31
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
32
+ nfeats[input.layout[k]] = bfeats
33
+ return input.replace(nfeats)
34
+
35
+
36
+ class SparseLayerNorm(nn.LayerNorm):
37
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
38
+ super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
39
+
40
+ def forward(self, input: SparseTensor) -> SparseTensor:
41
+ nfeats = torch.zeros_like(input.feats)
42
+ for k in range(input.shape[0]):
43
+ bfeats = input.feats[input.layout[k]]
44
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
45
+ bfeats = super().forward(bfeats)
46
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
47
+ nfeats[input.layout[k]] = bfeats
48
+ return input.replace(nfeats)
49
+
50
+
51
+ class SparseGroupNorm32(SparseGroupNorm):
52
+ """
53
+ A GroupNorm layer that converts to float32 before the forward pass.
54
+ """
55
+ def forward(self, x: SparseTensor) -> SparseTensor:
56
+ return super().forward(x.float()).type(x.dtype)
57
+
58
+ class SparseLayerNorm32(SparseLayerNorm):
59
+ """
60
+ A LayerNorm layer that converts to float32 before the forward pass.
61
+ """
62
+ def forward(self, x: SparseTensor) -> SparseTensor:
63
+ return super().forward(x.float()).type(x.dtype)
threeDFixer/modules/sparse/spatial.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ from . import SparseTensor
10
+
11
+ __all__ = [
12
+ 'SparseDownsample',
13
+ 'SparseUpsample',
14
+ 'SparseSubdivide'
15
+ ]
16
+
17
+
18
+ class SparseDownsample(nn.Module):
19
+ """
20
+ Downsample a sparse tensor by a factor of `factor`.
21
+ Implemented as average pooling.
22
+ """
23
+ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
24
+ super(SparseDownsample, self).__init__()
25
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
26
+
27
+ def forward(self, input: SparseTensor) -> SparseTensor:
28
+ DIM = input.coords.shape[-1] - 1
29
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
30
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
31
+
32
+ coord = list(input.coords.unbind(dim=-1))
33
+ for i, f in enumerate(factor):
34
+ coord[i+1] = coord[i+1] // f
35
+
36
+ MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
37
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
38
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
39
+ code, idx = code.unique(return_inverse=True)
40
+
41
+ new_feats = torch.scatter_reduce(
42
+ torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
43
+ dim=0,
44
+ index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
45
+ src=input.feats,
46
+ reduce='mean'
47
+ )
48
+ new_coords = torch.stack(
49
+ [code // OFFSET[0]] +
50
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
51
+ dim=-1
52
+ )
53
+ out = SparseTensor(new_feats, new_coords, input.shape,)
54
+ out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
55
+ out._spatial_cache = input._spatial_cache
56
+
57
+ out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
58
+ out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
59
+ out.register_spatial_cache(f'upsample_{factor}_idx', idx)
60
+
61
+ return out
62
+
63
+
64
+ class SparseUpsample(nn.Module):
65
+ """
66
+ Upsample a sparse tensor by a factor of `factor`.
67
+ Implemented as nearest neighbor interpolation.
68
+ """
69
+ def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
70
+ super(SparseUpsample, self).__init__()
71
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
72
+
73
+ def forward(self, input: SparseTensor) -> SparseTensor:
74
+ DIM = input.coords.shape[-1] - 1
75
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
76
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
77
+
78
+ new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
79
+ new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
80
+ idx = input.get_spatial_cache(f'upsample_{factor}_idx')
81
+ if any([x is None for x in [new_coords, new_layout, idx]]):
82
+ raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
83
+ new_feats = input.feats[idx]
84
+ out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
85
+ out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
86
+ out._spatial_cache = input._spatial_cache
87
+ return out
88
+
89
+ class SparseSubdivide(nn.Module):
90
+ """
91
+ Upsample a sparse tensor by a factor of `factor`.
92
+ Implemented as nearest neighbor interpolation.
93
+ """
94
+ def __init__(self):
95
+ super(SparseSubdivide, self).__init__()
96
+
97
+ def forward(self, input: SparseTensor) -> SparseTensor:
98
+ DIM = input.coords.shape[-1] - 1
99
+ # upsample scale=2^DIM
100
+ n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
101
+ n_coords = torch.nonzero(n_cube)
102
+ n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
103
+ factor = n_coords.shape[0]
104
+ assert factor == 2 ** DIM
105
+ # print(n_coords.shape)
106
+ new_coords = input.coords.clone()
107
+ new_coords[:, 1:] *= 2
108
+ new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
109
+
110
+ new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
111
+ out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
112
+ out._scale = input._scale * 2
113
+ out._spatial_cache = input._spatial_cache
114
+ return out
115
+
threeDFixer/modules/sparse/transformer/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from .blocks import *
7
+ from .modulated import *
threeDFixer/modules/sparse/transformer/blocks.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ from ..basic import SparseTensor
10
+ from ..linear import SparseLinear
11
+ from ..nonlinearity import SparseGELU
12
+ from ..attention import SparseMultiHeadAttention, SerializeMode
13
+ from ...norm import LayerNorm32
14
+
15
+
16
+ class SparseFeedForwardNet(nn.Module):
17
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
18
+ super().__init__()
19
+ self.mlp = nn.Sequential(
20
+ SparseLinear(channels, int(channels * mlp_ratio)),
21
+ SparseGELU(approximate="tanh"),
22
+ SparseLinear(int(channels * mlp_ratio), channels),
23
+ )
24
+
25
+ def forward(self, x: SparseTensor) -> SparseTensor:
26
+ return self.mlp(x)
27
+
28
+
29
+ class SparseTransformerBlock(nn.Module):
30
+ """
31
+ Sparse Transformer block (MSA + FFN).
32
+ """
33
+ def __init__(
34
+ self,
35
+ channels: int,
36
+ num_heads: int,
37
+ mlp_ratio: float = 4.0,
38
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
39
+ window_size: Optional[int] = None,
40
+ shift_sequence: Optional[int] = None,
41
+ shift_window: Optional[Tuple[int, int, int]] = None,
42
+ serialize_mode: Optional[SerializeMode] = None,
43
+ use_checkpoint: bool = False,
44
+ use_rope: bool = False,
45
+ qk_rms_norm: bool = False,
46
+ qkv_bias: bool = True,
47
+ ln_affine: bool = False,
48
+ ):
49
+ super().__init__()
50
+ self.use_checkpoint = use_checkpoint
51
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
52
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
53
+ self.attn = SparseMultiHeadAttention(
54
+ channels,
55
+ num_heads=num_heads,
56
+ attn_mode=attn_mode,
57
+ window_size=window_size,
58
+ shift_sequence=shift_sequence,
59
+ shift_window=shift_window,
60
+ serialize_mode=serialize_mode,
61
+ qkv_bias=qkv_bias,
62
+ use_rope=use_rope,
63
+ qk_rms_norm=qk_rms_norm,
64
+ )
65
+ self.mlp = SparseFeedForwardNet(
66
+ channels,
67
+ mlp_ratio=mlp_ratio,
68
+ )
69
+
70
+ def _forward(self, x: SparseTensor) -> SparseTensor:
71
+ h = x.replace(self.norm1(x.feats))
72
+ h = self.attn(h)
73
+ x = x + h
74
+ h = x.replace(self.norm2(x.feats))
75
+ h = self.mlp(h)
76
+ x = x + h
77
+ return x
78
+
79
+ def forward(self, x: SparseTensor) -> SparseTensor:
80
+ if self.use_checkpoint:
81
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
82
+ else:
83
+ return self._forward(x)
84
+
85
+
86
+ class SparseTransformerCrossBlock(nn.Module):
87
+ """
88
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
89
+ """
90
+ def __init__(
91
+ self,
92
+ channels: int,
93
+ ctx_channels: int,
94
+ num_heads: int,
95
+ mlp_ratio: float = 4.0,
96
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
97
+ window_size: Optional[int] = None,
98
+ shift_sequence: Optional[int] = None,
99
+ shift_window: Optional[Tuple[int, int, int]] = None,
100
+ serialize_mode: Optional[SerializeMode] = None,
101
+ use_checkpoint: bool = False,
102
+ use_rope: bool = False,
103
+ qk_rms_norm: bool = False,
104
+ qk_rms_norm_cross: bool = False,
105
+ qkv_bias: bool = True,
106
+ ln_affine: bool = False,
107
+ ):
108
+ super().__init__()
109
+ self.use_checkpoint = use_checkpoint
110
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
111
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
112
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
113
+ self.self_attn = SparseMultiHeadAttention(
114
+ channels,
115
+ num_heads=num_heads,
116
+ type="self",
117
+ attn_mode=attn_mode,
118
+ window_size=window_size,
119
+ shift_sequence=shift_sequence,
120
+ shift_window=shift_window,
121
+ serialize_mode=serialize_mode,
122
+ qkv_bias=qkv_bias,
123
+ use_rope=use_rope,
124
+ qk_rms_norm=qk_rms_norm,
125
+ )
126
+ self.cross_attn = SparseMultiHeadAttention(
127
+ channels,
128
+ ctx_channels=ctx_channels,
129
+ num_heads=num_heads,
130
+ type="cross",
131
+ attn_mode="full",
132
+ qkv_bias=qkv_bias,
133
+ qk_rms_norm=qk_rms_norm_cross,
134
+ )
135
+ self.mlp = SparseFeedForwardNet(
136
+ channels,
137
+ mlp_ratio=mlp_ratio,
138
+ )
139
+
140
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
141
+ h = x.replace(self.norm1(x.feats))
142
+ h = self.self_attn(h)
143
+ x = x + h
144
+ h = x.replace(self.norm2(x.feats))
145
+ h = self.cross_attn(h, context)
146
+ x = x + h
147
+ h = x.replace(self.norm3(x.feats))
148
+ h = self.mlp(h)
149
+ x = x + h
150
+ return x
151
+
152
+ def forward(self, x: SparseTensor, context: torch.Tensor):
153
+ if self.use_checkpoint:
154
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
155
+ else:
156
+ return self._forward(x, context)
threeDFixer/modules/sparse/transformer/modulated.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from TRELLIS:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+ from typing import *
8
+ import torch
9
+ import torch.nn as nn
10
+ from ..basic import SparseTensor
11
+ from ..attention import SparseMultiHeadAttention, SerializeMode
12
+ from ...norm import LayerNorm32
13
+ from .blocks import SparseFeedForwardNet
14
+
15
+
16
+ class ModulatedSparseTransformerBlock(nn.Module):
17
+ """
18
+ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
19
+ """
20
+ def __init__(
21
+ self,
22
+ channels: int,
23
+ num_heads: int,
24
+ mlp_ratio: float = 4.0,
25
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
26
+ window_size: Optional[int] = None,
27
+ shift_sequence: Optional[int] = None,
28
+ shift_window: Optional[Tuple[int, int, int]] = None,
29
+ serialize_mode: Optional[SerializeMode] = None,
30
+ use_checkpoint: bool = False,
31
+ use_rope: bool = False,
32
+ qk_rms_norm: bool = False,
33
+ qkv_bias: bool = True,
34
+ share_mod: bool = False,
35
+ ):
36
+ super().__init__()
37
+ self.use_checkpoint = use_checkpoint
38
+ self.share_mod = share_mod
39
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
40
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
41
+ self.attn = SparseMultiHeadAttention(
42
+ channels,
43
+ num_heads=num_heads,
44
+ attn_mode=attn_mode,
45
+ window_size=window_size,
46
+ shift_sequence=shift_sequence,
47
+ shift_window=shift_window,
48
+ serialize_mode=serialize_mode,
49
+ qkv_bias=qkv_bias,
50
+ use_rope=use_rope,
51
+ qk_rms_norm=qk_rms_norm,
52
+ )
53
+ self.mlp = SparseFeedForwardNet(
54
+ channels,
55
+ mlp_ratio=mlp_ratio,
56
+ )
57
+ if not share_mod:
58
+ self.adaLN_modulation = nn.Sequential(
59
+ nn.SiLU(),
60
+ nn.Linear(channels, 6 * channels, bias=True)
61
+ )
62
+
63
+ def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
64
+ if self.share_mod:
65
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
66
+ else:
67
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
68
+ h = x.replace(self.norm1(x.feats))
69
+ h = h * (1 + scale_msa) + shift_msa
70
+ h = self.attn(h)
71
+ h = h * gate_msa
72
+ x = x + h
73
+ h = x.replace(self.norm2(x.feats))
74
+ h = h * (1 + scale_mlp) + shift_mlp
75
+ h = self.mlp(h)
76
+ h = h * gate_mlp
77
+ x = x + h
78
+ return x
79
+
80
+ def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
81
+ if self.use_checkpoint:
82
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
83
+ else:
84
+ return self._forward(x, mod)
85
+
86
+
87
+ class ModulatedSparseTransformerCrossBlock(nn.Module):
88
+ """
89
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
90
+ """
91
+ def __init__(
92
+ self,
93
+ channels: int,
94
+ ctx_channels: int,
95
+ num_heads: int,
96
+ mlp_ratio: float = 4.0,
97
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
98
+ window_size: Optional[int] = None,
99
+ shift_sequence: Optional[int] = None,
100
+ shift_window: Optional[Tuple[int, int, int]] = None,
101
+ serialize_mode: Optional[SerializeMode] = None,
102
+ use_checkpoint: bool = False,
103
+ use_rope: bool = False,
104
+ qk_rms_norm: bool = False,
105
+ qk_rms_norm_cross: bool = False,
106
+ qkv_bias: bool = True,
107
+ share_mod: bool = False,
108
+
109
+ ):
110
+ super().__init__()
111
+ self.use_checkpoint = use_checkpoint
112
+ self.share_mod = share_mod
113
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
114
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
115
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
116
+ self.self_attn = SparseMultiHeadAttention(
117
+ channels,
118
+ num_heads=num_heads,
119
+ type="self",
120
+ attn_mode=attn_mode,
121
+ window_size=window_size,
122
+ shift_sequence=shift_sequence,
123
+ shift_window=shift_window,
124
+ serialize_mode=serialize_mode,
125
+ qkv_bias=qkv_bias,
126
+ use_rope=use_rope,
127
+ qk_rms_norm=qk_rms_norm,
128
+ )
129
+ self.cross_attn = SparseMultiHeadAttention(
130
+ channels,
131
+ ctx_channels=ctx_channels,
132
+ num_heads=num_heads,
133
+ type="cross",
134
+ attn_mode="full",
135
+ qkv_bias=qkv_bias,
136
+ qk_rms_norm=qk_rms_norm_cross,
137
+ )
138
+ self.mlp = SparseFeedForwardNet(
139
+ channels,
140
+ mlp_ratio=mlp_ratio,
141
+ )
142
+ if not share_mod:
143
+ self.adaLN_modulation = nn.Sequential(
144
+ nn.SiLU(),
145
+ nn.Linear(channels, 6 * channels, bias=True)
146
+ )
147
+
148
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
149
+ if self.share_mod:
150
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
151
+ else:
152
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
153
+ h = x.replace(self.norm1(x.feats))
154
+ h = h * (1 + scale_msa) + shift_msa
155
+ h = self.self_attn(h)
156
+ h = h * gate_msa
157
+ x = x + h
158
+ h = x.replace(self.norm2(x.feats))
159
+ h = self.cross_attn(h, context)
160
+ x = x + h
161
+ h = x.replace(self.norm3(x.feats))
162
+ h = h * (1 + scale_mlp) + shift_mlp
163
+ h = self.mlp(h)
164
+ h = h * gate_mlp
165
+ x = x + h
166
+ return x
167
+
168
+ def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
169
+ if self.use_checkpoint:
170
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
171
+ else:
172
+ return self._forward(x, mod, context)
173
+
174
+
175
+ class ModulatedSceneSparseTransformerCrossBlock(nn.Module):
176
+ """
177
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
178
+ """
179
+ def __init__(
180
+ self,
181
+ channels: int,
182
+ ctx_channels: int,
183
+ num_heads: int,
184
+ mlp_ratio: float = 4.0,
185
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
186
+ window_size: Optional[int] = None,
187
+ shift_sequence: Optional[int] = None,
188
+ shift_window: Optional[Tuple[int, int, int]] = None,
189
+ serialize_mode: Optional[SerializeMode] = None,
190
+ use_checkpoint: bool = False,
191
+ use_rope: bool = False,
192
+ qk_rms_norm: bool = False,
193
+ qk_rms_norm_cross: bool = False,
194
+ qkv_bias: bool = True,
195
+ share_mod: bool = False,
196
+
197
+ ):
198
+ super().__init__()
199
+ self.use_checkpoint = use_checkpoint
200
+ self.share_mod = share_mod
201
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
202
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
203
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
204
+ self.norm4 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
205
+ self.norm5 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
206
+ self.self_attn = SparseMultiHeadAttention(
207
+ channels,
208
+ num_heads=num_heads,
209
+ type="self",
210
+ attn_mode=attn_mode,
211
+ window_size=window_size,
212
+ shift_sequence=shift_sequence,
213
+ shift_window=shift_window,
214
+ serialize_mode=serialize_mode,
215
+ qkv_bias=qkv_bias,
216
+ use_rope=use_rope,
217
+ qk_rms_norm=qk_rms_norm,
218
+ )
219
+ self.cross_attn = SparseMultiHeadAttention(
220
+ channels,
221
+ ctx_channels=ctx_channels,
222
+ num_heads=num_heads,
223
+ type="cross",
224
+ attn_mode="full",
225
+ qkv_bias=qkv_bias,
226
+ qk_rms_norm=qk_rms_norm_cross,
227
+ )
228
+ self.self_attn_vis_ratio = SparseMultiHeadAttention(
229
+ channels,
230
+ num_heads=num_heads,
231
+ type="self",
232
+ attn_mode=attn_mode,
233
+ window_size=window_size,
234
+ shift_sequence=shift_sequence,
235
+ shift_window=shift_window,
236
+ serialize_mode=serialize_mode,
237
+ qkv_bias=qkv_bias,
238
+ use_rope=use_rope,
239
+ qk_rms_norm=qk_rms_norm,
240
+ )
241
+ self.cross_attn_extra = SparseMultiHeadAttention(
242
+ channels,
243
+ ctx_channels=ctx_channels,
244
+ num_heads=num_heads,
245
+ type="cross",
246
+ attn_mode="full",
247
+ qkv_bias=qkv_bias,
248
+ qk_rms_norm=qk_rms_norm_cross,
249
+ )
250
+ self.mlp = SparseFeedForwardNet(
251
+ channels,
252
+ mlp_ratio=mlp_ratio,
253
+ )
254
+ if not share_mod:
255
+ self.adaLN_modulation = nn.Sequential(
256
+ nn.SiLU(),
257
+ nn.Linear(channels, 6 * channels, bias=True)
258
+ )
259
+ self.adaLN_modulation_vis = nn.Sequential(
260
+ nn.SiLU(),
261
+ nn.Linear(channels, 3 * channels, bias=True)
262
+ )
263
+
264
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, vis_mod: torch.Tensor, context: torch.Tensor, context_extra: torch.Tensor) -> SparseTensor:
265
+ if self.share_mod:
266
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
267
+ vis_shift_msa, vis_scale_msa, vis_gate_msa = vis_mod.chunk(3, dim=1)
268
+ else:
269
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
270
+ vis_shift_msa, vis_scale_msa, vis_gate_msa = self.adaLN_modulation_vis(vis_mod).chunk(3, dim=1)
271
+
272
+ h = x.replace(self.norm1(x.feats))
273
+ h = h * (1 + scale_msa) + shift_msa
274
+ h = self.self_attn(h)
275
+ h = h * gate_msa
276
+ x = x + h
277
+ h = x.replace(self.norm2(x.feats))
278
+ h = self.cross_attn(h, context)
279
+ x = x + h
280
+
281
+ ####### self attn to integrate vis ratio
282
+ h = x.replace(self.norm4(x.feats))
283
+ h = h * (1 + vis_scale_msa) + vis_shift_msa
284
+ h = self.self_attn_vis_ratio(h)
285
+ h = h * vis_gate_msa
286
+ x = x + h
287
+ # cross attn for integrate extra info
288
+ h = x.replace(self.norm5(x.feats))
289
+ h = self.cross_attn_extra(h, context_extra)
290
+ x = x + h
291
+ #######
292
+
293
+ h = x.replace(self.norm3(x.feats))
294
+ h = h * (1 + scale_mlp) + shift_mlp
295
+ h = self.mlp(h)
296
+ h = h * gate_mlp
297
+ x = x + h
298
+ return x
299
+
300
+ def forward(self, x: SparseTensor, mod: torch.Tensor, vis_mod: torch.Tensor, context: torch.Tensor, context_extra: torch.Tensor) -> SparseTensor:
301
+ if self.use_checkpoint:
302
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, vis_mod, context, context_extra, use_reentrant=False)
303
+ else:
304
+ return self._forward(x, mod, vis_mod, context, context_extra)
threeDFixer/modules/spatial.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch
7
+
8
+
9
+ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
10
+ """
11
+ 3D pixel shuffle.
12
+ """
13
+ B, C, H, W, D = x.shape
14
+ C_ = C // scale_factor**3
15
+ x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
16
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
17
+ x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
18
+ return x
19
+
20
+
21
+ def patchify(x: torch.Tensor, patch_size: int):
22
+ """
23
+ Patchify a tensor.
24
+
25
+ Args:
26
+ x (torch.Tensor): (N, C, *spatial) tensor
27
+ patch_size (int): Patch size
28
+ """
29
+ DIM = x.dim() - 2
30
+ for d in range(2, DIM + 2):
31
+ assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
32
+
33
+ x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
34
+ x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
35
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
36
+ return x
37
+
38
+
39
+ def unpatchify(x: torch.Tensor, patch_size: int):
40
+ """
41
+ Unpatchify a tensor.
42
+
43
+ Args:
44
+ x (torch.Tensor): (N, C, *spatial) tensor
45
+ patch_size (int): Patch size
46
+ """
47
+ DIM = x.dim() - 2
48
+ assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
49
+
50
+ x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
51
+ x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
52
+ x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
53
+ return x
threeDFixer/modules/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
threeDFixer/modules/transformer/blocks.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ from typing import *
7
+ import torch
8
+ import torch.nn as nn
9
+ from ..attention import MultiHeadAttention
10
+ from ..norm import LayerNorm32
11
+
12
+
13
+ class AbsolutePositionEmbedder(nn.Module):
14
+ """
15
+ Embeds spatial positions into vector representations.
16
+ """
17
+ def __init__(self, channels: int, in_channels: int = 3):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.in_channels = in_channels
21
+ self.freq_dim = channels // in_channels // 2
22
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
23
+ self.freqs = 1.0 / (10000 ** self.freqs)
24
+
25
+ def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Create sinusoidal position embeddings.
28
+
29
+ Args:
30
+ x: a 1-D Tensor of N indices
31
+
32
+ Returns:
33
+ an (N, D) Tensor of positional embeddings.
34
+ """
35
+ self.freqs = self.freqs.to(x.device)
36
+ out = torch.outer(x, self.freqs)
37
+ out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
38
+ return out
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ """
42
+ Args:
43
+ x (torch.Tensor): (N, D) tensor of spatial positions
44
+ """
45
+ N, D = x.shape
46
+ assert D == self.in_channels, "Input dimension must match number of input channels"
47
+ embed = self._sin_cos_embedding(x.reshape(-1))
48
+ embed = embed.reshape(N, -1)
49
+ if embed.shape[1] < self.channels:
50
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
51
+ return embed
52
+
53
+
54
+ class FeedForwardNet(nn.Module):
55
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
56
+ super().__init__()
57
+ self.mlp = nn.Sequential(
58
+ nn.Linear(channels, int(channels * mlp_ratio)),
59
+ nn.GELU(approximate="tanh"),
60
+ nn.Linear(int(channels * mlp_ratio), channels),
61
+ )
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ return self.mlp(x)
65
+
66
+
67
+ class TransformerBlock(nn.Module):
68
+ """
69
+ Transformer block (MSA + FFN).
70
+ """
71
+ def __init__(
72
+ self,
73
+ channels: int,
74
+ num_heads: int,
75
+ mlp_ratio: float = 4.0,
76
+ attn_mode: Literal["full", "windowed"] = "full",
77
+ window_size: Optional[int] = None,
78
+ shift_window: Optional[int] = None,
79
+ use_checkpoint: bool = False,
80
+ use_rope: bool = False,
81
+ qk_rms_norm: bool = False,
82
+ qkv_bias: bool = True,
83
+ ln_affine: bool = False,
84
+ ):
85
+ super().__init__()
86
+ self.use_checkpoint = use_checkpoint
87
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
88
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
89
+ self.attn = MultiHeadAttention(
90
+ channels,
91
+ num_heads=num_heads,
92
+ attn_mode=attn_mode,
93
+ window_size=window_size,
94
+ shift_window=shift_window,
95
+ qkv_bias=qkv_bias,
96
+ use_rope=use_rope,
97
+ qk_rms_norm=qk_rms_norm,
98
+ )
99
+ self.mlp = FeedForwardNet(
100
+ channels,
101
+ mlp_ratio=mlp_ratio,
102
+ )
103
+
104
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
105
+ h = self.norm1(x)
106
+ h = self.attn(h)
107
+ x = x + h
108
+ h = self.norm2(x)
109
+ h = self.mlp(h)
110
+ x = x + h
111
+ return x
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ if self.use_checkpoint:
115
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
116
+ else:
117
+ return self._forward(x)
118
+
119
+
120
+ class TransformerCrossBlock(nn.Module):
121
+ """
122
+ Transformer cross-attention block (MSA + MCA + FFN).
123
+ """
124
+ def __init__(
125
+ self,
126
+ channels: int,
127
+ ctx_channels: int,
128
+ num_heads: int,
129
+ mlp_ratio: float = 4.0,
130
+ attn_mode: Literal["full", "windowed"] = "full",
131
+ window_size: Optional[int] = None,
132
+ shift_window: Optional[Tuple[int, int, int]] = None,
133
+ use_checkpoint: bool = False,
134
+ use_rope: bool = False,
135
+ qk_rms_norm: bool = False,
136
+ qk_rms_norm_cross: bool = False,
137
+ qkv_bias: bool = True,
138
+ ln_affine: bool = False,
139
+ ):
140
+ super().__init__()
141
+ self.use_checkpoint = use_checkpoint
142
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
143
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
144
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
145
+ self.self_attn = MultiHeadAttention(
146
+ channels,
147
+ num_heads=num_heads,
148
+ type="self",
149
+ attn_mode=attn_mode,
150
+ window_size=window_size,
151
+ shift_window=shift_window,
152
+ qkv_bias=qkv_bias,
153
+ use_rope=use_rope,
154
+ qk_rms_norm=qk_rms_norm,
155
+ )
156
+ self.cross_attn = MultiHeadAttention(
157
+ channels,
158
+ ctx_channels=ctx_channels,
159
+ num_heads=num_heads,
160
+ type="cross",
161
+ attn_mode="full",
162
+ qkv_bias=qkv_bias,
163
+ qk_rms_norm=qk_rms_norm_cross,
164
+ )
165
+ self.mlp = FeedForwardNet(
166
+ channels,
167
+ mlp_ratio=mlp_ratio,
168
+ )
169
+
170
+ def _forward(self, x: torch.Tensor, context: torch.Tensor):
171
+ h = self.norm1(x)
172
+ h = self.self_attn(h)
173
+ x = x + h
174
+ h = self.norm2(x)
175
+ h = self.cross_attn(h, context)
176
+ x = x + h
177
+ h = self.norm3(x)
178
+ h = self.mlp(h)
179
+ x = x + h
180
+ return x
181
+
182
+ def forward(self, x: torch.Tensor, context: torch.Tensor):
183
+ if self.use_checkpoint:
184
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
185
+ else:
186
+ return self._forward(x, context)
187
+
threeDFixer/modules/transformer/modulated.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from TRELLIS:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+ from typing import *
8
+ import torch
9
+ import torch.nn as nn
10
+ from ..attention import MultiHeadAttention
11
+ from ..norm import LayerNorm32
12
+ from .blocks import FeedForwardNet
13
+
14
+
15
+ class ModulatedTransformerBlock(nn.Module):
16
+ """
17
+ Transformer block (MSA + FFN) with adaptive layer norm conditioning.
18
+ """
19
+ def __init__(
20
+ self,
21
+ channels: int,
22
+ num_heads: int,
23
+ mlp_ratio: float = 4.0,
24
+ attn_mode: Literal["full", "windowed"] = "full",
25
+ window_size: Optional[int] = None,
26
+ shift_window: Optional[Tuple[int, int, int]] = None,
27
+ use_checkpoint: bool = False,
28
+ use_rope: bool = False,
29
+ qk_rms_norm: bool = False,
30
+ qkv_bias: bool = True,
31
+ share_mod: bool = False,
32
+ ):
33
+ super().__init__()
34
+ self.use_checkpoint = use_checkpoint
35
+ self.share_mod = share_mod
36
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
37
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
38
+ self.attn = MultiHeadAttention(
39
+ channels,
40
+ num_heads=num_heads,
41
+ attn_mode=attn_mode,
42
+ window_size=window_size,
43
+ shift_window=shift_window,
44
+ qkv_bias=qkv_bias,
45
+ use_rope=use_rope,
46
+ qk_rms_norm=qk_rms_norm,
47
+ )
48
+ self.mlp = FeedForwardNet(
49
+ channels,
50
+ mlp_ratio=mlp_ratio,
51
+ )
52
+ if not share_mod:
53
+ self.adaLN_modulation = nn.Sequential(
54
+ nn.SiLU(),
55
+ nn.Linear(channels, 6 * channels, bias=True)
56
+ )
57
+
58
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
59
+ if self.share_mod:
60
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
61
+ else:
62
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
63
+ h = self.norm1(x)
64
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
65
+ h = self.attn(h)
66
+ h = h * gate_msa.unsqueeze(1)
67
+ x = x + h
68
+ h = self.norm2(x)
69
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
70
+ h = self.mlp(h)
71
+ h = h * gate_mlp.unsqueeze(1)
72
+ x = x + h
73
+ return x
74
+
75
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
76
+ if self.use_checkpoint:
77
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
78
+ else:
79
+ return self._forward(x, mod)
80
+
81
+
82
+ class ModulatedTransformerCrossBlock(nn.Module):
83
+ """
84
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
85
+ """
86
+ def __init__(
87
+ self,
88
+ channels: int,
89
+ ctx_channels: int,
90
+ num_heads: int,
91
+ mlp_ratio: float = 4.0,
92
+ attn_mode: Literal["full", "windowed"] = "full",
93
+ window_size: Optional[int] = None,
94
+ shift_window: Optional[Tuple[int, int, int]] = None,
95
+ use_checkpoint: bool = False,
96
+ use_rope: bool = False,
97
+ qk_rms_norm: bool = False,
98
+ qk_rms_norm_cross: bool = False,
99
+ qkv_bias: bool = True,
100
+ share_mod: bool = False,
101
+ ):
102
+ super().__init__()
103
+ self.use_checkpoint = use_checkpoint
104
+ self.share_mod = share_mod
105
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
106
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
107
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
108
+ self.self_attn = MultiHeadAttention(
109
+ channels,
110
+ num_heads=num_heads,
111
+ type="self",
112
+ attn_mode=attn_mode,
113
+ window_size=window_size,
114
+ shift_window=shift_window,
115
+ qkv_bias=qkv_bias,
116
+ use_rope=use_rope,
117
+ qk_rms_norm=qk_rms_norm,
118
+ )
119
+ self.cross_attn = MultiHeadAttention(
120
+ channels,
121
+ ctx_channels=ctx_channels,
122
+ num_heads=num_heads,
123
+ type="cross",
124
+ attn_mode="full",
125
+ qkv_bias=qkv_bias,
126
+ qk_rms_norm=qk_rms_norm_cross,
127
+ )
128
+ self.mlp = FeedForwardNet(
129
+ channels,
130
+ mlp_ratio=mlp_ratio,
131
+ )
132
+ if not share_mod:
133
+ self.adaLN_modulation = nn.Sequential(
134
+ nn.SiLU(),
135
+ nn.Linear(channels, 6 * channels, bias=True)
136
+ )
137
+
138
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
139
+ if self.share_mod:
140
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
141
+ else:
142
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
143
+ h = self.norm1(x)
144
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
145
+ h = self.self_attn(h)
146
+ h = h * gate_msa.unsqueeze(1)
147
+ x = x + h
148
+ h = self.norm2(x)
149
+ h = self.cross_attn(h, context)
150
+ x = x + h
151
+ h = self.norm3(x)
152
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
153
+ h = self.mlp(h)
154
+ h = h * gate_mlp.unsqueeze(1)
155
+ x = x + h
156
+ return x
157
+
158
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
159
+ if self.use_checkpoint:
160
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
161
+ else:
162
+ return self._forward(x, mod, context)
163
+
164
+
165
+ class SceneModulatedTransformerCrossBlock(nn.Module):
166
+ """
167
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
168
+ """
169
+ def __init__(
170
+ self,
171
+ channels: int,
172
+ ctx_channels: int,
173
+ num_heads: int,
174
+ mlp_ratio: float = 4.0,
175
+ attn_mode: Literal["full", "windowed"] = "full",
176
+ window_size: Optional[int] = None,
177
+ shift_window: Optional[Tuple[int, int, int]] = None,
178
+ use_checkpoint: bool = False,
179
+ use_rope: bool = False,
180
+ qk_rms_norm: bool = False,
181
+ qk_rms_norm_cross: bool = False,
182
+ qkv_bias: bool = True,
183
+ share_mod: bool = False,
184
+ ):
185
+ super().__init__()
186
+ self.use_checkpoint = use_checkpoint
187
+ self.share_mod = share_mod
188
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
189
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
190
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
191
+ self.norm4 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
192
+ self.norm5 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
193
+ self.self_attn = MultiHeadAttention(
194
+ channels,
195
+ num_heads=num_heads,
196
+ type="self",
197
+ attn_mode=attn_mode,
198
+ window_size=window_size,
199
+ shift_window=shift_window,
200
+ qkv_bias=qkv_bias,
201
+ use_rope=use_rope,
202
+ qk_rms_norm=qk_rms_norm,
203
+ )
204
+ self.cross_attn = MultiHeadAttention(
205
+ channels,
206
+ ctx_channels=ctx_channels,
207
+ num_heads=num_heads,
208
+ type="cross",
209
+ attn_mode="full",
210
+ qkv_bias=qkv_bias,
211
+ qk_rms_norm=qk_rms_norm_cross,
212
+ )
213
+ self.self_attn_dpt_ratio = MultiHeadAttention(
214
+ channels,
215
+ num_heads=num_heads,
216
+ type="self",
217
+ attn_mode=attn_mode,
218
+ window_size=window_size,
219
+ shift_window=shift_window,
220
+ qkv_bias=qkv_bias,
221
+ use_rope=use_rope,
222
+ qk_rms_norm=qk_rms_norm,
223
+ )
224
+ self.cross_attn_extra = MultiHeadAttention(
225
+ channels,
226
+ ctx_channels=ctx_channels,
227
+ num_heads=num_heads,
228
+ type="cross",
229
+ attn_mode="full",
230
+ qkv_bias=qkv_bias,
231
+ qk_rms_norm=qk_rms_norm_cross,
232
+ )
233
+ self.mlp = FeedForwardNet(
234
+ channels,
235
+ mlp_ratio=mlp_ratio,
236
+ )
237
+ if not share_mod:
238
+ self.adaLN_modulation = nn.Sequential(
239
+ nn.SiLU(),
240
+ nn.Linear(channels, 6 * channels, bias=True)
241
+ )
242
+ self.adaLN_modulation_dpt = nn.Sequential(
243
+ nn.SiLU(),
244
+ nn.Linear(channels, 3 * channels, bias=True)
245
+ )
246
+
247
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, dpt_mod: torch.Tensor, context: torch.Tensor, context_extra: torch.Tensor):
248
+ if self.share_mod:
249
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
250
+ dpt_shift_msa, dpt_scale_msa, dpt_gate_msa = dpt_mod.chunk(3, dim=1)
251
+ else:
252
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
253
+ dpt_shift_msa, dpt_scale_msa, dpt_gate_msa = self.adaLN_modulation_dpt(dpt_mod).chunk(3, dim=1)
254
+
255
+
256
+ h = self.norm1(x)
257
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
258
+ h = self.self_attn(h)
259
+ h = h * gate_msa.unsqueeze(1)
260
+ x = x + h
261
+ h = self.norm2(x)
262
+ h = self.cross_attn(h, context)
263
+ x = x + h
264
+
265
+ ####### self attn to integrate dpt ratio
266
+ h = self.norm4(x)
267
+ h = h * (1 + dpt_scale_msa.unsqueeze(1)) + dpt_shift_msa.unsqueeze(1)
268
+ h = self.self_attn_dpt_ratio(h)
269
+ h = h * dpt_gate_msa.unsqueeze(1)
270
+ x = x + h
271
+ # cross attn for integrate extra info
272
+ h = self.norm5(x)
273
+ h = self.cross_attn_extra(h, context_extra)
274
+ x = x + h
275
+ #######
276
+
277
+ h = self.norm3(x)
278
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
279
+ h = self.mlp(h)
280
+ h = h * gate_mlp.unsqueeze(1)
281
+ x = x + h
282
+ return x
283
+
284
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, dpt_mod: torch.Tensor, context: torch.Tensor, context_extra: torch.Tensor):
285
+ if self.use_checkpoint:
286
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, dpt_mod, context, context_extra, use_reentrant=False)
287
+ else:
288
+ return self._forward(x, mod, dpt_mod, context, context_extra)
289
+
threeDFixer/modules/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the TRELLIS project:
2
+ # https://github.com/microsoft/TRELLIS
3
+ # Original license: MIT
4
+ # Copyright (c) the TRELLIS authors
5
+
6
+ import torch.nn as nn
7
+ from ..modules import sparse as sp
8
+
9
+ FP16_MODULES = (
10
+ nn.Conv1d,
11
+ nn.Conv2d,
12
+ nn.Conv3d,
13
+ nn.ConvTranspose1d,
14
+ nn.ConvTranspose2d,
15
+ nn.ConvTranspose3d,
16
+ nn.Linear,
17
+ sp.SparseConv3d,
18
+ sp.SparseInverseConv3d,
19
+ sp.SparseLinear,
20
+ )
21
+
22
+ def convert_module_to_f16(l):
23
+ """
24
+ Convert primitive modules to float16.
25
+ """
26
+ if isinstance(l, FP16_MODULES):
27
+ for p in l.parameters():
28
+ p.data = p.data.half()
29
+
30
+
31
+ def convert_module_to_f32(l):
32
+ """
33
+ Convert primitive modules to float32, undoing convert_module_to_f16().
34
+ """
35
+ if isinstance(l, FP16_MODULES):
36
+ for p in l.parameters():
37
+ p.data = p.data.float()
38
+
39
+
40
+ def zero_module(module):
41
+ """
42
+ Zero out the parameters of a module and return it.
43
+ """
44
+ for p in module.parameters():
45
+ p.detach().zero_()
46
+ return module
47
+
48
+
49
+ def scale_module(module, scale):
50
+ """
51
+ Scale the parameters of a module and return it.
52
+ """
53
+ for p in module.parameters():
54
+ p.detach().mul_(scale)
55
+ return module
56
+
57
+
58
+ def modulate(x, shift, scale):
59
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
threeDFixer/moge/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+