Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +19 -0
- Koala-36M-v1/.gitattributes +68 -0
- Koala-36M-v1/Koala_36M_1.csv +3 -0
- Koala-36M-v1/Koala_36M_10.csv +3 -0
- Koala-36M-v1/Koala_36M_2.csv +3 -0
- Koala-36M-v1/Koala_36M_3.csv +3 -0
- Koala-36M-v1/Koala_36M_4.csv +3 -0
- Koala-36M-v1/Koala_36M_5.csv +3 -0
- Koala-36M-v1/Koala_36M_6.csv +3 -0
- Koala-36M-v1/Koala_36M_7.csv +3 -0
- Koala-36M-v1/Koala_36M_8.csv +3 -0
- Koala-36M-v1/Koala_36M_9.csv +3 -0
- URSA-1.7B/.gitattributes +37 -0
- URSA-1.7B/.gitignore +55 -0
- URSA-1.7B/LICENSE +176 -0
- URSA-1.7B/README.md +117 -0
- URSA-1.7B/model_index.json +19 -0
- URSA-1.7B/scheduler/__scheduler__.py +17 -0
- URSA-1.7B/scheduler/scheduler_config.json +7 -0
- URSA-1.7B/tokenizer/tokenizer_config.json +239 -0
- URSA-1.7B/transformer/__transformer__.py +17 -0
- URSA-1.7B/transformer/config.json +13 -0
- URSA-1.7B/transformer/diffusion_pytorch_model.safetensors +3 -0
- URSA-1.7B/vae/__vae__.py +17 -0
- URSA-1.7B/vae/config.json +22 -0
- URSA/.flake8 +21 -0
- URSA/.gitignore +55 -0
- URSA/=4.57.1 +70 -0
- URSA/LICENSE +176 -0
- URSA/README.md +191 -0
- URSA/accelerate_configs/deepspeed_zero2.yaml +12 -0
- URSA/assets/sample_image.jpg +0 -0
- URSA/configs/distill_dimo.yaml +158 -0
- URSA/configs/onestep_dimo.yaml +111 -0
- URSA/configs/ursa_0.6b_fsq320.yaml +62 -0
- URSA/configs/ursa_0.6b_ibq1024.yaml +62 -0
- URSA/configs/ursa_1.7b_fsq320.yaml +62 -0
- URSA/configs/ursa_1.7b_ibq1024.yaml +62 -0
- URSA/diffnext/__init__.py +16 -0
- URSA/diffnext/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/diffnext/__pycache__/image_processor.cpython-312.pyc +0 -0
- URSA/diffnext/data/__init__.py +16 -0
- URSA/diffnext/data/flex_loaders.py +172 -0
- URSA/diffnext/data/flex_pipelines.py +63 -0
- URSA/diffnext/data/flex_transforms.py +66 -0
- URSA/diffnext/engine/__init__.py +16 -0
- URSA/diffnext/engine/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/diffnext/engine/__pycache__/engine_utils.cpython-312.pyc +0 -0
- URSA/diffnext/engine/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- URSA/diffnext/engine/engine_utils.py +109 -0
.gitattributes
CHANGED
|
@@ -132,3 +132,22 @@ URSA/outputs/eval_distill_v3_100steps_49frames/03_s2_a_hummingbird_hovers_in_fro
|
|
| 132 |
URSA/outputs/eval_distill_v3_100steps_49frames/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 133 |
URSA/outputs/eval_distill_v3_100steps_49frames/00_s0_a_lone_grizzly_bear_walks_through_a_mist_teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 134 |
URSA/outputs/eval_distill_v3_100steps_49frames/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
URSA/outputs/eval_distill_v3_100steps_49frames/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 133 |
URSA/outputs/eval_distill_v3_100steps_49frames/00_s0_a_lone_grizzly_bear_walks_through_a_mist_teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 134 |
URSA/outputs/eval_distill_v3_100steps_49frames/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
URSA/outputs/eval_distill_v3_100steps_49frames/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
URSA/outputs/eval_distill_v3_100steps_49frames/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
URSA/outputs/eval_distill_v3_100steps_49frames/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
URSA/outputs/eval_distill_49frames/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
URSA/outputs/eval_distill_49frames/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
URSA/outputs/eval_distill_49frames/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
URSA/outputs/eval_distill_49frames/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
URSA/outputs/eval_distill_49frames/00_s0_a_lone_grizzly_bear_walks_through_a_mist_teacher_50step_cfg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
URSA/outputs/eval_distill_v3_200steps_49frames/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 144 |
+
Koala-36M-v1/Koala_36M_7.csv filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
Koala-36M-v1/Koala_36M_10.csv filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
Koala-36M-v1/Koala_36M_8.csv filter=lfs diff=lfs merge=lfs -text
|
| 147 |
+
Koala-36M-v1/Koala_36M_5.csv filter=lfs diff=lfs merge=lfs -text
|
| 148 |
+
Koala-36M-v1/Koala_36M_3.csv filter=lfs diff=lfs merge=lfs -text
|
| 149 |
+
Koala-36M-v1/Koala_36M_4.csv filter=lfs diff=lfs merge=lfs -text
|
| 150 |
+
Koala-36M-v1/Koala_36M_1.csv filter=lfs diff=lfs merge=lfs -text
|
| 151 |
+
Koala-36M-v1/Koala_36M_2.csv filter=lfs diff=lfs merge=lfs -text
|
| 152 |
+
Koala-36M-v1/Koala_36M_6.csv filter=lfs diff=lfs merge=lfs -text
|
| 153 |
+
Koala-36M-v1/Koala_36M_9.csv filter=lfs diff=lfs merge=lfs -text
|
Koala-36M-v1/.gitattributes
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
# Audio files - uncompressed
|
| 38 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
# Audio files - compressed
|
| 42 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
# Image files - uncompressed
|
| 48 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
# Image files - compressed
|
| 53 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
# Video files - compressed
|
| 57 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
Koala_36M_1.csv filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
Koala_36M_2.csv filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
Koala_36M_3.csv filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
Koala_36M_4.csv filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
Koala_36M_5.csv filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
Koala_36M_6.csv filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
Koala_36M_7.csv filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
Koala_36M_8.csv filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
Koala_36M_9.csv filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
Koala_36M_10.csv filter=lfs diff=lfs merge=lfs -text
|
Koala-36M-v1/Koala_36M_1.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5721d746552bcf48ca2c85d383eb3aee8a9d724cb8b498448e283e6c155b65f3
|
| 3 |
+
size 4889903599
|
Koala-36M-v1/Koala_36M_10.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3aa2590eb8302cf43106e7faf7ef36849fedeb6c5d5ca1ee214635f820adf807
|
| 3 |
+
size 4888525462
|
Koala-36M-v1/Koala_36M_2.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0da912f9903bcc06e077fd84e116f0497782743f35b4c1bfe06223e071720f2a
|
| 3 |
+
size 4889857219
|
Koala-36M-v1/Koala_36M_3.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e5b7cf12f398b9379ac4b6e65c4d5e3154be362513db781ede873d2ee485b112
|
| 3 |
+
size 4889283599
|
Koala-36M-v1/Koala_36M_4.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b75062281023cf982e885cabf79963482fd23e683cfb8e1c68d7ad6c1e363637
|
| 3 |
+
size 4889718227
|
Koala-36M-v1/Koala_36M_5.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd185667807e084a760bf3708d5da284a8603a48581a15b9810bded0f7fb4f7c
|
| 3 |
+
size 4889216599
|
Koala-36M-v1/Koala_36M_6.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69af3c329b77c8fe5fe2a3fd7b52ccce1f88f2649f4cc13e76ab27ecca5a5efa
|
| 3 |
+
size 4889541704
|
Koala-36M-v1/Koala_36M_7.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f654daa45977d2c12db1c22fbca9ef5bb729ba37240b83d0ed0bd1ca8008175
|
| 3 |
+
size 4889367231
|
Koala-36M-v1/Koala_36M_8.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d1b984b48a839619b82d1db10c27a518d89c66815be010feaee76816eb59ccd
|
| 3 |
+
size 4888856454
|
Koala-36M-v1/Koala_36M_9.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f62b8a588768377d49d92b9a4ea5eb9745537399b1fd1ccc556721edb96bc4ca
|
| 3 |
+
size 4889171948
|
URSA-1.7B/.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
. filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
URSA-1.7B/.gitignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compiled Object files
|
| 2 |
+
*.slo
|
| 3 |
+
*.lo
|
| 4 |
+
*.o
|
| 5 |
+
*.cuo
|
| 6 |
+
|
| 7 |
+
# Compiled Dynamic libraries
|
| 8 |
+
*.so
|
| 9 |
+
*.dll
|
| 10 |
+
*.dylib
|
| 11 |
+
|
| 12 |
+
# Compiled Static libraries
|
| 13 |
+
*.lai
|
| 14 |
+
*.la
|
| 15 |
+
*.a
|
| 16 |
+
*.lib
|
| 17 |
+
|
| 18 |
+
# Compiled python
|
| 19 |
+
*.pyc
|
| 20 |
+
__pycache__
|
| 21 |
+
|
| 22 |
+
# Compiled MATLAB
|
| 23 |
+
*.mex*
|
| 24 |
+
|
| 25 |
+
# IPython notebook checkpoints
|
| 26 |
+
.ipynb_checkpoints
|
| 27 |
+
|
| 28 |
+
# Editor temporaries
|
| 29 |
+
*.swp
|
| 30 |
+
*~
|
| 31 |
+
|
| 32 |
+
# Sublime Text settings
|
| 33 |
+
*.sublime-workspace
|
| 34 |
+
*.sublime-project
|
| 35 |
+
|
| 36 |
+
# Eclipse Project settings
|
| 37 |
+
*.*project
|
| 38 |
+
.settings
|
| 39 |
+
|
| 40 |
+
# QtCreator files
|
| 41 |
+
*.user
|
| 42 |
+
|
| 43 |
+
# VSCode files
|
| 44 |
+
.vscode
|
| 45 |
+
|
| 46 |
+
# IDEA files
|
| 47 |
+
.idea
|
| 48 |
+
|
| 49 |
+
# OSX dir files
|
| 50 |
+
.DS_Store
|
| 51 |
+
|
| 52 |
+
# Android files
|
| 53 |
+
.gradle
|
| 54 |
+
*.iml
|
| 55 |
+
local.properties
|
URSA-1.7B/LICENSE
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
URSA-1.7B/README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
license_link: https://huggingface.co/BAAI/URSA-1.7B-FSQ320/blob/main/LICENSE
|
| 5 |
+
pipeline_tag: text-to-video
|
| 6 |
+
base_model:
|
| 7 |
+
- Qwen/Qwen3-1.7B
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# URSA-1.7B-FSQ320 Model Card
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
- **Developed by:** BAAI
|
| 14 |
+
- **Model type:** Text-to-Video Generation Model
|
| 15 |
+
- **Model size:** 1.7B
|
| 16 |
+
- **Model precision:** torch.float16 (FP16)
|
| 17 |
+
- **Model resolution:** 512x320
|
| 18 |
+
- **Model paper:** [Uniform Discrete Diffusion with Metric Path for Video Generation](https://arxiv.org/abs/2510.24717)
|
| 19 |
+
- **Model family:** [BAAI-Vision-URSA](https://github.com/baaivision/URSA)
|
| 20 |
+
- **Model Tokenizer:** [Cosmos-Tokenize1-DV4x8x8-360p](https://huggingface.co/nvidia/Cosmos-Tokenize1-DV4x8x8-360p)
|
| 21 |
+
- **Model Description:** This is a model that can be used to generate and modify videos based on text prompts.
|
| 22 |
+
|
| 23 |
+
## Examples
|
| 24 |
+
|
| 25 |
+
Using the [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run URSA in a simple and efficient manner.
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
pip install diffusers transformers accelerate imageio[ffmpeg]
|
| 29 |
+
pip install git+ssh://git@github.com/baaivision/URSA.git
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Running the pipeline:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
import os, torch, numpy
|
| 36 |
+
from diffnext.pipelines import URSAPipeline
|
| 37 |
+
from diffnext.utils import export_to_video
|
| 38 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 39 |
+
|
| 40 |
+
model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
|
| 41 |
+
model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
|
| 42 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 43 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 44 |
+
|
| 45 |
+
text_prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
|
| 46 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 47 |
+
|
| 48 |
+
# Text-to-Image
|
| 49 |
+
prompt = text_prompt
|
| 50 |
+
num_frames, num_inference_steps = 1, 25
|
| 51 |
+
image = pipe(**locals()).frames[0]
|
| 52 |
+
image.save("ursa.jpg")
|
| 53 |
+
|
| 54 |
+
# Image-to-Video
|
| 55 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 56 |
+
num_frames, num_inference_steps = 49, 50
|
| 57 |
+
video = pipe(**locals()).frames[0]
|
| 58 |
+
export_to_video(video, "ursa_1+48f.mp4", fps=12)
|
| 59 |
+
|
| 60 |
+
# Text-to-Video
|
| 61 |
+
image, video = None, None
|
| 62 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 63 |
+
num_frames, num_inference_steps = 49, 50
|
| 64 |
+
video = pipe(**locals()).frames[0]
|
| 65 |
+
export_to_video(video, "ursa_49f.mp4", fps=12)
|
| 66 |
+
|
| 67 |
+
# Video-to-Video
|
| 68 |
+
prompt = f"motion=5.0, {text_prompt}"
|
| 69 |
+
num_frames, num_inference_steps = 49, 50
|
| 70 |
+
num_cond_frames, cond_noise_scale = 13, 0.1
|
| 71 |
+
for i in range(12):
|
| 72 |
+
video, start_video = video[-num_cond_frames:], video
|
| 73 |
+
video = pipe(**locals()).frames[0]
|
| 74 |
+
video = numpy.concatenate([start_video, video[num_cond_frames:]])
|
| 75 |
+
export_to_video(video, "ursa_{}f.mp4".format(video.shape[0]), fps=12)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
# Uses
|
| 79 |
+
|
| 80 |
+
## Direct Use
|
| 81 |
+
The model is intended for research purposes only. Possible research areas and tasks include
|
| 82 |
+
|
| 83 |
+
- Research on generative models.
|
| 84 |
+
- Applications in educational or creative tools.
|
| 85 |
+
- Generation of artworks and use in design and other artistic processes.
|
| 86 |
+
- Probing and understanding the limitations and biases of generative models.
|
| 87 |
+
- Safe deployment of models which have the potential to generate harmful content.
|
| 88 |
+
|
| 89 |
+
Excluded uses are described below.
|
| 90 |
+
|
| 91 |
+
#### Out-of-Scope Use
|
| 92 |
+
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
| 93 |
+
|
| 94 |
+
#### Misuse and Malicious Use
|
| 95 |
+
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
| 96 |
+
|
| 97 |
+
- Mis- and disinformation.
|
| 98 |
+
- Representations of egregious violence and gore.
|
| 99 |
+
- Impersonating individuals without their consent.
|
| 100 |
+
- Sexual content without consent of the people who might see it.
|
| 101 |
+
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
| 102 |
+
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
| 103 |
+
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
| 104 |
+
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
| 105 |
+
|
| 106 |
+
## Limitations and Bias
|
| 107 |
+
|
| 108 |
+
### Limitations
|
| 109 |
+
|
| 110 |
+
- The autoencoding part of the model is lossy.
|
| 111 |
+
- The model cannot render complex legible text.
|
| 112 |
+
- The model does not achieve perfect photorealism.
|
| 113 |
+
- The fingers, .etc in general may not be generated properly.
|
| 114 |
+
- The model was trained on a subset of the web datasets [LAION-5B](https://laion.ai/blog/laion-5b/) and [COYO-700M](https://github.com/kakaobrain/coyo-dataset), which contains adult, violent and sexual content.
|
| 115 |
+
|
| 116 |
+
### Bias
|
| 117 |
+
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
URSA-1.7B/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "URSAPipeline",
|
| 3 |
+
"tokenizer": [
|
| 4 |
+
"transformers",
|
| 5 |
+
"Qwen2TokenizerFast"
|
| 6 |
+
],
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"__scheduler__",
|
| 9 |
+
"KineticOptimalScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"__transformer__",
|
| 13 |
+
"URSATransformer3DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"__vae__",
|
| 17 |
+
"AutoencoderVQCosmos3D"
|
| 18 |
+
]
|
| 19 |
+
}
|
URSA-1.7B/scheduler/__scheduler__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
##############################################################################
|
| 15 |
+
"""Scheduler."""
|
| 16 |
+
|
| 17 |
+
from diffnext.schedulers.scheduling_dfm import KineticOptimalScheduler # noqa
|
URSA-1.7B/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "KineticOptimalScheduler",
|
| 3 |
+
"alpha": 1.0,
|
| 4 |
+
"c": 5,
|
| 5 |
+
"eps": 1e-5,
|
| 6 |
+
"shift": 4.0
|
| 7 |
+
}
|
URSA-1.7B/tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|vision_start|>",
|
| 224 |
+
"<|vision_end|>",
|
| 225 |
+
"<|vision_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
| 231 |
+
"clean_up_tokenization_spaces": false,
|
| 232 |
+
"eos_token": "<|im_end|>",
|
| 233 |
+
"errors": "replace",
|
| 234 |
+
"model_max_length": 131072,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"split_special_tokens": false,
|
| 237 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 238 |
+
"unk_token": null
|
| 239 |
+
}
|
URSA-1.7B/transformer/__transformer__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
##############################################################################
|
| 15 |
+
"""Transformer model."""
|
| 16 |
+
|
| 17 |
+
from diffnext.models.transformers.transformer_ursa import URSATransformer3DModel # noqa
|
URSA-1.7B/transformer/config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"hidden_size": 2048,
|
| 3 |
+
"intermediate_size": 6144,
|
| 4 |
+
"max_window_layers": 28,
|
| 5 |
+
"num_attention_heads": 16,
|
| 6 |
+
"num_key_value_heads": 8,
|
| 7 |
+
"num_hidden_layers": 28,
|
| 8 |
+
"rope_theta": 1000000,
|
| 9 |
+
"vocab_size": 215669,
|
| 10 |
+
"lm_vocab_size": 151669,
|
| 11 |
+
"lm_head_size": 64000,
|
| 12 |
+
"bov_token_id": 151652
|
| 13 |
+
}
|
URSA-1.7B/transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4a50d661919972cd5c8640ca3c9e5824945d105fb714a0de2f7610a4e7bebb8
|
| 3 |
+
size 3964379808
|
URSA-1.7B/vae/__vae__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
##############################################################################
|
| 15 |
+
"""VAE model."""
|
| 16 |
+
|
| 17 |
+
from diffnext.models.autoencoders.autoencoder_vq_cosmos3d import AutoencoderVQCosmos3D # noqa
|
URSA-1.7B/vae/config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderVQCosmos3D",
|
| 3 |
+
"_quantizer_name": "FSQuantizer",
|
| 4 |
+
"in_channels": 3,
|
| 5 |
+
"latent_channels": 256,
|
| 6 |
+
"layers_per_block": 2,
|
| 7 |
+
"norm_num_groups": 1,
|
| 8 |
+
"out_channels": 3,
|
| 9 |
+
"sample_size": 1024,
|
| 10 |
+
"sample_frames": 49,
|
| 11 |
+
"num_vq_embeddings": 64000,
|
| 12 |
+
"vq_embed_dim": 6,
|
| 13 |
+
"patch_size": 2,
|
| 14 |
+
"temporal_stride": 4,
|
| 15 |
+
"spatial_stride": 8,
|
| 16 |
+
"block_out_channels": [
|
| 17 |
+
128,
|
| 18 |
+
256,
|
| 19 |
+
512,
|
| 20 |
+
512
|
| 21 |
+
]
|
| 22 |
+
}
|
URSA/.flake8
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 100
|
| 3 |
+
ignore =
|
| 4 |
+
# whitespace before ':' (conflicted with Black)
|
| 5 |
+
E203,
|
| 6 |
+
# ambiguous variable name
|
| 7 |
+
E741,
|
| 8 |
+
# ‘from module import *’ used; unable to detect undefined names
|
| 9 |
+
F403,
|
| 10 |
+
# name may be undefined, or defined from star imports: module
|
| 11 |
+
F405,
|
| 12 |
+
# redefinition of unused name from line N
|
| 13 |
+
F811,
|
| 14 |
+
# undefined name
|
| 15 |
+
F821,
|
| 16 |
+
# line break before binary operator
|
| 17 |
+
W503,
|
| 18 |
+
# line break after binary operator
|
| 19 |
+
W504
|
| 20 |
+
# module imported but unused
|
| 21 |
+
per-file-ignores = __init__.py: F401
|
URSA/.gitignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compiled Object files
|
| 2 |
+
*.slo
|
| 3 |
+
*.lo
|
| 4 |
+
*.o
|
| 5 |
+
*.cuo
|
| 6 |
+
|
| 7 |
+
# Compiled Dynamic libraries
|
| 8 |
+
*.so
|
| 9 |
+
*.dll
|
| 10 |
+
*.dylib
|
| 11 |
+
|
| 12 |
+
# Compiled Static libraries
|
| 13 |
+
*.lai
|
| 14 |
+
*.la
|
| 15 |
+
*.a
|
| 16 |
+
*.lib
|
| 17 |
+
|
| 18 |
+
# Compiled python
|
| 19 |
+
*.pyc
|
| 20 |
+
__pycache__
|
| 21 |
+
|
| 22 |
+
# Compiled MATLAB
|
| 23 |
+
*.mex*
|
| 24 |
+
|
| 25 |
+
# IPython notebook checkpoints
|
| 26 |
+
.ipynb_checkpoints
|
| 27 |
+
|
| 28 |
+
# Editor temporaries
|
| 29 |
+
*.swp
|
| 30 |
+
*~
|
| 31 |
+
|
| 32 |
+
# Sublime Text settings
|
| 33 |
+
*.sublime-workspace
|
| 34 |
+
*.sublime-project
|
| 35 |
+
|
| 36 |
+
# Eclipse Project settings
|
| 37 |
+
*.*project
|
| 38 |
+
.settings
|
| 39 |
+
|
| 40 |
+
# QtCreator files
|
| 41 |
+
*.user
|
| 42 |
+
|
| 43 |
+
# VSCode files
|
| 44 |
+
.vscode
|
| 45 |
+
|
| 46 |
+
# IDEA files
|
| 47 |
+
.idea
|
| 48 |
+
|
| 49 |
+
# OSX dir files
|
| 50 |
+
.DS_Store
|
| 51 |
+
|
| 52 |
+
# Android files
|
| 53 |
+
.gradle
|
| 54 |
+
*.iml
|
| 55 |
+
local.properties
|
URSA/=4.57.1
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Requirement already satisfied: diffusers in /usr/local/lib/python3.12/dist-packages (0.36.0)
|
| 2 |
+
Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.2.0)
|
| 3 |
+
Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (1.12.0)
|
| 4 |
+
Requirement already satisfied: imageio in /usr/local/lib/python3.12/dist-packages (2.37.2)
|
| 5 |
+
Requirement already satisfied: imageio-ffmpeg in /usr/local/lib/python3.12/dist-packages (0.6.0)
|
| 6 |
+
Requirement already satisfied: omegaconf in /usr/local/lib/python3.12/dist-packages (2.3.0)
|
| 7 |
+
Requirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (0.25.0)
|
| 8 |
+
Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages/setuptools/_vendor (from diffusers) (8.0.0)
|
| 9 |
+
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from diffusers) (3.17.0)
|
| 10 |
+
Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from diffusers) (0.28.1)
|
| 11 |
+
Requirement already satisfied: huggingface-hub<2.0,>=0.34.0 in /usr/local/lib/python3.12/dist-packages (from diffusers) (1.3.0)
|
| 12 |
+
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from diffusers) (1.26.4)
|
| 13 |
+
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from diffusers) (2024.11.6)
|
| 14 |
+
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from diffusers) (2.32.3)
|
| 15 |
+
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from diffusers) (0.5.3)
|
| 16 |
+
Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from diffusers) (11.1.0)
|
| 17 |
+
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (23.2)
|
| 18 |
+
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.2)
|
| 19 |
+
Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)
|
| 20 |
+
Requirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from transformers) (0.21.2)
|
| 21 |
+
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.1)
|
| 22 |
+
Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (7.0.0)
|
| 23 |
+
Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.9.0+cu128)
|
| 24 |
+
Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.12/dist-packages (from omegaconf) (4.9.3)
|
| 25 |
+
Requirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.12/dist-packages (from wandb) (8.1.8)
|
| 26 |
+
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (3.1.46)
|
| 27 |
+
Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb) (4.3.6)
|
| 28 |
+
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.24.4)
|
| 29 |
+
Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.10.6)
|
| 30 |
+
Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.54.0)
|
| 31 |
+
Requirement already satisfied: typing-extensions<5,>=4.8 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.12.2)
|
| 32 |
+
Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
|
| 33 |
+
Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (4.8.0)
|
| 34 |
+
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (2025.1.31)
|
| 35 |
+
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (1.0.7)
|
| 36 |
+
Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (3.10)
|
| 37 |
+
Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->diffusers) (0.14.0)
|
| 38 |
+
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (2025.2.0)
|
| 39 |
+
Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (1.3.2)
|
| 40 |
+
Requirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (1.5.4)
|
| 41 |
+
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.7.0)
|
| 42 |
+
Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (2.27.2)
|
| 43 |
+
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (3.4.1)
|
| 44 |
+
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (2.0.7)
|
| 45 |
+
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (75.8.2)
|
| 46 |
+
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.14.0)
|
| 47 |
+
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.4.2)
|
| 48 |
+
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.6)
|
| 49 |
+
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)
|
| 50 |
+
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
|
| 51 |
+
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
|
| 52 |
+
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (9.10.2.21)
|
| 53 |
+
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.4.1)
|
| 54 |
+
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.3.3.83)
|
| 55 |
+
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (10.3.9.90)
|
| 56 |
+
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.7.3.90)
|
| 57 |
+
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.5.8.93)
|
| 58 |
+
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (0.7.1)
|
| 59 |
+
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (2.27.5)
|
| 60 |
+
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.3.20)
|
| 61 |
+
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
|
| 62 |
+
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)
|
| 63 |
+
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.13.1.3)
|
| 64 |
+
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.5.0)
|
| 65 |
+
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.12/dist-packages/setuptools/_vendor (from importlib_metadata->diffusers) (3.19.2)
|
| 66 |
+
Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer-slim->transformers) (0.0.4)
|
| 67 |
+
Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
|
| 68 |
+
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)
|
| 69 |
+
Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio->httpx<1.0.0->diffusers) (1.3.1)
|
| 70 |
+
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2)
|
URSA/LICENSE
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
URSA/README.md
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<img src="assets/logo.png" width="30%" alt="logo"/>
|
| 4 |
+
|
| 5 |
+
<h1>🐻 URSA: Uniform Discrete Diffusion with Metric Path<br>for Video Generation</h1>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<a href="https://arxiv.org/abs/2510.24717"><img src="https://img.shields.io/badge/ArXiv-2510.24717-%23840707.svg" alt="ArXiv"></a>
|
| 9 |
+
<a href="https://huggingface.co/collections/BAAI/ursa"><img src="https://img.shields.io/badge/🤗 Weights-BAAI/URSA-rgb(166,109,59).svg" alt=""></a>
|
| 10 |
+
<a href="https://huggingface.co/spaces/BAAI/nova-d48w1024-osp480"><img src="https://img.shields.io/badge/🤗 Demo-TI2V-%26840707.svg" alt="TI2VDemo"></a>
|
| 11 |
+
<a href="http://bitterdhg.github.io/URSA_page"><img src="https://img.shields.io/badge/Project-URSA-%237CB4F7.svg" alt="Project"></a>
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
|
| 16 |
+
[Haoge Deng](https://scholar.google.com/citations?user=S2sbvjgAAAAJ&hl)<sup>1,4*</sup>, [Ting Pan](https://scholar.google.com/citations?&user=qQv6YbsAAAAJ)<sup>2,4*</sup>, [Fan Zhang](https://scholar.google.com/citations?user=VsJ39HMAAAAJ)<sup>4*</sup>, [Yang Liu](https://scholar.google.com/citations?user=9JcQ2hwAAAAJ&hl)<sup>3,4*</sup>, [Zhuoyan Luo](https://scholar.google.com/citations?user=mKQhEsIAAAAJ&hl)<sup>4</sup>, [Yufeng Cui](https://scholar.google.com/citations?user=5Ydha2EAAAAJ&hl)<sup>4</sup>, [Wenxuan Wang](https://scholar.google.com/citations?user=75OyC-oAAAAJ&hl)<sup>4</sup><br>
|
| 17 |
+
[Chunhua Shen](https://scholar.google.com/citations?user=Ljk2BvIAAAAJ&hl)<sup>3</sup>, [Shiguang Shan](https://scholar.google.com/citations?user=Vkzd7MIAAAAJ&hl)<sup>2</sup>, [Zhaoxiang Zhang](https://scholar.google.com/citations?user=qxWfV6cAAAAJ&hl)<sup>1†</sup>, [Xinlong Wang](https://scholar.google.com/citations?user=DPz0DjYAAAAJ&hl)<sup>4†</sup><br>
|
| 18 |
+
|
| 19 |
+
[CASIA](http://english.ia.cas.cn)<sup>1</sup>, [CASICT](http://english.ict.cas.cn)<sup>2</sup>, [ZJU](https://www.zju.edu.cn/english)<sup>3</sup>, [BAAI](https://www.baai.ac.cn/en)<sup>4</sup><br>
|
| 20 |
+
<sup>*</sup> Equal Contribution, <sup>†</sup> Corresponding Author
|
| 21 |
+
<br><br><image src="assets/model_preview.gif"/>
|
| 22 |
+
<br><br><image src="assets/model_overview.png"/>
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
We present **URSA** (**U**niform disc**R**ete diffu**S**ion with metric p**A**th), a simple yet powerful framework that bridges the gap with continuous approaches. **URSA** formulates the video generation task as an iterative global refinement of discrete spatiotemporal tokens and scales efficiently to long video generation, requiring fewer inference steps. **URSA** enables multi-task video generation with asynchronous timestep scheduling strategy in one unified model.
|
| 26 |
+
|
| 27 |
+
## 🚀 News
|
| 28 |
+
- ```[Feb 2026]``` Accepted by ICLR 2026 [[OpenReview]](https://openreview.net/forum?id=GFU5yCbILk).
|
| 29 |
+
- ```[Jan 2026]``` Released [Training Guide](./docs/training.md).
|
| 30 |
+
- ```[Oct 2025]``` 🎉 URSA is part of [Emu3.5](https://github.com/baaivision/Emu3.5) as DiDA (Discrete Diffusion Adaptation)!
|
| 31 |
+
- ```[Oct 2025]``` Released <a href="https://huggingface.co/spaces/BAAI/nova-d48w1024-osp480"><b>TI2V</b></a> 🤗 Demo.
|
| 32 |
+
- ```[Oct 2025]``` Released [Paper](https://arxiv.org/abs/2510.24717) & [Project Page](http://bitterdhg.github.io/URSA_page) & [Evaluation Guide](./docs/evaluation.md).
|
| 33 |
+
|
| 34 |
+
## ✨Hightlights
|
| 35 |
+
|
| 36 |
+
- 🥇 **Novel Approach**: Uniform Discrete Diffusion with Metric Path.
|
| 37 |
+
- 🥈 **SOTA Performance**: High efficiency with state-of-the-art T2I/T2V/I2V results.
|
| 38 |
+
- 🥉 **Unified Modeling**: Multi-task capabilities in a single unified model.
|
| 39 |
+
|
| 40 |
+
## 🗄️ Models
|
| 41 |
+
|
| 42 |
+
### 🖼️ Text to Image
|
| 43 |
+
|
| 44 |
+
| Model | Resolution | Data | Weight | GenEval | DPGBench |
|
| 45 |
+
|:-----:|:----------:|:----:|:------:|:-------:|:--------:|
|
| 46 |
+
| URSA-0.6B-IBQ1024 | 1024x1024 | 30M | [🤗 HF](https://huggingface.co/BAAI/URSA-0.6B-IBQ1024) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-0.6B-IBQ1024) | 0.79 | 85.6 |
|
| 47 |
+
| URSA-1.7B-IBQ1024 | 1024x1024 | 30M | [🤗 HF](https://huggingface.co/BAAI/URSA-1.7B-IBQ1024) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-1.7B-IBQ1024) | 0.80 | 86.0 |
|
| 48 |
+
|
| 49 |
+
### 🎬 Text to Video
|
| 50 |
+
|
| 51 |
+
| Model | Resolution | Data | Weight | VBench-T2V | VBench-I2V |
|
| 52 |
+
|:-----:|:----------:|:----:|:------:|:----------:|:----------:|
|
| 53 |
+
| URSA-0.6B-FSQ320 | 49x512x320 | 24M | [🤗 HF](https://huggingface.co/BAAI/URSA-0.6B-FSQ320) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-0.6B-FSQ320) | 81.4 | 86.0 |
|
| 54 |
+
| URSA-1.7B-FSQ320 | 49x512x320 | 24M | [🤗 HF](https://huggingface.co/BAAI/URSA-1.7B-FSQ320) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-1.7B-FSQ320) | 82.4 | 86.2 |
|
| 55 |
+
|
| 56 |
+
## 📖 Table of Contents
|
| 57 |
+
- [🔧 Installation](#installation)
|
| 58 |
+
- [🔥 Quick Start](#quick-start)
|
| 59 |
+
- [🖼️ Image Generation](#quickstart-image-generation)
|
| 60 |
+
- [🎬 Video Generation](#quickstart-video-generation)
|
| 61 |
+
- [💻 Gradio Demo](#gradio-demo)
|
| 62 |
+
- [💯 Evaluation](./docs/evaluation.md)
|
| 63 |
+
- [🤖 Training](./docs/training.md)
|
| 64 |
+
|
| 65 |
+
## 🔧 Installation
|
| 66 |
+
<a id="installation"></a>
|
| 67 |
+
|
| 68 |
+
Clone this repository to local disk and install:
|
| 69 |
+
```bash
|
| 70 |
+
pip install diffusers transformers>=4.57.1 accelerate imageio imageio-ffmpeg omegaconf wandb
|
| 71 |
+
git clone https://github.com/baaivision/URSA.git
|
| 72 |
+
cd URSA && pip install .
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## 🔥 Quick Start
|
| 76 |
+
<a id="quick-start"></a>
|
| 77 |
+
|
| 78 |
+
### 🖼️ Image Generation
|
| 79 |
+
<a id="quickstart-image-generation"></a>
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
import torch
|
| 83 |
+
from diffnext.pipelines import URSAPipeline
|
| 84 |
+
|
| 85 |
+
model_id, height, width = "BAAI/URSA-1.7B-IBQ1024", 1024, 1024
|
| 86 |
+
model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
|
| 87 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 88 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 89 |
+
|
| 90 |
+
prompt = "The bear, calm and still, gazes upward as if lost in contemplation of the cosmos."
|
| 91 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 92 |
+
|
| 93 |
+
image = pipe(**locals()).frames[0]
|
| 94 |
+
image.save("ursa.jpg")
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 🎬 Video Generation
|
| 98 |
+
<a id="quickstart-video-generation"></a>
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
import os, torch, numpy
|
| 102 |
+
from diffnext.pipelines import URSAPipeline
|
| 103 |
+
from diffnext.utils import export_to_video
|
| 104 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 105 |
+
|
| 106 |
+
model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
|
| 107 |
+
model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
|
| 108 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 109 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 110 |
+
|
| 111 |
+
text_prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
|
| 112 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 113 |
+
|
| 114 |
+
# Text-to-Image
|
| 115 |
+
prompt = text_prompt
|
| 116 |
+
num_frames, num_inference_steps = 1, 25
|
| 117 |
+
image = pipe(**locals()).frames[0]
|
| 118 |
+
image.save("ursa.jpg")
|
| 119 |
+
|
| 120 |
+
# Image-to-Video
|
| 121 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 122 |
+
num_frames, num_inference_steps = 49, 50
|
| 123 |
+
video = pipe(**locals()).frames[0]
|
| 124 |
+
export_to_video(video, "ursa_1+48f.mp4", fps=12)
|
| 125 |
+
|
| 126 |
+
# Text-to-Video
|
| 127 |
+
image, video = None, None
|
| 128 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 129 |
+
num_frames, num_inference_steps = 49, 50
|
| 130 |
+
video = pipe(**locals()).frames[0]
|
| 131 |
+
export_to_video(video, "ursa_49f.mp4", fps=12)
|
| 132 |
+
|
| 133 |
+
# Video-to-Video
|
| 134 |
+
prompt = f"motion=5.0, {text_prompt}"
|
| 135 |
+
num_frames, num_inference_steps = 49, 50
|
| 136 |
+
num_cond_frames, cond_noise_scale = 13, 0.1
|
| 137 |
+
for i in range(12):
|
| 138 |
+
video, start_video = video[-num_cond_frames:], video
|
| 139 |
+
video = pipe(**locals()).frames[0]
|
| 140 |
+
video = numpy.concatenate([start_video, video[num_cond_frames:]])
|
| 141 |
+
export_to_video(video, "ursa_{}f.mp4".format(video.shape[0]), fps=12)
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## 💻 Gradio Demo
|
| 145 |
+
<a id="gradio-demo"></a>
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
# Text-to-Image (T2I)
|
| 149 |
+
python scripts/app_ursa_t2i.py --model "BAAI/URSA-1.7B-IBQ1024" --device 0
|
| 150 |
+
|
| 151 |
+
# Text-to-Image-to-Video (TI2V)
|
| 152 |
+
python scripts/app_ursa_ti2v.py --model "BAAI/URSA-1.7B-FSQ320" --device 0
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## 📋 Todo List
|
| 156 |
+
- [X] [Model Zoo](#model-zoo)
|
| 157 |
+
- [X] [Quick Start](#quick-start)
|
| 158 |
+
- [X] [Gradio Demo](#gradio-demo)
|
| 159 |
+
- [X] [Evaluation Guide](./docs/evaluation.md)
|
| 160 |
+
- [X] [Training Guide](./docs/training.md)
|
| 161 |
+
- [ ] 4B Model
|
| 162 |
+
|
| 163 |
+
## 📖 Citation
|
| 164 |
+
If you find this repository useful, please consider giving a star ⭐ and citation 🦖:
|
| 165 |
+
```
|
| 166 |
+
@article{deng2025ursa,
|
| 167 |
+
title={Uniform Discrete Diffusion with Metric Path for Video Generation},
|
| 168 |
+
author={Deng, Haoge and Pan, Ting and Zhang, Fan and Liu, Yang and Luo, Zhuoyan and Cui, Yufeng and Shen, Chunhua and Shan, Shiguang and Zhang, Zhaoxiang and Wang, Xinlong},
|
| 169 |
+
journal={arXiv preprint arXiv:2510.24717},
|
| 170 |
+
year={2025}
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
```
|
| 174 |
+
@article{deng2024nova,
|
| 175 |
+
title={Autoregressive Video Generation without Vector Quantization},
|
| 176 |
+
author={Deng, Haoge and Pan, Ting and Diao, Haiwen and Luo, Zhengxiong and Cui, Yufeng and Lu, Huchuan and Shan, Shiguang and Qi, Yonggang and Wang, Xinlong},
|
| 177 |
+
journal={arXiv preprint arXiv:2412.14169},
|
| 178 |
+
year={2024}
|
| 179 |
+
}
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## 🤗 Acknowledgement
|
| 183 |
+
|
| 184 |
+
We thank the repositories:
|
| 185 |
+
- [NOVA](https://github.com/baaivision/NOVA). ✨NOVA is the predecessor of 🐻URSA.
|
| 186 |
+
- [FlowMatching](https://github.com/facebookresearch/flow_matching). This codebase systemically provides CFM and DFM implementations.
|
| 187 |
+
- [FUDOKI](https://github.com/fudoki-hku/FUDOKI). This codebase provides a naive multimodal DFM implementation.
|
| 188 |
+
- [CodeWithGPU](https://github.com/seetacloud/codewithgpu). CodeWithGPU library is the core of our data loading pipeline.
|
| 189 |
+
|
| 190 |
+
## License
|
| 191 |
+
Code and models are licensed under [Apache License 2.0](LICENSE).
|
URSA/accelerate_configs/deepspeed_zero2.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
distributed_type: DEEPSPEED
|
| 2 |
+
deepspeed_config:
|
| 3 |
+
deepspeed_multinode_launcher: standard
|
| 4 |
+
gradient_clipping: 0.0
|
| 5 |
+
zero_stage: 3 #2
|
| 6 |
+
offload_optimizer_device: cpu # Moves optimizer states to CPU RAM
|
| 7 |
+
offload_param_device: cpu # Moves model parameters to CPU RAM
|
| 8 |
+
zero3_init_flag: true # Initializes the model directly across GPUs to save CPU RAM
|
| 9 |
+
zero3_save_16bit_model: true # Consolidates weights into a single file when saving checkpoints
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 8
|
| 12 |
+
machine_rank: 0
|
URSA/assets/sample_image.jpg
ADDED
|
URSA/configs/distill_dimo.yaml
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================================
|
| 2 |
+
# URSA one-step distillation — DiMO-style distributed training config
|
| 3 |
+
# ============================================================================
|
| 4 |
+
# Verified native inference regime (from A/B testing — ground truth):
|
| 5 |
+
# height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50.
|
| 6 |
+
# no_cfg (guidance_scale=1) does NOT produce valid output.
|
| 7 |
+
# All defaults below align to this verified regime.
|
| 8 |
+
#
|
| 9 |
+
# Launch (8-GPU, single node):
|
| 10 |
+
#
|
| 11 |
+
# accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \
|
| 12 |
+
# --machine_rank 0 --num_machines 1 --num_processes 8 \
|
| 13 |
+
# scripts/train_distill_dimo.py \
|
| 14 |
+
# config="./configs/distill_dimo.yaml" \
|
| 15 |
+
# experiment.output_dir="./experiments/distill_dimo" \
|
| 16 |
+
# distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \
|
| 17 |
+
# distill.prompt_source="/data/Koala_36M_*.csv"
|
| 18 |
+
#
|
| 19 |
+
# Smoke test (1 GPU, 50 steps — save student checkpoint):
|
| 20 |
+
#
|
| 21 |
+
# accelerate launch --num_processes 1 \
|
| 22 |
+
# scripts/train_distill_dimo.py \
|
| 23 |
+
# config="./configs/distill_dimo.yaml" \
|
| 24 |
+
# experiment.output_dir="./experiments/smoke" \
|
| 25 |
+
# distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \
|
| 26 |
+
# distill.prompt_source="prompts.txt" \
|
| 27 |
+
# training.max_train_steps=50 \
|
| 28 |
+
# experiment.save_every=50
|
| 29 |
+
#
|
| 30 |
+
# Load student for 1-step inference (must use CFG=7, native geometry):
|
| 31 |
+
#
|
| 32 |
+
# pipe = URSAPipeline.from_pretrained("/path/to/URSA-1.7B-IBQ1024")
|
| 33 |
+
# state = torch.load("experiments/distill_dimo/checkpoints/final/student.pt")
|
| 34 |
+
# pipe.transformer.load_state_dict(state, strict=True)
|
| 35 |
+
# frames = pipe(prompt="...", num_inference_steps=1,
|
| 36 |
+
# height=320, width=512, num_frames=49,
|
| 37 |
+
# guidance_scale=7).frames
|
| 38 |
+
# ============================================================================
|
| 39 |
+
|
| 40 |
+
# ── Experiment bookkeeping ───────────────────────────────────────────────────
|
| 41 |
+
experiment:
|
| 42 |
+
name: distill_dimo
|
| 43 |
+
output_dir: ./experiments/distill_dimo
|
| 44 |
+
log_every: 10
|
| 45 |
+
save_every: 100
|
| 46 |
+
resume_iter: 0 # set to step number to resume
|
| 47 |
+
|
| 48 |
+
# ── Training (framework-level) ───────────────────────────────────────────────
|
| 49 |
+
training:
|
| 50 |
+
seed: 42
|
| 51 |
+
mixed_precision: bf16 # bf16 | fp16 | fp32
|
| 52 |
+
max_train_steps: 10000
|
| 53 |
+
gradient_accumulation_steps: 1 # Two-backward; keep =1 for distillation
|
| 54 |
+
|
| 55 |
+
# ── Distillation hyperparameters ─────────────────────────────────────────────
|
| 56 |
+
distill:
|
| 57 |
+
# ---- Paths ----------------------------------------------------------------
|
| 58 |
+
teacher_ckpt: /gfs/space/private/fengzl/World_Model/URSA-1.7B
|
| 59 |
+
prompt_source: /gfs/space/private/fengzl/World_Model/Koala-36M-v1 # glob, dir, .txt, or comma-list
|
| 60 |
+
|
| 61 |
+
# ---- Video geometry (verified native: 320×512×49) -------------------------
|
| 62 |
+
num_frames: 49
|
| 63 |
+
height: 320
|
| 64 |
+
width: 512
|
| 65 |
+
max_prompt_length: 320
|
| 66 |
+
|
| 67 |
+
# ---- Data -----------------------------------------------------------------
|
| 68 |
+
batch_size_per_gpu: 1 # effective global batch = batch_size_per_gpu × 8 GPUs
|
| 69 |
+
|
| 70 |
+
# # ---- Loss weights ---------------------------------------------------------
|
| 71 |
+
# lambda_kd: 0.5 # KL(z_T || z_S) weight
|
| 72 |
+
# lambda_pg: 1.0 # REINFORCE policy gradient weight
|
| 73 |
+
# lambda_ent: 0.01 # entropy bonus (λ_ent_eff × H) — set 0 for DiMO orig
|
| 74 |
+
# tau: 1.0 # student sampling temperature
|
| 75 |
+
# tau_kd: 1.0 # KD / Jeffrey softmax temperature
|
| 76 |
+
|
| 77 |
+
# # ---- Teacher CFG (aligned to verified working regime: CFG=7) ---------------
|
| 78 |
+
# # A/B testing confirmed: guidance_scale=1 (no_cfg) does NOT produce valid
|
| 79 |
+
# # output for this URSA checkpoint. The teacher KD target must use CFG=7.
|
| 80 |
+
# enable_teacher_cfg: true
|
| 81 |
+
# teacher_cfg_scale: 7.0 # s in z_guided = z_uncond + s*(z_cond-z_uncond)
|
| 82 |
+
# # Verified: CFG=7 is the official working value.
|
| 83 |
+
# teacher_cfg_prob: 1.0 # max fraction of samples using guided target
|
| 84 |
+
# teacher_cfg_warmup_steps: 2000 # linear warmup 0→teacher_cfg_prob
|
| 85 |
+
# teacher_cfg_trunc: 0.9 # when t≥trunc, scale falls to 1 (no guide)
|
| 86 |
+
# lambda_kd_uncond: 0.3 # weight for uncond-branch KD loss
|
| 87 |
+
# reward_use_guided: false # [RISKY] use guided logits for reward signal
|
| 88 |
+
|
| 89 |
+
# # ---- DiMO extensions -------------------------------------------------------
|
| 90 |
+
# fake_rounds: 1 # aux updates per student update (DiMO=2; try 2)
|
| 91 |
+
# use_surrogate_grad: false
|
| 92 |
+
# lambda_surr: 1.0
|
| 93 |
+
|
| 94 |
+
# ---- Loss weights ---------------------------------------------------------
|
| 95 |
+
lambda_kd: 1.0 # KL(z_T || z_S) weight (基础知识蒸馏权重,保持不变)
|
| 96 |
+
lambda_pg: 1.0 # [重用] 现在代表 lambda_bridge,控制 MSE 伪梯度注入的强度
|
| 97 |
+
lambda_ent: 0.0 # [已废弃] 强化学习的熵奖励已彻底删除,设为 0.0
|
| 98 |
+
tau: 1.0 # student sampling temperature
|
| 99 |
+
tau_kd: 1.0 # KD softmax temperature
|
| 100 |
+
|
| 101 |
+
# ---- Teacher CFG (aligned to verified working regime: CFG=7) ---------------
|
| 102 |
+
enable_teacher_cfg: true
|
| 103 |
+
teacher_cfg_scale: 7.0
|
| 104 |
+
teacher_cfg_prob: 1.0
|
| 105 |
+
teacher_cfg_warmup_steps: 1000
|
| 106 |
+
teacher_cfg_trunc: 0.9
|
| 107 |
+
lambda_kd_uncond: 0.3
|
| 108 |
+
# reward_use_guided: false <-- [请直接删除这行] 因为 Reward 计算已被移除
|
| 109 |
+
|
| 110 |
+
# ---- DiMO extensions -------------------------------------------------------
|
| 111 |
+
fake_rounds: 2 #1 # Aux 拟合假 token 的迭代次数。如果发现 Aux 算出的 bridge_loss 降不下去,可以尝试改为 2
|
| 112 |
+
use_surrogate_grad: false
|
| 113 |
+
lambda_surr: 1.0
|
| 114 |
+
|
| 115 |
+
# ---- Stability -------------------------------------------------------------
|
| 116 |
+
t_curriculum_steps: 10000 # curriculum steps before uniform-t sampling
|
| 117 |
+
p_init_mix_ratio: 0.2 # fraction of batch from corrupted x_hat_prev
|
| 118 |
+
p_mix_corrupt_frac: 0.2 # token corruption rate in p_init mixing
|
| 119 |
+
collapse_warn_frac: 0.2 # warn if tok_entropy < frac × initial entropy
|
| 120 |
+
|
| 121 |
+
# ---- Aux initialisation ---------------------------------------------------
|
| 122 |
+
aux_noise_std: 1.0e-5 # tiny noise added to aux weights at init to break
|
| 123 |
+
# symmetry; set 0.0 to keep aux == student exactly
|
| 124 |
+
|
| 125 |
+
# ---- Gradient clipping ----------------------------------------------------
|
| 126 |
+
grad_clip: 1.0
|
| 127 |
+
|
| 128 |
+
# ── Student optimizer ────────────────────────────────────────────────────────
|
| 129 |
+
optimizer_student:
|
| 130 |
+
target: torch.optim.AdamW
|
| 131 |
+
params:
|
| 132 |
+
lr: 1.0e-5
|
| 133 |
+
betas: [0.9, 0.95]
|
| 134 |
+
weight_decay: 0.01
|
| 135 |
+
|
| 136 |
+
# ── Aux optimizer ────────────────────────────────────────────────────────────
|
| 137 |
+
optimizer_aux:
|
| 138 |
+
target: torch.optim.AdamW
|
| 139 |
+
params:
|
| 140 |
+
lr: 1.0e-5
|
| 141 |
+
betas: [0.9, 0.95]
|
| 142 |
+
weight_decay: 0.01
|
| 143 |
+
|
| 144 |
+
# ── LR scheduler (cosine, shared warmup/decay params for both opts) ──────────
|
| 145 |
+
lr_scheduler:
|
| 146 |
+
target: diffnext.engine.lr_scheduler.CosineLR
|
| 147 |
+
params:
|
| 148 |
+
lr_max: ${optimizer_student.params.lr}
|
| 149 |
+
lr_min: 1.0e-6
|
| 150 |
+
max_steps: ${training.max_train_steps}
|
| 151 |
+
warmup_steps: 500
|
| 152 |
+
|
| 153 |
+
# ── Prompt DataLoader ─────────────────────────────────────────────────────────
|
| 154 |
+
prompt_dataloader:
|
| 155 |
+
shuffle_files: true
|
| 156 |
+
shuffle_buffer: 50000 # in-memory shuffle buffer per shard; reduce if OOM
|
| 157 |
+
num_workers: 4 # CPU workers (no CUDA in workers)
|
| 158 |
+
caption_field: caption # CSV column name (Koala default)
|
URSA/configs/onestep_dimo.yaml
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================================
|
| 2 |
+
# URSA one-step distillation — DiMO-style training configuration
|
| 3 |
+
# ============================================================================
|
| 4 |
+
# Reference: train_onestep_ursa_dimo.py
|
| 5 |
+
#
|
| 6 |
+
# DiMO hyperparameter comparison (Meissonic vs. our URSA defaults)
|
| 7 |
+
# ---------------------------------------------------------------
|
| 8 |
+
# Param DiMO (Meissonic) URSA (this config) Risk / Note
|
| 9 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 10 |
+
# guidance_scale (CFG) 3.0 (true_cfg) 3.0 (teacher_cfg) ✅ aligned
|
| 11 |
+
# fake_rounds 2 1 ⚠ try 2 for aux stability
|
| 12 |
+
# fixed_ratio 0.5 (mask ratio) — N/A (different domain)
|
| 13 |
+
# distil_loss_type surrogate MSE optional surrogate ✅ toggle via use_surrogate_grad
|
| 14 |
+
# noise_emb_perturb True — ℹ️ not needed for VQ-based model
|
| 15 |
+
# cfg_prob 1.0 teacher_cfg_prob=1.0 ✅ aligned
|
| 16 |
+
# lambda_ent 0.0 (no ent reg) 0.01 ℹ️ our addition for stability
|
| 17 |
+
# ============================================================================
|
| 18 |
+
|
| 19 |
+
# ── Paths ────────────────────────────────────────────────────────────────────
|
| 20 |
+
teacher_ckpt: "/path/to/URSA"
|
| 21 |
+
prompt_file: "prompts.txt"
|
| 22 |
+
out_dir: "./outputs/dimo"
|
| 23 |
+
|
| 24 |
+
# ── Video geometry ───────────────────────────────────────────────────────────
|
| 25 |
+
num_frames: 17
|
| 26 |
+
height: 256
|
| 27 |
+
width: 256
|
| 28 |
+
max_prompt_length: 320
|
| 29 |
+
|
| 30 |
+
# ── Training ─────────────────────────────────────────────────────────────────
|
| 31 |
+
batch_size: 2 # reduce to 1 if enable_teacher_cfg uses too much VRAM
|
| 32 |
+
num_steps: 10000
|
| 33 |
+
lr_student: 1.0e-5
|
| 34 |
+
lr_aux: 1.0e-5
|
| 35 |
+
weight_decay: 0.01
|
| 36 |
+
grad_clip: 1.0
|
| 37 |
+
mixed_precision: "bf16"
|
| 38 |
+
seed: 42
|
| 39 |
+
log_every: 50
|
| 40 |
+
save_every: 1000
|
| 41 |
+
|
| 42 |
+
# ── Loss weights ─────────────────────────────────────────────────────────────
|
| 43 |
+
lambda_pg: 1.0
|
| 44 |
+
lambda_kd: 0.5
|
| 45 |
+
lambda_ent: 0.01 # entropy regularisation (0 → DiMO original; 0.01 → our default)
|
| 46 |
+
tau: 1.0 # student sampling temperature
|
| 47 |
+
tau_kd: 1.0 # KD softmax temperature
|
| 48 |
+
|
| 49 |
+
# ── Teacher CFG (DiMO true_cfg style) ────────────────────────────────────────
|
| 50 |
+
# Set enable_teacher_cfg: false to revert to the prior single-branch behavior.
|
| 51 |
+
# All other params in this block are ignored when enable_teacher_cfg=false.
|
| 52 |
+
enable_teacher_cfg: true
|
| 53 |
+
|
| 54 |
+
teacher_cfg_scale: 3.0 # s in z_guided = z_uncond + s*(z_cond - z_uncond)
|
| 55 |
+
# Matches DiMO true_cfg=3.0
|
| 56 |
+
|
| 57 |
+
teacher_cfg_prob: 1.0 # Probability of using guided target per batch (after warmup).
|
| 58 |
+
# 1.0 = always guided (DiMO default).
|
| 59 |
+
|
| 60 |
+
teacher_cfg_warmup_steps: 2000
|
| 61 |
+
# Ramp teacher_cfg_prob from 0 → teacher_cfg_prob over this many
|
| 62 |
+
# steps. Prevents instability at the start of training.
|
| 63 |
+
|
| 64 |
+
teacher_cfg_trunc: 0.9 # When t >= trunc, CFG scale falls to 1 (no guidance at high noise).
|
| 65 |
+
# Mirrors DiMO's guidance_trunc parameter.
|
| 66 |
+
|
| 67 |
+
lambda_kd_uncond: 0.3 # Weight for uncond-branch KD loss.
|
| 68 |
+
# Keeps the student uncond-capable for eval-time CFG.
|
| 69 |
+
|
| 70 |
+
reward_use_guided: false # [RISKY] Use guided teacher logits for REINFORCE reward.
|
| 71 |
+
# Default false: use non-guided cond (more stable).
|
| 72 |
+
|
| 73 |
+
# ── Eval / inference CFG ─────────────────────────────────────────────────────
|
| 74 |
+
eval_cfg_scale: 3.0 # guidance_scale used during evaluation
|
| 75 |
+
use_cfg_eval: false # Run eval with inference-time CFG (2× forward)
|
| 76 |
+
|
| 77 |
+
# ── DiMO extensions ──────────────────────────────────────────────────────────
|
| 78 |
+
use_surrogate_grad: false # DiMO surrogate MSE trick (zero-variance alternative to REINFORCE)
|
| 79 |
+
lambda_surr: 1.0
|
| 80 |
+
fake_rounds: 1 # Aux updates per generator update (DiMO uses 2; try 2 for aux stability)
|
| 81 |
+
|
| 82 |
+
# ── Stability ─────────────────────────────────────────────────────────────────
|
| 83 |
+
t_curriculum_steps: 10000 # Steps to use t-curriculum (biases t toward larger values)
|
| 84 |
+
p_mix_corrupt_frac: 0.2 # Fraction of tokens to corrupt in p_init mixing
|
| 85 |
+
p_init_mix_ratio: 0.2 # Fraction of batch drawn from corrupted x_hat_prev
|
| 86 |
+
collapse_warn_frac: 0.2 # Warn if tok_hist_entropy drops below this fraction of initial
|
| 87 |
+
|
| 88 |
+
# ── Debug ────────────────────────────────────────────────────────────────────
|
| 89 |
+
dry_run: false # Run 1 step, print diagnostics, exit
|
| 90 |
+
debug_dump: 0 # Dump token histogram + x_hat every N steps (0=off)
|
| 91 |
+
|
| 92 |
+
# ── Recommended quick-start commands ─────────────────────────────────────────
|
| 93 |
+
# # Smoke test (CFG enabled):
|
| 94 |
+
# python scripts/train_onestep_ursa_dimo.py \
|
| 95 |
+
# --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \
|
| 96 |
+
# --enable_teacher_cfg --teacher_cfg_scale 3.0 \
|
| 97 |
+
# --num_frames 17 --height 256 --width 256 --dry_run
|
| 98 |
+
#
|
| 99 |
+
# # Full training (DiMO-aligned):
|
| 100 |
+
# python scripts/train_onestep_ursa_dimo.py \
|
| 101 |
+
# --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \
|
| 102 |
+
# --enable_teacher_cfg --teacher_cfg_scale 3.0 \
|
| 103 |
+
# --batch_size 2 --num_steps 10000 --fake_rounds 2 \
|
| 104 |
+
# --out_dir ./outputs/dimo_cfg
|
| 105 |
+
#
|
| 106 |
+
# # Eval (compare 3 student modes vs teacher):
|
| 107 |
+
# python scripts/eval_onestep_ursa.py \
|
| 108 |
+
# --teacher_ckpt /path/to/URSA \
|
| 109 |
+
# --student_ckpt ./outputs/dimo_cfg/final/student.pt \
|
| 110 |
+
# --modes no_cfg cfg baked --eval_cfg_scale 3.0 \
|
| 111 |
+
# --out_dir ./outputs/eval
|
URSA/configs/ursa_0.6b_fsq320.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb:
|
| 2 |
+
run_id: null
|
| 3 |
+
|
| 4 |
+
experiment:
|
| 5 |
+
project: ursa_0.6b_fsq320
|
| 6 |
+
log_every: 20
|
| 7 |
+
save_every: 5000
|
| 8 |
+
resume_from_checkpoint: latest
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
name: "transformer"
|
| 12 |
+
gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt
|
| 13 |
+
async_timestep: true
|
| 14 |
+
tokenizer:
|
| 15 |
+
params:
|
| 16 |
+
max_length: 320
|
| 17 |
+
truncation: true
|
| 18 |
+
padding_side: left
|
| 19 |
+
padding: max_length
|
| 20 |
+
|
| 21 |
+
pipeline:
|
| 22 |
+
target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline
|
| 23 |
+
paths:
|
| 24 |
+
pretrained_path: /path/to/URSA-0.6B-FSQ320
|
| 25 |
+
module_dict:
|
| 26 |
+
vae: ${pipeline.paths.pretrained_path}/vae
|
| 27 |
+
scheduler: ${pipeline.paths.pretrained_path}/scheduler
|
| 28 |
+
tokenizer: ${pipeline.paths.pretrained_path}/tokenizer
|
| 29 |
+
model_index: ${pipeline.paths.pretrained_path}/model_index.json
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
target: torch.optim.AdamW
|
| 33 |
+
param_groups: false
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.00003
|
| 36 |
+
betas: [0.9, 0.95]
|
| 37 |
+
weight_decay: 0.05
|
| 38 |
+
fused: true
|
| 39 |
+
|
| 40 |
+
lr_scheduler:
|
| 41 |
+
target: diffnext.engine.lr_scheduler.CosineLR
|
| 42 |
+
params:
|
| 43 |
+
lr_max: ${optimizer.params.lr}
|
| 44 |
+
lr_min: 0.00001
|
| 45 |
+
max_steps: ${training.max_train_steps}
|
| 46 |
+
warmup_steps: 500
|
| 47 |
+
|
| 48 |
+
train_dataloader:
|
| 49 |
+
target: diffnext.data.flex_loaders.FeatureDataLoader
|
| 50 |
+
params:
|
| 51 |
+
dataset: /path/to/fsq320_dataset
|
| 52 |
+
batch_size: ${training.batch_size}
|
| 53 |
+
seed: ${training.seed}
|
| 54 |
+
num_workers: 4
|
| 55 |
+
shuffle: true
|
| 56 |
+
|
| 57 |
+
training:
|
| 58 |
+
gradient_accumulation_steps: 1
|
| 59 |
+
batch_size: 1 # * 256 = 256
|
| 60 |
+
max_train_steps: 20000
|
| 61 |
+
seed: 1337
|
| 62 |
+
mixed_precision: bf16
|
URSA/configs/ursa_0.6b_ibq1024.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb:
|
| 2 |
+
run_id: null
|
| 3 |
+
|
| 4 |
+
experiment:
|
| 5 |
+
project: ursa_0.6b_ibq1024
|
| 6 |
+
log_every: 20
|
| 7 |
+
save_every: 5000
|
| 8 |
+
resume_from_checkpoint: latest
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
name: "transformer"
|
| 12 |
+
gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt
|
| 13 |
+
async_timestep: false
|
| 14 |
+
tokenizer:
|
| 15 |
+
params:
|
| 16 |
+
max_length: 320
|
| 17 |
+
truncation: true
|
| 18 |
+
padding_side: left
|
| 19 |
+
padding: max_length
|
| 20 |
+
|
| 21 |
+
pipeline:
|
| 22 |
+
target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline
|
| 23 |
+
paths:
|
| 24 |
+
pretrained_path: /path/to/URSA-0.6B-IBQ1024
|
| 25 |
+
module_dict:
|
| 26 |
+
vae: ${pipeline.paths.pretrained_path}/vae
|
| 27 |
+
scheduler: ${pipeline.paths.pretrained_path}/scheduler
|
| 28 |
+
tokenizer: ${pipeline.paths.pretrained_path}/tokenizer
|
| 29 |
+
model_index: ${pipeline.paths.pretrained_path}/model_index.json
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
target: torch.optim.AdamW
|
| 33 |
+
param_groups: false
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.00003
|
| 36 |
+
betas: [0.9, 0.95]
|
| 37 |
+
weight_decay: 0.05
|
| 38 |
+
fused: true
|
| 39 |
+
|
| 40 |
+
lr_scheduler:
|
| 41 |
+
target: diffnext.engine.lr_scheduler.CosineLR
|
| 42 |
+
params:
|
| 43 |
+
lr_max: ${optimizer.params.lr}
|
| 44 |
+
lr_min: 0.00001
|
| 45 |
+
max_steps: ${training.max_train_steps}
|
| 46 |
+
warmup_steps: 500
|
| 47 |
+
|
| 48 |
+
train_dataloader:
|
| 49 |
+
target: diffnext.data.flex_loaders.FeatureDataLoader
|
| 50 |
+
params:
|
| 51 |
+
dataset: /path/to/ibq1024_dataset
|
| 52 |
+
batch_size: ${training.batch_size}
|
| 53 |
+
seed: ${training.seed}
|
| 54 |
+
num_workers: 4
|
| 55 |
+
shuffle: true
|
| 56 |
+
|
| 57 |
+
training:
|
| 58 |
+
gradient_accumulation_steps: 1
|
| 59 |
+
batch_size: 1 # * 512 = 512
|
| 60 |
+
max_train_steps: 120000
|
| 61 |
+
seed: 1337
|
| 62 |
+
mixed_precision: bf16
|
URSA/configs/ursa_1.7b_fsq320.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb:
|
| 2 |
+
run_id: null
|
| 3 |
+
|
| 4 |
+
experiment:
|
| 5 |
+
project: ursa_1.7b_fsq320
|
| 6 |
+
log_every: 20
|
| 7 |
+
save_every: 5000
|
| 8 |
+
resume_from_checkpoint: latest
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
name: "transformer"
|
| 12 |
+
gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt
|
| 13 |
+
async_timestep: true
|
| 14 |
+
tokenizer:
|
| 15 |
+
params:
|
| 16 |
+
max_length: 320
|
| 17 |
+
truncation: true
|
| 18 |
+
padding_side: left
|
| 19 |
+
padding: max_length
|
| 20 |
+
|
| 21 |
+
pipeline:
|
| 22 |
+
target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline
|
| 23 |
+
paths:
|
| 24 |
+
pretrained_path: /path/to/URSA-1.7B-FSQ320
|
| 25 |
+
module_dict:
|
| 26 |
+
vae: ${pipeline.paths.pretrained_path}/vae
|
| 27 |
+
scheduler: ${pipeline.paths.pretrained_path}/scheduler
|
| 28 |
+
tokenizer: ${pipeline.paths.pretrained_path}/tokenizer
|
| 29 |
+
model_index: ${pipeline.paths.pretrained_path}/model_index.json
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
target: torch.optim.AdamW
|
| 33 |
+
param_groups: false
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.00003
|
| 36 |
+
betas: [0.9, 0.95]
|
| 37 |
+
weight_decay: 0.05
|
| 38 |
+
fused: true
|
| 39 |
+
|
| 40 |
+
lr_scheduler:
|
| 41 |
+
target: diffnext.engine.lr_scheduler.CosineLR
|
| 42 |
+
params:
|
| 43 |
+
lr_max: ${optimizer.params.lr}
|
| 44 |
+
lr_min: 0.00001
|
| 45 |
+
max_steps: ${training.max_train_steps}
|
| 46 |
+
warmup_steps: 500
|
| 47 |
+
|
| 48 |
+
train_dataloader:
|
| 49 |
+
target: diffnext.data.flex_loaders.FeatureDataLoader
|
| 50 |
+
params:
|
| 51 |
+
dataset: /path/to/fsq320_dataset
|
| 52 |
+
batch_size: ${training.batch_size}
|
| 53 |
+
seed: ${training.seed}
|
| 54 |
+
num_workers: 4
|
| 55 |
+
shuffle: true
|
| 56 |
+
|
| 57 |
+
training:
|
| 58 |
+
gradient_accumulation_steps: 1
|
| 59 |
+
batch_size: 1 # * 256 = 256
|
| 60 |
+
max_train_steps: 20000
|
| 61 |
+
seed: 1337
|
| 62 |
+
mixed_precision: bf16
|
URSA/configs/ursa_1.7b_ibq1024.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb:
|
| 2 |
+
run_id: null
|
| 3 |
+
|
| 4 |
+
experiment:
|
| 5 |
+
project: ursa_1.7b_ibq1024
|
| 6 |
+
log_every: 20
|
| 7 |
+
save_every: 5000
|
| 8 |
+
resume_from_checkpoint: latest
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
name: "transformer"
|
| 12 |
+
gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt
|
| 13 |
+
async_timestep: false
|
| 14 |
+
tokenizer:
|
| 15 |
+
params:
|
| 16 |
+
max_length: 320
|
| 17 |
+
truncation: true
|
| 18 |
+
padding_side: left
|
| 19 |
+
padding: max_length
|
| 20 |
+
|
| 21 |
+
pipeline:
|
| 22 |
+
target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline
|
| 23 |
+
paths:
|
| 24 |
+
pretrained_path: /path/to/URSA-1.7B-IBQ1024
|
| 25 |
+
module_dict:
|
| 26 |
+
vae: ${pipeline.paths.pretrained_path}/vae
|
| 27 |
+
scheduler: ${pipeline.paths.pretrained_path}/scheduler
|
| 28 |
+
tokenizer: ${pipeline.paths.pretrained_path}/tokenizer
|
| 29 |
+
model_index: ${pipeline.paths.pretrained_path}/model_index.json
|
| 30 |
+
|
| 31 |
+
optimizer:
|
| 32 |
+
target: torch.optim.AdamW
|
| 33 |
+
param_groups: false
|
| 34 |
+
params:
|
| 35 |
+
lr: 0.00003
|
| 36 |
+
betas: [0.9, 0.95]
|
| 37 |
+
weight_decay: 0.05
|
| 38 |
+
fused: true
|
| 39 |
+
|
| 40 |
+
lr_scheduler:
|
| 41 |
+
target: diffnext.engine.lr_scheduler.CosineLR
|
| 42 |
+
params:
|
| 43 |
+
lr_max: ${optimizer.params.lr}
|
| 44 |
+
lr_min: 0.00001
|
| 45 |
+
max_steps: ${training.max_train_steps}
|
| 46 |
+
warmup_steps: 500
|
| 47 |
+
|
| 48 |
+
train_dataloader:
|
| 49 |
+
target: diffnext.data.flex_loaders.FeatureDataLoader
|
| 50 |
+
params:
|
| 51 |
+
dataset: /path/to/ibq1024_dataset
|
| 52 |
+
batch_size: ${training.batch_size}
|
| 53 |
+
seed: ${training.seed}
|
| 54 |
+
num_workers: 4
|
| 55 |
+
shuffle: true
|
| 56 |
+
|
| 57 |
+
training:
|
| 58 |
+
gradient_accumulation_steps: 1
|
| 59 |
+
batch_size: 1 # * 512 = 512
|
| 60 |
+
max_train_steps: 120000
|
| 61 |
+
seed: 1337
|
| 62 |
+
mixed_precision: bf16
|
URSA/diffnext/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""DiffNext: A diffusers based library for autoregressive diffusion models."""
|
URSA/diffnext/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (250 Bytes). View file
|
|
|
URSA/diffnext/__pycache__/image_processor.cpython-312.pyc
ADDED
|
Binary file (5.04 kB). View file
|
|
|
URSA/diffnext/data/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Data components."""
|
URSA/diffnext/data/flex_loaders.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Flex data loaders."""
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import multiprocessing as mp
|
| 20 |
+
import time
|
| 21 |
+
import threading
|
| 22 |
+
import queue
|
| 23 |
+
|
| 24 |
+
import codewithgpu
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
from diffnext.data.flex_pipelines import FeatureWorker
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BalancedQueues(object):
|
| 31 |
+
"""Balanced queues."""
|
| 32 |
+
|
| 33 |
+
def __init__(self, base_queue, num=1):
|
| 34 |
+
self.queues = [base_queue]
|
| 35 |
+
self.queues += [mp.Queue(base_queue._maxsize) for _ in range(num - 1)]
|
| 36 |
+
self.index = 0
|
| 37 |
+
|
| 38 |
+
def put(self, obj, block=True, timeout=None):
|
| 39 |
+
q = self.queues[self.index]
|
| 40 |
+
q.put(obj, block=block, timeout=timeout)
|
| 41 |
+
self.index = (self.index + 1) % len(self.queues)
|
| 42 |
+
|
| 43 |
+
def get(self, block=True, timeout=None):
|
| 44 |
+
q = self.queues[self.index]
|
| 45 |
+
obj = q.get(block=block, timeout=timeout)
|
| 46 |
+
self.index = (self.index + 1) % len(self.queues)
|
| 47 |
+
return obj
|
| 48 |
+
|
| 49 |
+
def get_n(self, num=1):
|
| 50 |
+
outputs = []
|
| 51 |
+
while len(outputs) < num:
|
| 52 |
+
obj = self.get()
|
| 53 |
+
if obj is not None:
|
| 54 |
+
outputs.append(obj)
|
| 55 |
+
return outputs
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DataLoaderBase(threading.Thread):
|
| 59 |
+
"""Base class of data loader."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, worker, **kwargs):
|
| 62 |
+
super().__init__(daemon=True)
|
| 63 |
+
self.seed = kwargs.pop("seed", 1337)
|
| 64 |
+
self.shuffle = kwargs.pop("shuffle", True)
|
| 65 |
+
self.shard_id = kwargs.get("shard_id", 0)
|
| 66 |
+
self.num_shards = kwargs.get("num_shards", 1)
|
| 67 |
+
self.batch_size = kwargs.get("batch_size", 1)
|
| 68 |
+
self.num_workers = kwargs.get("num_workers", 1)
|
| 69 |
+
self.queue_depth = kwargs.get("queue_depth", 2)
|
| 70 |
+
# Build queues.
|
| 71 |
+
self.reader_queue = mp.Queue(self.queue_depth * self.batch_size)
|
| 72 |
+
self.worker_queue = mp.Queue(self.queue_depth * self.batch_size)
|
| 73 |
+
self.batch_queue = queue.Queue(self.queue_depth)
|
| 74 |
+
self.reader_queue = BalancedQueues(self.reader_queue, self.num_workers)
|
| 75 |
+
self.worker_queue = BalancedQueues(self.worker_queue, self.num_workers)
|
| 76 |
+
# Build readers.
|
| 77 |
+
self.readers = [
|
| 78 |
+
codewithgpu.DatasetReader(
|
| 79 |
+
output_queue=self.reader_queue,
|
| 80 |
+
partition_id=self.shard_id,
|
| 81 |
+
num_partitions=self.num_shards,
|
| 82 |
+
seed=self.seed + self.shard_id,
|
| 83 |
+
shuffle=self.shuffle,
|
| 84 |
+
**kwargs,
|
| 85 |
+
)
|
| 86 |
+
]
|
| 87 |
+
self.readers[0].start()
|
| 88 |
+
time.sleep(0.1)
|
| 89 |
+
# Build workers.
|
| 90 |
+
self.workers = []
|
| 91 |
+
for i in range(self.num_workers):
|
| 92 |
+
p = worker()
|
| 93 |
+
p.seed = self.seed + i + self.shard_id * self.num_workers
|
| 94 |
+
p.reader_queue = self.reader_queue.queues[i]
|
| 95 |
+
p.worker_queue = self.worker_queue.queues[i]
|
| 96 |
+
p.start()
|
| 97 |
+
self.workers.append(p)
|
| 98 |
+
time.sleep(0.1)
|
| 99 |
+
|
| 100 |
+
# Register cleanup callbacks.
|
| 101 |
+
def cleanup():
|
| 102 |
+
def terminate(processes):
|
| 103 |
+
for p in processes:
|
| 104 |
+
p.terminate()
|
| 105 |
+
p.join()
|
| 106 |
+
|
| 107 |
+
terminate(self.workers)
|
| 108 |
+
terminate(self.readers)
|
| 109 |
+
|
| 110 |
+
import atexit
|
| 111 |
+
|
| 112 |
+
atexit.register(cleanup)
|
| 113 |
+
# Start batch prefetching.
|
| 114 |
+
self.start()
|
| 115 |
+
|
| 116 |
+
def next(self):
|
| 117 |
+
"""Return the next batch of data."""
|
| 118 |
+
return self.__next__()
|
| 119 |
+
|
| 120 |
+
def run(self):
|
| 121 |
+
"""Main loop."""
|
| 122 |
+
|
| 123 |
+
def __call__(self):
|
| 124 |
+
return self.next()
|
| 125 |
+
|
| 126 |
+
def __iter__(self):
|
| 127 |
+
"""Return the iterator self."""
|
| 128 |
+
return self
|
| 129 |
+
|
| 130 |
+
def __next__(self):
|
| 131 |
+
"""Return the next batch of data."""
|
| 132 |
+
return [self.batch_queue.get()]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class DataLoader(DataLoaderBase):
|
| 136 |
+
"""Loader to return the batch of data."""
|
| 137 |
+
|
| 138 |
+
def __init__(self, dataset, worker, **kwargs):
|
| 139 |
+
kwargs.update({"path": dataset}) # Alias for codewithgpu.
|
| 140 |
+
self.contiguous = kwargs.pop("contiguous", True)
|
| 141 |
+
self.prefetch_count = kwargs.pop("prefetch_count", 50)
|
| 142 |
+
super().__init__(worker, **kwargs)
|
| 143 |
+
|
| 144 |
+
def run(self):
|
| 145 |
+
"""Main loop."""
|
| 146 |
+
prev_inputs = self.worker_queue.get_n(self.prefetch_count * self.batch_size)
|
| 147 |
+
next_inputs = []
|
| 148 |
+
while True:
|
| 149 |
+
# Use cached buffer for next N inputs.
|
| 150 |
+
if len(next_inputs) == 0:
|
| 151 |
+
next_inputs = prev_inputs
|
| 152 |
+
prev_inputs = []
|
| 153 |
+
# Collect the next batch.
|
| 154 |
+
outputs = collections.defaultdict(list)
|
| 155 |
+
for _ in range(self.batch_size):
|
| 156 |
+
inputs = next_inputs.pop(0)
|
| 157 |
+
for k, v in inputs.items():
|
| 158 |
+
outputs[k].extend(v)
|
| 159 |
+
prev_inputs += self.worker_queue.get_n(1)
|
| 160 |
+
# Stack batch data.
|
| 161 |
+
if self.contiguous:
|
| 162 |
+
if "latents" in outputs:
|
| 163 |
+
outputs["latents"] = np.stack(outputs["latents"])
|
| 164 |
+
# Send batch data to consumer.
|
| 165 |
+
self.batch_queue.put(outputs)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class FeatureDataLoader(DataLoader):
|
| 169 |
+
"""Loader to return the batch of data features."""
|
| 170 |
+
|
| 171 |
+
def __init__(self, dataset, **kwargs):
|
| 172 |
+
super().__init__(dataset, FeatureWorker, **kwargs)
|
URSA/diffnext/data/flex_pipelines.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Flex data pipelines."""
|
| 17 |
+
|
| 18 |
+
import multiprocessing
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import numpy.random as npr
|
| 22 |
+
|
| 23 |
+
from diffnext.data import flex_transforms
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Worker(multiprocessing.Process):
|
| 27 |
+
"""Base data worker."""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super().__init__(daemon=True)
|
| 31 |
+
self.seed = 1337
|
| 32 |
+
self.reader_queue = None
|
| 33 |
+
self.worker_queue = None
|
| 34 |
+
|
| 35 |
+
def run(self):
|
| 36 |
+
"""Run implementation."""
|
| 37 |
+
# Disable opencv threading and fix numpy random seed.
|
| 38 |
+
cv2.setNumThreads(1), npr.seed(self.seed)
|
| 39 |
+
while True: # Main loop.
|
| 40 |
+
self.worker_queue.put(self.get_outputs(self.reader_queue.get()))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class FeaturePipe(object):
|
| 44 |
+
"""Pipeline to transform data features."""
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.parse_latents = flex_transforms.ParseLatents()
|
| 49 |
+
self.parse_annotations = flex_transforms.ParseAnnotations()
|
| 50 |
+
|
| 51 |
+
def get_outputs(self, inputs):
|
| 52 |
+
"""Return the outputs."""
|
| 53 |
+
latents = self.parse_latents(inputs)
|
| 54 |
+
label, caption = self.parse_annotations(inputs)
|
| 55 |
+
outputs = {"latents": [latents]}
|
| 56 |
+
outputs.setdefault("prompt", [label]) if label is not None else None
|
| 57 |
+
outputs.setdefault("prompt", [caption]) if caption is not None else None
|
| 58 |
+
outputs.setdefault("motion", [inputs["flow"]]) if "flow" in inputs else None
|
| 59 |
+
return outputs
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FeatureWorker(FeaturePipe, Worker):
|
| 63 |
+
"""Worker to transform data features."""
|
URSA/diffnext/data/flex_transforms.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Flex data transforms."""
|
| 17 |
+
|
| 18 |
+
import re
|
| 19 |
+
import numpy as np
|
| 20 |
+
import numpy.random as npr
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Transform(object):
|
| 24 |
+
"""Base transform type."""
|
| 25 |
+
|
| 26 |
+
def filter_outputs(self, *outputs):
|
| 27 |
+
outputs = [x for x in outputs if x is not None]
|
| 28 |
+
return outputs if len(outputs) > 1 else outputs[0]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ParseLatents(Transform):
|
| 32 |
+
"""Parse VQ or VAE latents."""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
def __call__(self, inputs):
|
| 38 |
+
for k, dtype in zip(("moments", "codes"), ("float16", "int32")):
|
| 39 |
+
if k in inputs:
|
| 40 |
+
return np.frombuffer(inputs[k], dtype).reshape(inputs["shape"])
|
| 41 |
+
raise ValueError("Missing latents in inputs.")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ParseAnnotations(Transform):
|
| 45 |
+
"""Parse ground-truth annotations."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, short_prob=0.5):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.short_prob = short_prob
|
| 50 |
+
|
| 51 |
+
def __call__(self, inputs):
|
| 52 |
+
text = inputs.get("text", None)
|
| 53 |
+
label = inputs.get("label", None)
|
| 54 |
+
caption = inputs.get("caption", None)
|
| 55 |
+
if caption and isinstance(caption, dict): # Cached.
|
| 56 |
+
caption = np.frombuffer(caption["data"], "float16").reshape(caption["shape"])
|
| 57 |
+
if text and isinstance(text, dict) and len(text["data"]) > 0 and npr.rand() < 0.5:
|
| 58 |
+
caption = np.frombuffer(text["data"], "float16").reshape(text["shape"])
|
| 59 |
+
return label, caption
|
| 60 |
+
|
| 61 |
+
# Improved short caption.
|
| 62 |
+
if label is None:
|
| 63 |
+
text_match = re.match(r"^(.*?[.!?])\s+", caption)
|
| 64 |
+
text = text if text else (text_match.group(1) if text_match else caption)
|
| 65 |
+
caption = text if text and npr.rand() < self.short_prob else caption
|
| 66 |
+
return label, caption
|
URSA/diffnext/engine/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Engine components."""
|
URSA/diffnext/engine/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
URSA/diffnext/engine/__pycache__/engine_utils.cpython-312.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
URSA/diffnext/engine/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
|
Binary file (4.39 kB). View file
|
|
|
URSA/diffnext/engine/engine_utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Engine utilities."""
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import pickle
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def count_params(module, trainable=True, unit="M"):
|
| 27 |
+
"""Return the number of parameters."""
|
| 28 |
+
counts = [v.size().numel() for v in module.parameters() if v.requires_grad or (not trainable)]
|
| 29 |
+
return sum(counts) / {"M": 1e6, "B": 1e9}[unit]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def freeze_module(module, trainable=False):
|
| 33 |
+
"""Freeze parameters of given module."""
|
| 34 |
+
module.eval() if not trainable else module.train()
|
| 35 |
+
for param in module.parameters():
|
| 36 |
+
param.requires_grad = trainable
|
| 37 |
+
return module
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_device(index):
|
| 41 |
+
"""Create the available device object."""
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
return torch.device("cuda", index)
|
| 44 |
+
for device_type in ("mps",):
|
| 45 |
+
try:
|
| 46 |
+
if getattr(torch.backends, device_type).is_available():
|
| 47 |
+
return torch.device(device_type, index)
|
| 48 |
+
except AttributeError:
|
| 49 |
+
pass
|
| 50 |
+
return torch.device("cpu")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_param_groups(model):
|
| 54 |
+
"""Separate parameters into groups."""
|
| 55 |
+
memo, groups, lr_scale_getter = set(), collections.OrderedDict(), None
|
| 56 |
+
norm_types = (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm, nn.LayerNorm)
|
| 57 |
+
for module_name, module in model.named_modules():
|
| 58 |
+
for param_name, param in module.named_parameters(recurse=False):
|
| 59 |
+
if not param.requires_grad or param in memo:
|
| 60 |
+
continue
|
| 61 |
+
memo.add(param)
|
| 62 |
+
attrs = collections.OrderedDict()
|
| 63 |
+
if lr_scale_getter:
|
| 64 |
+
attrs["lr_scale"] = lr_scale_getter(f"{module_name}.{param_name}")
|
| 65 |
+
if hasattr(param, "lr_scale"):
|
| 66 |
+
attrs["lr_scale"] = param.lr_scale
|
| 67 |
+
if getattr(param, "no_weight_decay", False) or isinstance(module, norm_types):
|
| 68 |
+
attrs["weight_decay"] = 0
|
| 69 |
+
group_name = "/".join(["%s:%s" % (v[0], v[1]) for v in list(attrs.items())])
|
| 70 |
+
groups[group_name] = groups.get(group_name, {**attrs, **{"params": []}})
|
| 71 |
+
groups[group_name]["params"].append(param)
|
| 72 |
+
return list(groups.values())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_weights(module, weights_file, prefix_removed="", strict=True):
|
| 76 |
+
"""Load a weights file."""
|
| 77 |
+
if not weights_file:
|
| 78 |
+
return
|
| 79 |
+
if weights_file.endswith(".pkl"):
|
| 80 |
+
with open(weights_file, "rb") as f:
|
| 81 |
+
state_dict = pickle.load(f)
|
| 82 |
+
for k, v in state_dict.items():
|
| 83 |
+
state_dict[k] = torch.as_tensor(v)
|
| 84 |
+
else:
|
| 85 |
+
state_dict = torch.load(weights_file, map_location="cpu", weights_only=False)
|
| 86 |
+
if prefix_removed:
|
| 87 |
+
new_state_dict = type(state_dict)()
|
| 88 |
+
for k in list(state_dict.keys()):
|
| 89 |
+
if k.startswith(prefix_removed):
|
| 90 |
+
new_state_dict[k.replace(prefix_removed, "")] = state_dict.pop(k)
|
| 91 |
+
state_dict = new_state_dict
|
| 92 |
+
module.load_state_dict(state_dict, strict=strict)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def manual_seed(seed, device_and_seed=None):
|
| 96 |
+
"""Set the cpu and device random seed."""
|
| 97 |
+
torch.manual_seed(seed)
|
| 98 |
+
if device_and_seed is not None:
|
| 99 |
+
device_index, device_seed = device_and_seed
|
| 100 |
+
device_type = get_device(device_index).type
|
| 101 |
+
np.random.seed(device_seed)
|
| 102 |
+
if device_type in ("cuda", "mps"):
|
| 103 |
+
getattr(torch, device_type).manual_seed(device_seed)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def synchronize_device(device):
|
| 107 |
+
"""Synchronize the computation of device."""
|
| 108 |
+
if device.type in ("cuda", "mps"):
|
| 109 |
+
getattr(torch, device.type).synchronize(device)
|