Diffusers
Safetensors
Longin-Yu commited on
Commit
bde4d05
·
verified ·
1 Parent(s): 33d1fe1

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OmniAlpha: A Sequence-to-Sequence Framework for Unified Multi-Task RGBA Generation
2
+
3
+ <p align="center">
4
+ <a href="https://github.com/Longin-Yu/OmniAlpha"><img src="https://img.shields.io/badge/GitHub-OmniAlpha-181717.svg?logo=github" alt="GitHub"></a>
5
+ <a href="https://arxiv.org/abs/2511.20211"><img src="https://img.shields.io/badge/arXiv-2511.20211-b31b1b.svg" alt="arXiv"></a>
6
+ <a href="https://huggingface.co/Longin-Yu/OmniAlpha"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow" alt="Hugging Face"></a>
7
+ </p>
8
+
9
+ ---
10
+
11
+ **This is the official repository for "[OmniAlpha: A Sequence-to-Sequence Framework for Unified Multi-Task RGBA Generation](https://arxiv.org/abs/2511.20211)".**
12
+
13
+ ![examples](assets/examples_01.png)
14
+
15
+ ---
16
+
17
+ ## 📂 Project Structure
18
+
19
+ ```
20
+ .
21
+ ├── alpha/ # Core package
22
+ │ ├── data.py # Dataset loading & preprocessing
23
+ │ ├── args.py # Argument definitions
24
+ │ ├── inplace.py # In-place operations
25
+ │ ├── pipelines/ # Inference pipelines (Qwen-Image-Edit)
26
+ │ ├── vae/ # AlphaVAE model & losses
27
+ │ ├── grpo/ # GRPO (RL) training utilities
28
+ │ └── utils/ # Utility functions
29
+ ├── configs/ # Configuration files
30
+ │ ├── datasets.*.jsonc # Dataset configurations
31
+ │ ├── deepspeed/ # DeepSpeed configs (ZeRO-1/3)
32
+ │ ├── experiments/ # VAE experiment configs
33
+ │ └── accelerate.yaml # Accelerate config
34
+ ├── scripts/ # Bash scripts for training/inference
35
+ │ ├── train_qwen_image.sh # Single-node training (Accelerate)
36
+ │ ├── train_qwen_image_torchrun.sh # Multi-node training (torchrun)
37
+ │ ├── vae_convert.sh # VAE conversion script
38
+ │ ├── vae_train.sh # VAE fine-tuning script
39
+ │ ├── infer.sh # Inference script
40
+ │ ├── demo.sh # Gradio demo script
41
+ │ └── rl/ # GRPO reinforcement learning scripts
42
+ ├── tasks/ # Python/Jupyter task scripts
43
+ │ ├── diffusion/ # Diffusion training & inference
44
+ │ ├── vae/ # VAE fine-tuning, conversion & inference
45
+ │ ├── rl/ # GRPO RL training & preprocessing
46
+ │ └── demo/ # Gradio demo application
47
+ └── pyproject.toml # Package definitions & dependencies
48
+ ```
49
+
50
+ ## 📦 Installation
51
+
52
+ ### Step 1. Create a Conda Environment
53
+
54
+ ```bash
55
+ conda create -n OmniAlpha python=3.10
56
+ conda activate OmniAlpha
57
+ ```
58
+
59
+ ### Step 2. Install OmniAlpha
60
+
61
+ First clone this repo and `cd OmniAlpha`. Then:
62
+
63
+ ```bash
64
+ # Install OmniAlpha and all dependencies
65
+ pip install -e .
66
+ ```
67
+
68
+ ## ⚙️ Environment Variables
69
+
70
+ All scripts use environment variables to specify model/data paths. Set these before running any script:
71
+
72
+ ```bash
73
+ # Model paths
74
+ export PRETRAINED_MODEL="Qwen/Qwen-Image-Edit-2509" # HuggingFace model ID or local path
75
+ export VAE_MODEL_PATH="/path/to/vae/checkpoint" # Path to AlphaVAE checkpoint
76
+ export LORA_PATH="/path/to/lora/pytorch_lora_weights.safetensors" # Path to LoRA weights
77
+
78
+ # Data paths
79
+ export DATA_ROOT="/path/to/datasets" # Root directory for all datasets
80
+ ```
81
+
82
+ If not set, scripts will fall back to placeholder paths and you will need to edit them manually.
83
+
84
+ ## 📄 Data Preparation
85
+
86
+ > Please refer to `configs/datasets.demo.jsonc` for dataset configuration examples.
87
+ > Each dataset entry consists of two required fields:
88
+ >
89
+ > * `data_path`: Path to the JSONL annotation file.
90
+ > * `image_dir`: Root directory for the dataset images.
91
+
92
+ ### Dataset Format
93
+
94
+ The annotation file (`data_path`) should be a JSONL file with the following structure. Both `input_images` and `output_images` must be **relative paths** within `image_dir`:
95
+
96
+ ```jsonl
97
+ {"id": "case_0", "prompt": "Vintage camera next to a brown glass bottle.", "input_images": ["images_512/case_0/base.png"], "output_images": ["images_512/case_0/00.png"]}
98
+ {"id": "case_1", "prompt": "A vintage-style globe with a map of North and South America, mounted on a black stand.;Antique key with ornate design, attached to a chain.", "input_images": ["images_512/case_1/base.png"], "output_images": ["images_512/case_1/00.png", "images_512/case_1/01.png"]}
99
+ ...
100
+ ```
101
+
102
+ ### Dataset Configuration
103
+
104
+ Create a `.jsonc` config file under `configs/` to define datasets and splits:
105
+
106
+ ```jsonc
107
+ {
108
+ "datasets": {
109
+ "my_dataset": {
110
+ "data_path": "/path/to/datasets/my_dataset/annotations.jsonl",
111
+ "image_dir": "/path/to/datasets/my_dataset"
112
+ }
113
+ },
114
+ "splits": {
115
+ "train": [{"dataset": "my_dataset", "ends": -50}],
116
+ "valid": [{"dataset": "my_dataset", "starts": -50}]
117
+ }
118
+ }
119
+ ```
120
+
121
+ ## 🔽 Model Download
122
+
123
+ [Pretrained model checkpoints are available on Hugging Face.](https://huggingface.co/Longin-Yu/OmniAlpha)
124
+
125
+ ## 🚀 Inference
126
+
127
+ You can use the provided script to run inference with pretrained models.
128
+
129
+ 1. **Configure**: Set environment variables (`PRETRAINED_MODEL`, `VAE_MODEL_PATH`, `LORA_PATH`) or edit `scripts/infer.sh` directly.
130
+ 2. **Execute**:
131
+
132
+ ```bash
133
+ bash scripts/infer.sh
134
+ ```
135
+
136
+ ## 🎬 Demo
137
+
138
+ We provide a Gradio-based demo for interactive multi-task RGBA generation and editing.
139
+
140
+ ### Supported Tasks
141
+
142
+ - `t2i` — Text-to-RGBA image generation
143
+ - `ObjectClear` — Object removal
144
+ - `automatting` — Automatic matting
145
+ - `refmatting` — Referential matting
146
+ - `layerdecompose` — Layer decomposition
147
+
148
+ ### Execute
149
+
150
+ ```bash
151
+ # Set model paths
152
+ export PRETRAINED_MODEL="Qwen/Qwen-Image-Edit-2509"
153
+ export VAE_MODEL_PATH="/path/to/models/OmniAlpha/rgba_vae"
154
+ export LORA_PATH="/path/to/models/OmniAlpha/lora/pytorch_lora_weights.safetensors"
155
+
156
+ # Launch demo
157
+ bash scripts/demo.sh
158
+ ```
159
+
160
+ ### Example Assets
161
+
162
+ Demo example images are placed in `tasks/demo/omnialpha/`.
163
+
164
+ ## 🏋️ Training
165
+
166
+ ### AlphaVAE Fine-tuning
167
+
168
+ ```bash
169
+ # Step 1: Convert the base VAE to RGBA format
170
+ bash scripts/vae_convert.sh
171
+
172
+ # Step 2: Fine-tune the AlphaVAE
173
+ bash scripts/vae_train.sh
174
+ ```
175
+
176
+ ### LoRA Training (Single-Node with Accelerate)
177
+
178
+ ```bash
179
+ bash scripts/train_qwen_image.sh
180
+ ```
181
+
182
+ ### LoRA Training (Multi-Node with torchrun)
183
+
184
+ For distributed training across multiple nodes:
185
+
186
+ ```bash
187
+ # Set distributed training variables
188
+ export MASTER_ADDR="your_master_ip"
189
+ export MASTER_PORT=29500
190
+ export NNODES=2
191
+ export NPROC_PER_NODE=8
192
+ export MACHINE_RANK=0 # 0 for master, 1 for worker, etc.
193
+ export VERSION="omnialpha" # Matches configs/datasets.<VERSION>.jsonc
194
+
195
+ bash scripts/train_qwen_image_torchrun.sh
196
+ ```
197
+
198
+ ### GRPO Reinforcement Learning
199
+
200
+ For RL-based fine-tuning:
201
+
202
+ ```bash
203
+ # Run GRPO training
204
+ bash scripts/rl/train_grpo.sh
205
+ # Or for multi-node:
206
+ bash scripts/rl/train_grpo_torchrun.sh
207
+ ```
208
+
209
+ ## 🔗 Contact
210
+
211
+ Feel free to reach out via email at longinyh@gmail.com. You can also open an issue if you have ideas to share or would like to contribute data for training future models.
212
+
213
+ ## Citation
214
+
215
+ ```bibtex
216
+ @article{yu2025omnialpha0,
217
+ title = {OmniAlpha: A Sequence-to-Sequence Framework for Unified Multi-Task RGBA Generation},
218
+ author = {Hao Yu and Jiabo Zhan and Zile Wang and Jinglin Wang and Huaisong Zhang and Hongyu Li and Xinrui Chen and Yongxian Wei and Chun Yuan},
219
+ year = {2025},
220
+ journal = {arXiv preprint arXiv: 2511.20211}
221
+ }
222
+ @misc{wang2025alphavaeunifiedendtoendrgba,
223
+ title={AlphaVAE: Unified End-to-End RGBA Image Reconstruction and Generation with Alpha-Aware Representation Learning},
224
+ author={Zile Wang and Hao Yu and Jiabo Zhan and Chun Yuan},
225
+ year={2025},
226
+ eprint={2507.09308},
227
+ archivePrefix={arXiv},
228
+ primaryClass={cs.CV},
229
+ url={https://arxiv.org/abs/2507.09308},
230
+ }
231
+ ```
rl/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c975ed2a989afe87f0fffd4176798a993cf6ce698a97559eb355e232d6baef5
3
+ size 1510080864
sft/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0cc86bd930efbc53cbc6b807cb1a9eecb9b9e3640ba88d2df2bd9877d6d4b36
3
+ size 1510080864
vae_rgba/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLQwenImageAlpha",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "_name_or_path": "Qwen/Qwen-Image-Edit-2509/vae_rgba",
5
+ "attn_scales": [],
6
+ "base_dim": 96,
7
+ "dim_mult": [
8
+ 1,
9
+ 2,
10
+ 4,
11
+ 4
12
+ ],
13
+ "dropout": 0.0,
14
+ "in_channels": 4,
15
+ "latents_mean": [
16
+ -0.7571,
17
+ -0.7089,
18
+ -0.9113,
19
+ 0.1075,
20
+ -0.1745,
21
+ 0.9653,
22
+ -0.1517,
23
+ 1.5508,
24
+ 0.4134,
25
+ -0.0715,
26
+ 0.5517,
27
+ -0.3632,
28
+ -0.1922,
29
+ -0.9497,
30
+ 0.2503,
31
+ -0.2921
32
+ ],
33
+ "latents_std": [
34
+ 2.8184,
35
+ 1.4541,
36
+ 2.3275,
37
+ 2.6558,
38
+ 1.2196,
39
+ 1.7708,
40
+ 2.6052,
41
+ 2.0743,
42
+ 3.2687,
43
+ 2.1526,
44
+ 2.8652,
45
+ 1.5579,
46
+ 1.6382,
47
+ 1.1253,
48
+ 2.8251,
49
+ 1.916
50
+ ],
51
+ "num_res_blocks": 2,
52
+ "out_channels": 4,
53
+ "temperal_downsample": [
54
+ false,
55
+ true,
56
+ true
57
+ ],
58
+ "z_dim": 16
59
+ }
vae_rgba/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72ef5e35bab4a743d74d64d4fedd1dfdbe219dea5f57961f4abb5fa073ff0d5
3
+ size 253817336