Text-to-Video
SelfForcing
Jashan887 commited on
Commit
7029662
·
verified ·
1 Parent(s): 6319da8

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vidprom_filtered_extended.txt filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: self-forcing
3
+ license: apache-2.0
4
+ pipeline_tag: text-to-video
5
+ ---
6
+ <p align="center"><h1 align="center">
7
+ Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion
8
+ </h1>
9
+ </p>
10
+
11
+ <p align="center">
12
+ <h3 align="center"><a href="https://arxiv.org/abs/2506.08009">Paper</a> | <a href="https://self-forcing.github.io">Website</a> | <a href="https://huggingface.co/gdhe17/Self-Forcing/tree/main">Models (HuggingFace)</a> | <a href="https://github.com/guandeh17/Self-Forcing">Code</a></h3>
13
+ </p>
14
+
15
+ ---
16
+
17
+ Self Forcing trains autoregressive video diffusion models by **simulating the inference process during training**, performing autoregressive rollout with KV caching. It resolves the train-test distribution mismatch and enables **real-time, streaming video generation on a single RTX 4090** while matching the quality of state-of-the-art diffusion models.
18
+
19
+ ---
20
+
21
+ ## Requirements
22
+ We tested this repo on the following setup:
23
+ * Nvidia GPU with at least 24 GB memory (RTX 4090, A100, and H100 are tested).
24
+ * Linux operating system.
25
+ * 64 GB RAM.
26
+
27
+ Other hardware setup could also work but hasn't been tested.
28
+
29
+ ## Installation
30
+ Create a conda environment and install dependencies:
31
+ ```
32
+ conda create -n self_forcing python=3.10 -y
33
+ conda activate self_forcing
34
+ pip install -r requirements.txt
35
+ pip install flash-attn --no-build-isolation
36
+ python setup.py develop
37
+ ```
38
+
39
+ ## Quick Start
40
+ ### Download checkpoints
41
+ ```
42
+ huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B
43
+ huggingface-cli download gdhe17/Self-Forcing checkpoints/self_forcing_dmd.pt --local-dir .
44
+ ```
45
+
46
+ ### GUI demo
47
+ ```
48
+ python demo.py
49
+ ```
50
+ Note:
51
+ * **Our model works better with long, detailed prompts** since it's trained with such prompts. We will integrate prompt extension into the codebase (similar to [Wan2.1](https://github.com/Wan-Video/Wan2.1/tree/main?tab=readme-ov-file#2-using-prompt-extention)) in the future. For now, it is recommended to use third-party LLMs (such as GPT-4o) to extend your prompt before providing to the model.
52
+ * You may want to adjust FPS so it plays smoothly on your device.
53
+ * The speed can be improved by enabling `torch.compile`, [TAEHV-VAE](https://github.com/madebyollin/taehv/), or using FP8 Linear layers, although the latter two options may sacrifice quality. It is recommended to use `torch.compile` if possible and enable TAEHV-VAE if further speedup is needed.
54
+
55
+ ### CLI Inference
56
+ Example inference script using the chunk-wise autoregressive checkpoint trained with DMD:
57
+ ```
58
+ python inference.py \
59
+ --config_path configs/self_forcing_dmd.yaml \
60
+ --output_folder videos/self_forcing_dmd \
61
+ --checkpoint_path checkpoints/self_forcing_dmd.pt \
62
+ --data_path prompts/MovieGenVideoBench_extended.txt \
63
+ --use_ema
64
+ ```
65
+ Other config files and corresponding checkpoints can be found in [configs](configs) folder and our [huggingface repo](https://huggingface.co/gdhe17/Self-Forcing/tree/main/checkpoints).
66
+
67
+ ## Training
68
+ ### Download text prompts and ODE initialized checkpoint
69
+ ```
70
+ huggingface-cli download gdhe17/Self-Forcing checkpoints/ode_init.pt --local-dir .
71
+ huggingface-cli download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
72
+ ```
73
+ Note: Our training algorithm (except for the GAN version) is data-free (**no video data is needed**). For now, we directly provide the ODE initialization checkpoint and will add more instructions on how to perform ODE initialization in the future (which is identical to the process described in the [CausVid](https://github.com/tianweiy/CausVid) repo).
74
+
75
+ ### Self Forcing Training with DMD
76
+ ```
77
+ torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
78
+ --rdzv_backend=c10d \
79
+ --rdzv_endpoint $MASTER_ADDR \
80
+ train.py \
81
+ --config_path configs/self_forcing_dmd.yaml \
82
+ --logdir logs/self_forcing_dmd \
83
+ --disable-wandb
84
+ ```
85
+ Our training run uses 600 iterations and completes in under 2 hours using 64 H100 GPUs. By implementing gradient accumulation, it should be possible to reproduce the results in less than 16 hours using 8 H100 GPUs.
86
+
87
+ ## Acknowledgements
88
+ This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid) by [Tianwei Yin](https://tianweiy.github.io/) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
89
+
90
+ ## Citation
91
+ If you find this codebase useful for your research, please kindly cite our paper:
92
+ ```
93
+ @article{huang2025selfforcing,
94
+ title={Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion},
95
+ author={Huang, Xun and Li, Zhengqi and He, Guande and Zhou, Mingyuan and Shechtman, Eli},
96
+ journal={arXiv preprint arXiv:2506.08009},
97
+ year={2025}
98
+ }
99
+ ```
checkpoints/ode_init.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5396b8076ab3b920c9e4f4a2b52daa2c98c9983fb5e067ae5160fdf778dce21
3
+ size 5676203690
checkpoints/self_forcing_10s.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7885ec2df188ce05e63fc768aecf9a95b5672dbc64e82ea6e2a48751f67dc11
3
+ size 11352514125
checkpoints/self_forcing_dmd.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0413986d9734e02c09504e1520f5697ba6df731bb2f0f35577485e9cc8f56a3
3
+ size 5676252553
checkpoints/self_forcing_gan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4005372c17c4dfc27ceae8c951c69bc22256d0b67eede2da234b382fba1d8a2f
3
+ size 5676252553
checkpoints/self_forcing_sid.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb354a0af5775c6316a6805db8ace9b6093190d4d07f93fd83f3a88c5ad49b19
3
+ size 5676252553
checkpoints/self_forcing_sid_v2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b98050d304569c1880e46f9b05d375d8b1370882aa72f0d53b068f3f03909b40
3
+ size 5676255365
vidprom_filtered_extended.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7896742f468bc8aef9e4547424d1ce0a951acdb2a82233790155401a99bf5aa5
3
+ size 145875068