Enzo8930302 commited on
Commit
80b58c8
·
verified ·
1 Parent(s): f0598d9

Upload folder using huggingface_hub

Browse files
Files changed (22) hide show
  1. .gitignore +73 -0
  2. LICENSE +21 -0
  3. README.md +182 -3
  4. SETUP_GUIDE.md +262 -0
  5. app.py +305 -0
  6. bytedream/__init__.py +21 -0
  7. bytedream/generator.py +317 -0
  8. bytedream/model.py +582 -0
  9. bytedream/pipeline.py +312 -0
  10. bytedream/scheduler.py +273 -0
  11. bytedream/utils.py +398 -0
  12. config.yaml +81 -0
  13. environment.yml +25 -0
  14. examples.py +316 -0
  15. infer.py +150 -0
  16. main.py +278 -0
  17. prepare_dataset.py +287 -0
  18. publish_to_hf.py +30 -0
  19. quick_start.py +124 -0
  20. requirements.txt +16 -0
  21. train.py +500 -0
  22. upload_to_hf.py +420 -0
.gitignore ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Jupyter Notebook
37
+ .ipynb_checkpoints
38
+
39
+ # PyTorch
40
+ *.pth
41
+ *.onnx
42
+
43
+ # Model checkpoints
44
+ models/
45
+ checkpoints/
46
+ *.bin
47
+ *.safetensors
48
+
49
+ # Outputs
50
+ outputs/
51
+ demo_outputs/
52
+ *.png
53
+ *.jpg
54
+ *.jpeg
55
+ *.webp
56
+
57
+ # Logs
58
+ logs/
59
+ *.log
60
+
61
+ # OS
62
+ .DS_Store
63
+ Thumbs.db
64
+ desktop.ini
65
+
66
+ # Temporary files
67
+ tmp/
68
+ temp/
69
+ *.tmp
70
+
71
+ # Hugging Face cache
72
+ .huggingface/
73
+ .cache/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Byte Dream
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,182 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte Dream - AI Image Generation Model
2
+
3
+ ## Overview
4
+ Byte Dream is a robust, production-ready text-to-image diffusion model optimized for CPU inference. This model uses advanced latent diffusion architecture to generate high-quality images from text prompts.
5
+
6
+ ## Features
7
+ - **CPU Optimized**: Runs efficiently on CPU without GPU requirement
8
+ - **High Quality**: Generates 512x512 and higher resolution images
9
+ - **Fast Inference**: Optimized for speed with quality preservation
10
+ - **Hugging Face Ready**: Easy deployment to Hugging Face Spaces
11
+ - **Flexible**: Supports various sampling methods and customization
12
+
13
+ ## Installation
14
+
15
+ ### Using pip
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ ### Using conda
21
+ ```bash
22
+ conda env create -f environment.yml
23
+ conda activate bytedream
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ ### Basic Image Generation
29
+ ```python
30
+ from bytedream import ByteDreamGenerator
31
+
32
+ # Initialize generator
33
+ generator = ByteDreamGenerator()
34
+
35
+ # Generate image from prompt
36
+ image = generator.generate(
37
+ prompt="A beautiful sunset over mountains, digital art",
38
+ num_inference_steps=50,
39
+ guidance_scale=7.5
40
+ )
41
+
42
+ # Save image
43
+ image.save("output.png")
44
+ ```
45
+
46
+ ### Advanced Usage
47
+ ```python
48
+ from bytedream import ByteDreamGenerator
49
+
50
+ generator = ByteDreamGenerator(model_path="models/bytedream")
51
+
52
+ # Generate with custom parameters
53
+ image = generator.generate(
54
+ prompt="Cyberpunk city at night, neon lights, futuristic",
55
+ negative_prompt="blurry, low quality, distorted",
56
+ width=768,
57
+ height=768,
58
+ num_inference_steps=100,
59
+ guidance_scale=9.0,
60
+ seed=42
61
+ )
62
+
63
+ image.save("cyberpunk_city.png")
64
+ ```
65
+
66
+ ### Command Line Interface
67
+ ```bash
68
+ # Generate image from command line
69
+ python infer.py --prompt "A dragon flying over castle" --output dragon.png
70
+
71
+ # With advanced options
72
+ python infer.py --prompt "Fantasy landscape" --negative "ugly, blurry" --steps 75 --guidance 8.0
73
+ ```
74
+
75
+ ### Gradio Web Interface
76
+ ```bash
77
+ python app.py
78
+ ```
79
+
80
+ ## Model Architecture
81
+
82
+ Byte Dream uses a latent diffusion model with:
83
+ - **Text Encoder**: CLIP-based text understanding
84
+ - **UNet**: Noise prediction network with cross-attention
85
+ - **VAE**: Variational Autoencoder for image compression
86
+ - **Scheduler**: Advanced DDIM/PNDM sampling
87
+
88
+ ## Training
89
+
90
+ ### Prepare Dataset
91
+ ```bash
92
+ python prepare_dataset.py --data_dir ./dataset --output_dir ./processed_data
93
+ ```
94
+
95
+ ### Train Model
96
+ ```bash
97
+ python train.py \
98
+ --train_data ./processed_data \
99
+ --output_dir ./models/bytedream \
100
+ --epochs 100 \
101
+ --batch_size 4 \
102
+ --learning_rate 1e-5
103
+ ```
104
+
105
+ ## Hugging Face Deployment
106
+
107
+ ### Upload to Hugging Face
108
+ ```bash
109
+ python upload_to_hf.py --model_id your_username/bytedream
110
+ ```
111
+
112
+ ### Deploy to Spaces
113
+ 1. Create new Space on Hugging Face
114
+ 2. Select Gradio SDK
115
+ 3. Upload all files
116
+ 4. Configure CPU hardware
117
+ 5. Deploy automatically
118
+
119
+ ## Configuration
120
+
121
+ Edit `config.yaml` for custom settings:
122
+ - Model dimensions
123
+ - Sampling parameters
124
+ - Training hyperparameters
125
+ - CPU optimization settings
126
+
127
+ ## Performance Optimization
128
+
129
+ ### CPU Optimization
130
+ - OpenVINO integration available
131
+ - ONNX runtime support
132
+ - Mixed precision (FP16/FP32)
133
+ - Batch processing
134
+
135
+ ### Memory Management
136
+ - Gradient checkpointing
137
+ - Model offloading
138
+ - Progressive generation
139
+
140
+ ## File Structure
141
+ ```
142
+ Byte Dream/
143
+ ├── bytedream/ # Core package
144
+ │ ├── __init__.py
145
+ │ ├── model.py # Model architecture
146
+ │ ├── pipeline.py # Generation pipeline
147
+ │ ├── scheduler.py # Diffusion scheduler
148
+ │ └── utils.py # Utilities
149
+ ├── train.py # Training script
150
+ ├── infer.py # Inference script
151
+ ├── app.py # Gradio web interface
152
+ ├── config.yaml # Configuration
153
+ ├── requirements.txt # Dependencies
154
+ └── README.md # This file
155
+ ```
156
+
157
+ ## Examples
158
+
159
+ Generate various types of images:
160
+ - Digital art and illustrations
161
+ - Photorealistic scenes
162
+ - Abstract concepts
163
+ - Character designs
164
+ - Landscapes and environments
165
+
166
+ ## License
167
+
168
+ MIT License - See LICENSE file for details
169
+
170
+ ## Citation
171
+
172
+ If you use Byte Dream in your research:
173
+ ```bibtex
174
+ @software{bytedream2024,
175
+ title={Byte Dream: CPU-Optimized Text-to-Image Generation},
176
+ year={2024}
177
+ }
178
+ ```
179
+
180
+ ## Support
181
+
182
+ For issues and questions, please open a GitHub issue or contact the maintainers.
SETUP_GUIDE.md ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte Dream - Setup Guide
2
+
3
+ ## Quick Start (Windows)
4
+
5
+ ### 1. Install Dependencies
6
+
7
+ #### Option A: Using pip (Recommended)
8
+ ```cmd
9
+ cd "c:\Users\Enzo\Documents\Byte Dream"
10
+ pip install -r requirements.txt
11
+ ```
12
+
13
+ #### Option B: Using conda
14
+ ```cmd
15
+ cd "c:\Users\Enzo\Documents\Byte Dream"
16
+ conda env create -f environment.yml
17
+ conda activate bytedream
18
+ ```
19
+
20
+ ### 2. Verify Installation
21
+ ```cmd
22
+ python quick_start.py
23
+ ```
24
+
25
+ This will check if all dependencies are installed and test the model.
26
+
27
+ ### 3. Generate Your First Image
28
+
29
+ #### Command Line
30
+ ```cmd
31
+ python infer.py --prompt "A beautiful sunset over mountains, digital art" --output sunset.png
32
+ ```
33
+
34
+ #### Web Interface
35
+ ```cmd
36
+ python app.py
37
+ ```
38
+ Then open http://localhost:7860 in your browser.
39
+
40
+ #### Python Script
41
+ ```python
42
+ from bytedream import ByteDreamGenerator
43
+
44
+ generator = ByteDreamGenerator()
45
+ image = generator.generate(
46
+ prompt="A cyberpunk city at night with neon lights",
47
+ num_inference_steps=50,
48
+ guidance_scale=7.5
49
+ )
50
+ image.save("cyberpunk_city.png")
51
+ ```
52
+
53
+ ## Model Training
54
+
55
+ ### Prepare Your Dataset
56
+
57
+ 1. Collect images in a folder (JPG, PNG formats)
58
+ 2. Optionally add .txt files with captions for each image
59
+ 3. Run preparation script:
60
+
61
+ ```cmd
62
+ python prepare_dataset.py --input ./my_images --output ./processed_data --size 512
63
+ ```
64
+
65
+ ### Train the Model
66
+
67
+ ```cmd
68
+ python train.py --train_data ./processed_data --output_dir ./models/bytedream --epochs 100 --batch_size 4
69
+ ```
70
+
71
+ Training time depends on:
72
+ - Dataset size
73
+ - Number of epochs
74
+ - CPU speed (expect several hours to days for CPU training)
75
+
76
+ ## Hugging Face Deployment
77
+
78
+ ### Upload to Hugging Face Hub
79
+
80
+ 1. Get your Hugging Face token from https://huggingface.co/settings/tokens
81
+ 2. Upload model:
82
+
83
+ ```cmd
84
+ python upload_to_hf.py --model_path ./models/bytedream --repo_id your_username/bytedream --token YOUR_TOKEN
85
+ ```
86
+
87
+ ### Deploy to Spaces
88
+
89
+ 1. Create Gradio app file (already included as `app.py`)
90
+ 2. Go to https://huggingface.co/spaces
91
+ 3. Click "Create new Space"
92
+ 4. Choose Gradio SDK
93
+ 5. Upload all project files
94
+ 6. Select CPU hardware (COSTAR or similar)
95
+ 7. Deploy!
96
+
97
+ ## File Structure
98
+
99
+ ```
100
+ Byte Dream/
101
+ ├── bytedream/ # Core package
102
+ │ ├── __init__.py # Package initialization
103
+ │ ├── model.py # Neural network architectures
104
+ │ ├── pipeline.py # Generation pipeline
105
+ │ ├── scheduler.py # Diffusion scheduler
106
+ │ ├── generator.py # Main generator class
107
+ │ └── utils.py # Utility functions
108
+ ├── train.py # Training script
109
+ ├── infer.py # Command-line inference
110
+ ├── app.py # Gradio web interface
111
+ ├── main.py # High-level application API
112
+ ├── prepare_dataset.py # Dataset preparation
113
+ ├── upload_to_hf.py # Hugging Face upload
114
+ ├── quick_start.py # Quick start guide
115
+ ├── config.yaml # Configuration
116
+ ├── requirements.txt # Python dependencies
117
+ ├── environment.yml # Conda environment
118
+ ├── README.md # Documentation
119
+ └── LICENSE # MIT License
120
+ ```
121
+
122
+ ## Usage Examples
123
+
124
+ ### Basic Generation
125
+ ```cmd
126
+ python infer.py -p "A dragon flying over castle" -o dragon.png
127
+ ```
128
+
129
+ ### Advanced Parameters
130
+ ```cmd
131
+ python infer.py -p "Fantasy landscape" -n "ugly, blurry" -W 768 -H 768 -s 75 -g 8.0 --seed 42
132
+ ```
133
+
134
+ ### Batch Generation (Python)
135
+ ```python
136
+ from bytedream import ByteDreamGenerator
137
+
138
+ generator = ByteDreamGenerator()
139
+
140
+ prompts = [
141
+ "Sunset beach, palm trees, tropical paradise",
142
+ "Mountain landscape, snow peaks, alpine lake",
143
+ "Forest path, sunlight filtering through trees"
144
+ ]
145
+
146
+ images = generator.generate_batch(
147
+ prompts=prompts,
148
+ width=512,
149
+ height=512,
150
+ num_inference_steps=50
151
+ )
152
+
153
+ for i, img in enumerate(images):
154
+ img.save(f"landscape_{i}.png")
155
+ ```
156
+
157
+ ## Performance Optimization
158
+
159
+ ### CPU Optimization
160
+ The model is already optimized for CPU, but you can:
161
+
162
+ 1. Increase threads in `config.yaml`:
163
+ ```yaml
164
+ cpu_optimization:
165
+ threads: 8 # Set to number of CPU cores
166
+ precision: fp32
167
+ ```
168
+
169
+ 2. Use fewer inference steps for faster generation:
170
+ ```cmd
171
+ python infer.py -p "Quick preview" -s 20
172
+ ```
173
+
174
+ 3. Generate smaller images:
175
+ ```cmd
176
+ python infer.py -p "Small image" -W 256 -H 256
177
+ ```
178
+
179
+ ### Memory Management
180
+ For systems with limited RAM:
181
+
182
+ 1. Enable memory efficient mode (already default)
183
+ 2. Generate one image at a time
184
+ 3. Restart Python between batch generations
185
+
186
+ ## Troubleshooting
187
+
188
+ ### Import Errors
189
+ If you get import errors:
190
+ ```cmd
191
+ pip install --upgrade torch transformers diffusers
192
+ ```
193
+
194
+ ### Memory Errors
195
+ Reduce image size or inference steps:
196
+ ```cmd
197
+ python infer.py -p "Test" -W 256 -H 256 -s 20
198
+ ```
199
+
200
+ ### Slow Generation
201
+ CPU generation is slower than GPU. Expect:
202
+ - 256x256: ~30-60 seconds
203
+ - 512x512: ~2-5 minutes
204
+ - 768x768: ~5-10 minutes
205
+
206
+ Times vary by CPU speed and number of steps.
207
+
208
+ ### Model Not Loading
209
+ The model needs trained weights. Either:
210
+ 1. Train your own model using `train.py`
211
+ 2. Download pretrained weights from Hugging Face
212
+ 3. Use Stable Diffusion weights as base
213
+
214
+ ## Tips for Better Results
215
+
216
+ ### Writing Prompts
217
+ - Be specific and descriptive
218
+ - Include style references ("digital art", "oil painting")
219
+ - Mention lighting ("dramatic lighting", "soft sunlight")
220
+ - Add quality modifiers ("highly detailed", "4K", "masterpiece")
221
+
222
+ ### Negative Prompts
223
+ Use to avoid common issues:
224
+ ```
225
+ ugly, blurry, low quality, distorted, deformed, bad anatomy, extra limbs
226
+ ```
227
+
228
+ ### Parameters
229
+ - **Steps**: 20-30 (quick), 50 (good), 75-100 (best)
230
+ - **Guidance**: 5-7 (creative), 7-9 (balanced), 9-12 (strict)
231
+ - **Resolution**: Start with 512x512, increase if needed
232
+
233
+ ## Advanced Features
234
+
235
+ ### Custom Schedulers
236
+ Edit `config.yaml` to try different schedulers:
237
+ - DDIM (default) - Fast, deterministic
238
+ - EulerDiscrete - Alternative sampling
239
+
240
+ ### Fine-tuning
241
+ Fine-tune on specific styles:
242
+ 1. Collect 50-100 images in desired style
243
+ 2. Prepare dataset
244
+ 3. Train for 50-100 epochs with low learning rate (1e-6)
245
+
246
+ ## Support
247
+
248
+ For issues and questions:
249
+ 1. Check this guide first
250
+ 2. Review README.md
251
+ 3. Check code comments
252
+ 4. Visit Hugging Face documentation
253
+
254
+ ## Updates
255
+
256
+ Check for updates and improvements:
257
+ - New model architectures
258
+ - Better CPU optimization
259
+ - Additional features
260
+ - Bug fixes
261
+
262
+ Enjoy creating with Byte Dream! 🎨
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream - Gradio Web Interface
3
+ Interactive web UI for image generation
4
+ """
5
+
6
+ import gradio as gr
7
+ from bytedream.generator import ByteDreamGenerator
8
+ import torch
9
+
10
+
11
+ # Initialize generator
12
+ print("Loading Byte Dream model...")
13
+ try:
14
+ generator = ByteDreamGenerator(
15
+ model_path="./models/bytedream",
16
+ config_path="config.yaml",
17
+ device="cpu",
18
+ )
19
+ print("✓ Model loaded successfully!")
20
+ except Exception as e:
21
+ print(f"⚠ Warning: Could not load model: {e}")
22
+ print(" Please train the model or download pretrained weights.")
23
+ generator = None
24
+
25
+
26
+ def generate_image(
27
+ prompt,
28
+ negative_prompt,
29
+ width,
30
+ height,
31
+ num_steps,
32
+ guidance_scale,
33
+ seed,
34
+ ):
35
+ """Generate image from prompt"""
36
+
37
+ if generator is None:
38
+ return None, "Error: Model not loaded. Please train or download model weights."
39
+
40
+ # Convert seed to None if -1
41
+ seed_value = None if seed == -1 else seed
42
+
43
+ try:
44
+ # Generate image
45
+ image = generator.generate(
46
+ prompt=prompt,
47
+ negative_prompt=negative_prompt if negative_prompt else None,
48
+ width=int(width),
49
+ height=int(height),
50
+ num_inference_steps=int(num_steps),
51
+ guidance_scale=float(guidance_scale),
52
+ seed=seed_value,
53
+ )
54
+
55
+ return image, "Success! ✓"
56
+
57
+ except Exception as e:
58
+ print(f"Error generating image: {e}")
59
+ import traceback
60
+ traceback.print_exc()
61
+ return None, f"Error: {str(e)}"
62
+
63
+
64
+ # Create Gradio interface
65
+ with gr.Blocks(
66
+ title="Byte Dream - AI Image Generator",
67
+ theme=gr.themes.Soft(),
68
+ css="""
69
+ .gradio-container {
70
+ max-width: 1400px !important;
71
+ }
72
+ #main-heading {
73
+ text-align: center;
74
+ margin-bottom: 20px;
75
+ }
76
+ .description {
77
+ text-align: center;
78
+ margin-bottom: 30px;
79
+ }
80
+ """
81
+ ) as demo:
82
+
83
+ gr.Markdown("""
84
+ # 🎨 Byte Dream - AI Image Generator
85
+
86
+ ### Transform your imagination into reality with advanced AI
87
+
88
+ Powered by state-of-the-art latent diffusion models, optimized for CPU inference.
89
+ """)
90
+
91
+ with gr.Row():
92
+ with gr.Column(scale=1):
93
+ gr.Markdown("### 📝 Create Your Prompt")
94
+
95
+ prompt_input = gr.Textbox(
96
+ label="Positive Prompt",
97
+ placeholder="Describe the image you want to create...",
98
+ lines=3,
99
+ value="A beautiful sunset over mountains, digital art, highly detailed, vibrant colors",
100
+ )
101
+
102
+ negative_prompt_input = gr.Textbox(
103
+ label="Negative Prompt (Optional)",
104
+ placeholder="What to avoid in the image...",
105
+ lines=2,
106
+ value="ugly, blurry, low quality, distorted, deformed",
107
+ )
108
+
109
+ gr.Markdown("### ⚙️ Settings")
110
+
111
+ with gr.Row():
112
+ width_slider = gr.Slider(
113
+ minimum=256,
114
+ maximum=1024,
115
+ step=64,
116
+ value=512,
117
+ label="Width (px)",
118
+ info="Image width - multiples of 64"
119
+ )
120
+ height_slider = gr.Slider(
121
+ minimum=256,
122
+ maximum=1024,
123
+ step=64,
124
+ value=512,
125
+ label="Height (px)",
126
+ info="Image height - multiples of 64"
127
+ )
128
+
129
+ with gr.Row():
130
+ steps_slider = gr.Slider(
131
+ minimum=10,
132
+ maximum=150,
133
+ step=5,
134
+ value=50,
135
+ label="Inference Steps",
136
+ info="More steps = better quality but slower"
137
+ )
138
+ guidance_slider = gr.Slider(
139
+ minimum=1.0,
140
+ maximum=20.0,
141
+ step=0.5,
142
+ value=7.5,
143
+ label="Guidance Scale",
144
+ info="Higher = closer to prompt, Lower = more creative"
145
+ )
146
+
147
+ seed_input = gr.Number(
148
+ label="Seed",
149
+ value=-1,
150
+ precision=0,
151
+ info="-1 for random, any number for reproducibility",
152
+ )
153
+
154
+ generate_btn = gr.Button(
155
+ "🎨 Generate Image",
156
+ variant="primary",
157
+ size="lg",
158
+ )
159
+
160
+ with gr.Column(scale=1):
161
+ gr.Markdown("### 🖼️ Result")
162
+
163
+ output_image = gr.Image(
164
+ label="Generated Image",
165
+ type="pil",
166
+ height=512,
167
+ )
168
+ status_text = gr.Textbox(
169
+ label="Status",
170
+ interactive=False,
171
+ )
172
+
173
+ download_btn = gr.File(
174
+ label="Download",
175
+ visible=True,
176
+ )
177
+
178
+ # Tips section
179
+ with gr.Accordion("💡 Tips for Better Results", open=False):
180
+ gr.Markdown("""
181
+ **Writing Effective Prompts:**
182
+ - Be specific and descriptive
183
+ - Include art style references (e.g., "digital art", "oil painting", "watercolor")
184
+ - Mention lighting ("dramatic lighting", "soft sunlight", "neon lights")
185
+ - Add quality modifiers ("highly detailed", "4K", "masterpiece")
186
+ - Specify mood and atmosphere ("peaceful", "dramatic", "mysterious")
187
+
188
+ **Using Negative Prompts:**
189
+ - Remove unwanted elements ("no people", "no text")
190
+ - Avoid quality issues ("no blur", "no distortion")
191
+ - Fix common problems ("bad anatomy", "extra limbs")
192
+
193
+ **Parameter Guide:**
194
+ - **Steps**: 20-30 for quick previews, 50-75 for final images, 100+ for best quality
195
+ - **Guidance**: 5-7 for creative freedom, 7-9 for balanced, 9-12 for strict prompt following
196
+ - **Resolution**: Higher = more detail but slower. Start with 512x512, increase if needed
197
+ """)
198
+
199
+ # Examples section
200
+ gr.Markdown("### 💡 Example Prompts")
201
+
202
+ with gr.Row():
203
+ example_btn1 = gr.Button(
204
+ "🌆 Cyberpunk City",
205
+ size="sm",
206
+ )
207
+ example_btn2 = gr.Button(
208
+ "🐉 Fantasy Dragon",
209
+ size="sm",
210
+ )
211
+ example_btn3 = gr.Button(
212
+ "🏔️ Peaceful Landscape",
213
+ size="sm",
214
+ )
215
+
216
+ with gr.Row():
217
+ example_btn4 = gr.Button(
218
+ "👤 Character Portrait",
219
+ size="sm",
220
+ )
221
+ example_btn5 = gr.Button(
222
+ "🌊 Underwater Scene",
223
+ size="sm",
224
+ )
225
+ example_btn6 = gr.Button(
226
+ "🎨 Abstract Art",
227
+ size="sm",
228
+ )
229
+
230
+ # Example prompt values
231
+ example_prompts = {
232
+ "example_btn1": (
233
+ "A cyberpunk city at night with neon lights, futuristic architecture, flying cars, rain-slicked streets, highly detailed, digital art, cinematic lighting",
234
+ "ugly, blurry, low quality, distorted, dark, gloomy"
235
+ ),
236
+ "example_btn2": (
237
+ "A majestic dragon breathing fire, fantasy art, dramatic lighting, epic scene, scales gleaming, powerful wings, mountain landscape background",
238
+ "ugly, deformed, blurry, low quality, cartoonish"
239
+ ),
240
+ "example_btn3": (
241
+ "A peaceful cottage in a meadow, wildflowers, sunny day, blue sky, studio ghibli style, serene atmosphere, pastoral landscape",
242
+ "people, animals, buildings, urban, dark, stormy"
243
+ ),
244
+ "example_btn4": (
245
+ "Portrait of a warrior princess, ornate armor, fantasy setting, intricate details, character design, dramatic lighting, confident expression, long flowing hair",
246
+ "ugly, deformed, asymmetrical, blurry, low quality, bad anatomy"
247
+ ),
248
+ "example_btn5": (
249
+ "Underwater coral reef, tropical fish, sunlight filtering through water, photorealistic, vibrant colors, marine life, crystal clear water",
250
+ "polluted, murky, dark, blurry, low quality"
251
+ ),
252
+ "example_btn6": (
253
+ "Abstract geometric art, colorful shapes, dynamic composition, modern art, bold patterns, artistic expression, vivid colors",
254
+ "representational, realistic, boring, dull colors, simple"
255
+ ),
256
+ }
257
+
258
+ # Connect example buttons
259
+ def set_example(prompt, negative):
260
+ return prompt, negative, "Click Generate to create!"
261
+
262
+ for btn_name, (prompt, negative) in example_prompts.items():
263
+ demo.get_component(btn_name).click(
264
+ fn=set_example,
265
+ inputs=[gr.State(prompt), gr.State(negative)],
266
+ outputs=[prompt_input, negative_prompt_input, status_text],
267
+ )
268
+
269
+ # Connect generate button
270
+ generate_btn.click(
271
+ fn=generate_image,
272
+ inputs=[
273
+ prompt_input,
274
+ negative_prompt_input,
275
+ width_slider,
276
+ height_slider,
277
+ steps_slider,
278
+ guidance_slider,
279
+ seed_input,
280
+ ],
281
+ outputs=[output_image, status_text],
282
+ )
283
+
284
+ # Footer
285
+ gr.Markdown("""
286
+ ---
287
+ **Byte Dream** v1.0.0 | Powered by Latent Diffusion Models | Optimized for CPU Inference
288
+
289
+ Created with ❤️ using PyTorch and Hugging Face Diffusers
290
+ """)
291
+
292
+
293
+ if __name__ == "__main__":
294
+ print("\n" + "="*60)
295
+ print("Starting Byte Dream Web Interface")
296
+ print("="*60)
297
+ print("\nOpening browser...")
298
+ print("Press Ctrl+C to close\n")
299
+
300
+ demo.launch(
301
+ server_name="0.0.0.0",
302
+ server_port=7860,
303
+ share=False,
304
+ show_error=True,
305
+ )
bytedream/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream - AI Image Generation Model
3
+ Production-ready text-to-image diffusion model optimized for CPU inference
4
+ """
5
+
6
+ __version__ = "1.0.0"
7
+ __author__ = "Byte Dream Team"
8
+
9
+ from .generator import ByteDreamGenerator
10
+ from .model import UNet2DConditionModel, AutoencoderKL, CLIPTextModel
11
+ from .pipeline import ByteDreamPipeline
12
+ from .scheduler import DDIMScheduler
13
+
14
+ __all__ = [
15
+ "ByteDreamGenerator",
16
+ "UNet2DConditionModel",
17
+ "AutoencoderKL",
18
+ "CLIPTextModel",
19
+ "ByteDreamPipeline",
20
+ "DDIMScheduler",
21
+ ]
bytedream/generator.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream Generator
3
+ Main inference engine optimized for CPU with advanced features
4
+ """
5
+
6
+ import torch
7
+ import yaml
8
+ from pathlib import Path
9
+ from typing import Optional, Union, List
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gc
13
+
14
+
15
+ class ByteDreamGenerator:
16
+ """
17
+ Production-ready image generation engine
18
+ Optimized for CPU inference with memory efficiency
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model_path: Optional[str] = None,
24
+ config_path: str = "config.yaml",
25
+ device: str = "cpu",
26
+ use_safetensors: bool = True,
27
+ ):
28
+ """
29
+ Initialize Byte Dream generator
30
+
31
+ Args:
32
+ model_path: Path to trained model weights
33
+ config_path: Path to configuration file
34
+ device: Device to run on (default: cpu)
35
+ use_safetensors: Use safetensors format if available
36
+ """
37
+ self.device = device
38
+ self.config_path = config_path
39
+ self.use_safetensors = use_safetensors
40
+
41
+ # Load configuration
42
+ self.config = self._load_config(config_path)
43
+
44
+ # Initialize components
45
+ print("Initializing Byte Dream Generator...")
46
+ self.pipeline = self._initialize_pipeline(model_path)
47
+
48
+ # Optimize for CPU
49
+ self._optimize_for_cpu()
50
+
51
+ print("✓ Byte Dream Generator ready!")
52
+
53
+ def _load_config(self, config_path: str) -> dict:
54
+ """Load configuration from YAML file"""
55
+ try:
56
+ with open(config_path, 'r', encoding='utf-8') as f:
57
+ config = yaml.safe_load(f)
58
+ return config
59
+ except FileNotFoundError:
60
+ print(f"Warning: Config file {config_path} not found. Using defaults.")
61
+ return self._get_default_config()
62
+
63
+ def _get_default_config(self) -> dict:
64
+ """Get default configuration"""
65
+ return {
66
+ 'model': {
67
+ 'unet': {
68
+ 'in_channels': 4,
69
+ 'out_channels': 4,
70
+ 'block_out_channels': [320, 640, 1280, 1280],
71
+ 'layers_per_block': 2,
72
+ 'attention_head_dim': 8,
73
+ 'cross_attention_dim': 768,
74
+ 'use_linear_projection': True,
75
+ },
76
+ 'scheduler': {
77
+ 'name': 'DDIM',
78
+ 'num_train_timesteps': 1000,
79
+ 'beta_start': 0.00085,
80
+ 'beta_end': 0.012,
81
+ 'beta_schedule': 'scaled_linear',
82
+ 'clip_sample': False,
83
+ 'set_alpha_to_one': False,
84
+ }
85
+ },
86
+ 'generation': {
87
+ 'width': 512,
88
+ 'height': 512,
89
+ 'num_inference_steps': 50,
90
+ 'guidance_scale': 7.5,
91
+ 'negative_prompt': 'ugly, blurry, low quality, distorted, deformed',
92
+ }
93
+ }
94
+
95
+ def _initialize_pipeline(self, model_path: Optional[str]):
96
+ """Initialize the generation pipeline"""
97
+ from bytedream.model import create_unet, create_vae, create_text_encoder
98
+ from bytedream.scheduler import create_scheduler
99
+ from bytedream.pipeline import ByteDreamPipeline
100
+
101
+ # Create model components
102
+ print("Creating UNet...")
103
+ unet = create_unet(self.config)
104
+
105
+ print("Creating VAE...")
106
+ vae = create_vae(self.config)
107
+
108
+ print("Creating Text Encoder...")
109
+ text_encoder = create_text_encoder(self.config)
110
+
111
+ print("Creating Scheduler...")
112
+ scheduler = create_scheduler(self.config)
113
+
114
+ # Load pretrained weights if provided
115
+ if model_path:
116
+ self._load_model_weights(unet, model_path)
117
+
118
+ # Create pipeline
119
+ pipeline = ByteDreamPipeline(
120
+ text_encoder=text_encoder,
121
+ vae=vae,
122
+ unet=unet,
123
+ scheduler=scheduler,
124
+ device=self.device,
125
+ dtype=torch.float32,
126
+ )
127
+
128
+ return pipeline
129
+
130
+ def _load_model_weights(self, unet, model_path: str):
131
+ """Load pretrained model weights"""
132
+ model_file = Path(model_path) / "unet_pytorch_model.bin"
133
+
134
+ if not model_file.exists():
135
+ model_file = Path(model_path) / "pytorch_model.bin"
136
+
137
+ if model_file.exists():
138
+ print(f"Loading weights from {model_file}...")
139
+ checkpoint = torch.load(model_file, map_location=self.device)
140
+
141
+ if 'unet_state_dict' in checkpoint:
142
+ unet.load_state_dict(checkpoint['unet_state_dict'])
143
+ else:
144
+ unet.load_state_dict(checkpoint)
145
+
146
+ print("✓ Weights loaded successfully!")
147
+ else:
148
+ print("⚠ No pretrained weights found. Using random initialization.")
149
+ print(" Train the model or download pretrained weights.")
150
+
151
+ def _optimize_for_cpu(self):
152
+ """Optimize pipeline for CPU inference"""
153
+ # Set number of threads
154
+ cpu_config = self.config.get('cpu_optimization', {})
155
+ threads = cpu_config.get('threads', -1)
156
+
157
+ if threads > 0:
158
+ torch.set_num_threads(threads)
159
+ else:
160
+ # Use all available cores
161
+ import os
162
+ torch.set_num_threads(os.cpu_count())
163
+
164
+ # Enable memory efficient mode
165
+ self.pipeline.enable_memory_efficient_mode()
166
+
167
+ print(f"✓ Optimized for CPU ({torch.get_num_threads()} threads)")
168
+
169
+ @torch.no_grad()
170
+ def generate(
171
+ self,
172
+ prompt: str,
173
+ negative_prompt: Optional[str] = None,
174
+ width: Optional[int] = None,
175
+ height: Optional[int] = None,
176
+ num_inference_steps: Optional[int] = None,
177
+ guidance_scale: Optional[float] = None,
178
+ seed: Optional[int] = None,
179
+ eta: float = 0.0,
180
+ output_type: str = "pil",
181
+ ) -> Image.Image:
182
+ """
183
+ Generate image from text prompt
184
+
185
+ Args:
186
+ prompt: Text description of desired image
187
+ negative_prompt: Things to avoid in the image
188
+ width: Output image width (default: 512)
189
+ height: Output image height (default: 512)
190
+ num_inference_steps: Number of denoising steps (default: 50)
191
+ guidance_scale: How closely to follow prompt (default: 7.5)
192
+ seed: Random seed for reproducibility
193
+ eta: DDIM eta parameter (0.0 for deterministic)
194
+ output_type: Output format ("pil" or "tensor")
195
+
196
+ Returns:
197
+ Generated PIL Image
198
+ """
199
+ # Get default values from config
200
+ gen_config = self.config.get('generation', {})
201
+
202
+ width = width or gen_config.get('width', 512)
203
+ height = height or gen_config.get('height', 512)
204
+ num_inference_steps = num_inference_steps or gen_config.get('num_inference_steps', 50)
205
+ guidance_scale = guidance_scale or gen_config.get('guidance_scale', 7.5)
206
+ negative_prompt = negative_prompt or gen_config.get('negative_prompt', "")
207
+
208
+ # Ensure dimensions are divisible by 8
209
+ width = (width // 8) * 8
210
+ height = (height // 8) * 8
211
+
212
+ print(f"\nGenerating image...")
213
+ print(f"Prompt: {prompt}")
214
+ if negative_prompt:
215
+ print(f"Negative prompt: {negative_prompt}")
216
+ print(f"Size: {width}x{height}")
217
+ print(f"Steps: {num_inference_steps}")
218
+ print(f"Guidance scale: {guidance_scale}")
219
+
220
+ # Set random seed
221
+ generator = None
222
+ if seed is not None:
223
+ generator = torch.Generator(device=self.device).manual_seed(seed)
224
+ print(f"Seed: {seed}")
225
+
226
+ # Generate image
227
+ result = self.pipeline(
228
+ prompt=prompt,
229
+ negative_prompt=negative_prompt,
230
+ height=height,
231
+ width=width,
232
+ num_inference_steps=num_inference_steps,
233
+ guidance_scale=guidance_scale,
234
+ eta=eta,
235
+ generator=generator,
236
+ output_type=output_type,
237
+ )
238
+
239
+ image = result[0] if isinstance(result, (list, tuple)) else result
240
+
241
+ print("\n✓ Image generated successfully!")
242
+
243
+ return image
244
+
245
+ def generate_batch(
246
+ self,
247
+ prompts: List[str],
248
+ negative_prompt: Optional[str] = None,
249
+ width: int = 512,
250
+ height: int = 512,
251
+ num_inference_steps: int = 50,
252
+ guidance_scale: float = 7.5,
253
+ seeds: Optional[List[int]] = None,
254
+ ) -> List[Image.Image]:
255
+ """
256
+ Generate multiple images from prompts
257
+
258
+ Args:
259
+ prompts: List of text prompts
260
+ negative_prompt: Negative prompt for all images
261
+ width: Image width
262
+ height: Image height
263
+ num_inference_steps: Number of denoising steps
264
+ guidance_scale: Guidance scale
265
+ seeds: Random seeds for each image
266
+
267
+ Returns:
268
+ List of generated PIL Images
269
+ """
270
+ images = []
271
+
272
+ for i, prompt in enumerate(prompts):
273
+ seed = seeds[i] if seeds and i < len(seeds) else None
274
+
275
+ print(f"\n{'='*50}")
276
+ print(f"Generating image {i+1}/{len(prompts)}")
277
+ print(f"{'='*50}")
278
+
279
+ image = self.generate(
280
+ prompt=prompt,
281
+ negative_prompt=negative_prompt,
282
+ width=width,
283
+ height=height,
284
+ num_inference_steps=num_inference_steps,
285
+ guidance_scale=guidance_scale,
286
+ seed=seed,
287
+ )
288
+
289
+ images.append(image)
290
+
291
+ # Clear memory between generations
292
+ gc.collect()
293
+
294
+ return images
295
+
296
+ def get_model_info(self) -> dict:
297
+ """Get model information"""
298
+ unet_params = sum(p.numel() for p in self.pipeline.unet.parameters())
299
+ vae_params = sum(p.numel() for p in self.pipeline.vae.parameters())
300
+
301
+ info = {
302
+ 'name': self.config['model']['name'],
303
+ 'version': self.config['model']['version'],
304
+ 'unet_parameters': f"{unet_params:,}",
305
+ 'device': self.device,
306
+ 'dtype': str(self.pipeline.dtype),
307
+ 'default_resolution': f"{self.config['generation']['width']}x{self.config['generation']['height']}",
308
+ }
309
+
310
+ return info
311
+
312
+ def clear_memory(self):
313
+ """Clear GPU/CPU memory"""
314
+ gc.collect()
315
+ if torch.cuda.is_available():
316
+ torch.cuda.empty_cache()
317
+ print("Memory cleared")
bytedream/model.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream Model Architecture
3
+ Complete implementation of UNet, VAE, and Text Encoder for diffusion-based image generation
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Optional, Tuple, Union
10
+ import math
11
+
12
+
13
+ class ResnetBlock2D(nn.Module):
14
+ """Residual block for 2D convolutions"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels: int,
19
+ out_channels: int,
20
+ temb_channels: Optional[int] = None,
21
+ groups: int = 32,
22
+ eps: float = 1e-6,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
27
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
28
+
29
+ if temb_channels is not None:
30
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
31
+
32
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
33
+ self.dropout = nn.Dropout(0.0)
34
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
35
+
36
+ self.nonlinearity = nn.SiLU(inplace=True)
37
+
38
+ if in_channels != out_channels:
39
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
40
+ else:
41
+ self.conv_shortcut = None
42
+
43
+ def forward(
44
+ self,
45
+ hidden_states: torch.Tensor,
46
+ temb: Optional[torch.Tensor] = None,
47
+ ) -> torch.Tensor:
48
+ x = hidden_states
49
+
50
+ x = self.norm1(x)
51
+ x = self.nonlinearity(x)
52
+ x = self.conv1(x)
53
+
54
+ if temb is not None:
55
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
56
+ x = x + temb
57
+
58
+ x = self.norm2(x)
59
+ x = self.nonlinearity(x)
60
+ x = self.dropout(x)
61
+ x = self.conv2(x)
62
+
63
+ if self.conv_shortcut is not None:
64
+ hidden_states = self.conv_shortcut(hidden_states)
65
+
66
+ return x + hidden_states
67
+
68
+
69
+ class AttentionBlock(nn.Module):
70
+ """Cross-attention block for text-conditioned generation"""
71
+
72
+ def __init__(
73
+ self,
74
+ query_dim: int,
75
+ cross_attention_dim: Optional[int] = None,
76
+ num_heads: int = 8,
77
+ head_dim: Optional[int] = None,
78
+ eps: float = 1e-6,
79
+ ):
80
+ super().__init__()
81
+
82
+ inner_dim = num_heads * head_dim if head_dim is not None else query_dim
83
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
84
+
85
+ self.num_heads = num_heads
86
+ self.head_dim = head_dim if head_dim is not None else query_dim // num_heads
87
+
88
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
89
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
90
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
91
+
92
+ self.to_out = nn.ModuleList([
93
+ nn.Linear(inner_dim, query_dim),
94
+ nn.Dropout(0.0)
95
+ ])
96
+
97
+ self.norm = nn.LayerNorm(query_dim, eps=eps)
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ encoder_hidden_states: Optional[torch.Tensor] = None,
103
+ ) -> torch.Tensor:
104
+ residual = hidden_states
105
+
106
+ batch_size, sequence_length, _ = hidden_states.shape
107
+
108
+ query = self.to_q(hidden_states)
109
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
110
+ key = self.to_k(encoder_hidden_states)
111
+ value = self.to_v(encoder_hidden_states)
112
+
113
+ # Multi-head attention
114
+ query = query.reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
115
+ key = key.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
116
+ value = value.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
117
+
118
+ # Scaled dot-product attention
119
+ attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
120
+ attn_weights = F.softmax(attn_weights, dim=-1)
121
+
122
+ attn_output = torch.matmul(attn_weights, value)
123
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, sequence_length, -1)
124
+
125
+ # Output projection
126
+ for layer in self.to_out:
127
+ attn_output = layer(attn_output)
128
+
129
+ return residual + attn_output
130
+
131
+
132
+ class DownBlock2D(nn.Module):
133
+ """Downsampling block"""
134
+
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ out_channels: int,
139
+ temb_channels: int,
140
+ num_layers: int = 1,
141
+ add_downsample: bool = True,
142
+ has_cross_attention: bool = False,
143
+ cross_attention_dim: Optional[int] = None,
144
+ ):
145
+ super().__init__()
146
+
147
+ resnets = []
148
+ attentions = []
149
+
150
+ for i in range(num_layers):
151
+ in_ch = in_channels if i == 0 else out_channels
152
+
153
+ resnets.append(ResnetBlock2D(
154
+ in_channels=in_ch,
155
+ out_channels=out_channels,
156
+ temb_channels=temb_channels,
157
+ ))
158
+
159
+ if has_cross_attention:
160
+ attentions.append(AttentionBlock(
161
+ query_dim=out_channels,
162
+ cross_attention_dim=cross_attention_dim,
163
+ num_heads=8,
164
+ head_dim=out_channels // 8,
165
+ ))
166
+ else:
167
+ attentions.append(None)
168
+
169
+ self.resnets = nn.ModuleList(resnets)
170
+ self.attentions = nn.ModuleList(attentions)
171
+
172
+ if add_downsample:
173
+ self.downsamplers = nn.ModuleList([
174
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
175
+ ])
176
+ else:
177
+ self.downsamplers = None
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ temb: Optional[torch.Tensor] = None,
183
+ encoder_hidden_states: Optional[torch.Tensor] = None,
184
+ ) -> torch.Tensor:
185
+ output_states = ()
186
+
187
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
188
+ hidden_states = resnet(hidden_states, temb)
189
+
190
+ if attn is not None and encoder_hidden_states is not None:
191
+ hidden_states = attn(hidden_states, encoder_hidden_states)
192
+
193
+ output_states += (hidden_states,)
194
+
195
+ if self.downsamplers is not None:
196
+ for downsampler in self.downsamplers:
197
+ hidden_states = downsampler(hidden_states)
198
+ output_states += (hidden_states,)
199
+
200
+ return hidden_states, output_states
201
+
202
+
203
+ class UpBlock2D(nn.Module):
204
+ """Upsampling block"""
205
+
206
+ def __init__(
207
+ self,
208
+ in_channels: int,
209
+ out_channels: int,
210
+ prev_output_channel: int,
211
+ temb_channels: int,
212
+ num_layers: int = 1,
213
+ add_upsample: bool = True,
214
+ has_cross_attention: bool = False,
215
+ cross_attention_dim: Optional[int] = None,
216
+ ):
217
+ super().__init__()
218
+
219
+ resnets = []
220
+ attentions = []
221
+
222
+ for i in range(num_layers):
223
+ in_ch = in_channels if i == 0 else out_channels
224
+ mix_ch = prev_output_channel if i == num_layers - 1 else out_channels
225
+
226
+ resnets.append(ResnetBlock2D(
227
+ in_channels=in_ch + mix_ch,
228
+ out_channels=out_channels,
229
+ temb_channels=temb_channels,
230
+ ))
231
+
232
+ if has_cross_attention:
233
+ attentions.append(AttentionBlock(
234
+ query_dim=out_channels,
235
+ cross_attention_dim=cross_attention_dim,
236
+ num_heads=8,
237
+ head_dim=out_channels // 8,
238
+ ))
239
+ else:
240
+ attentions.append(None)
241
+
242
+ self.resnets = nn.ModuleList(resnets)
243
+ self.attentions = nn.ModuleList(attentions)
244
+
245
+ if add_upsample:
246
+ self.upsamplers = nn.ModuleList([
247
+ nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
248
+ ])
249
+ else:
250
+ self.upsamplers = None
251
+
252
+ def forward(
253
+ self,
254
+ hidden_states: torch.Tensor,
255
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
256
+ temb: Optional[torch.Tensor] = None,
257
+ encoder_hidden_states: Optional[torch.Tensor] = None,
258
+ ) -> torch.Tensor:
259
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
260
+ # Skip connection from U-Net downsampling path
261
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple[i]], dim=1)
262
+
263
+ hidden_states = resnet(hidden_states, temb)
264
+
265
+ if attn is not None and encoder_hidden_states is not None:
266
+ hidden_states = attn(hidden_states, encoder_hidden_states)
267
+
268
+ if self.upsamplers is not None:
269
+ for upsampler in self.upsamplers:
270
+ hidden_states = upsampler(hidden_states)
271
+
272
+ return hidden_states
273
+
274
+
275
+ class UNet2DConditionModel(nn.Module):
276
+ """
277
+ Main UNet architecture for diffusion-based image generation
278
+ Handles noise prediction conditioned on text embeddings
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ in_channels: int = 4,
284
+ out_channels: int = 4,
285
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
286
+ layers_per_block: int = 2,
287
+ attention_head_dim: int = 8,
288
+ cross_attention_dim: int = 768,
289
+ use_linear_projection: bool = True,
290
+ ):
291
+ super().__init__()
292
+
293
+ self.in_channels = in_channels
294
+ self.block_out_channels = block_out_channels
295
+ self.layers_per_block = layers_per_block
296
+ self.cross_attention_dim = cross_attention_dim
297
+
298
+ # Time embedding
299
+ time_embed_dim = block_out_channels[0] * 4
300
+ self.time_proj = nn.Sequential(
301
+ nn.Linear(block_out_channels[0], time_embed_dim),
302
+ nn.SiLU(inplace=True),
303
+ nn.Linear(time_embed_dim, time_embed_dim),
304
+ )
305
+
306
+ # Input convolution
307
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
308
+
309
+ # Down blocks
310
+ self.down_blocks = nn.ModuleList([])
311
+ output_channel = block_out_channels[0]
312
+
313
+ for i, down_block_type in enumerate(["down", "down", "down", "down"]):
314
+ input_channel = output_channel
315
+ output_channel = block_out_channels[i]
316
+ is_final_block = i == len(block_out_channels) - 1
317
+
318
+ down_block = DownBlock2D(
319
+ in_channels=input_channel,
320
+ out_channels=output_channel,
321
+ temb_channels=time_embed_dim,
322
+ num_layers=layers_per_block,
323
+ add_downsample=not is_final_block,
324
+ has_cross_attention=True,
325
+ cross_attention_dim=cross_attention_dim,
326
+ )
327
+
328
+ self.down_blocks.append(down_block)
329
+
330
+ # Middle blocks
331
+ self.mid_block = nn.ModuleList([
332
+ ResnetBlock2D(
333
+ in_channels=block_out_channels[-1],
334
+ out_channels=block_out_channels[-1],
335
+ temb_channels=time_embed_dim,
336
+ ),
337
+ AttentionBlock(
338
+ query_dim=block_out_channels[-1],
339
+ cross_attention_dim=cross_attention_dim,
340
+ num_heads=attention_head_dim,
341
+ head_dim=block_out_channels[-1] // attention_head_dim,
342
+ ),
343
+ ResnetBlock2D(
344
+ in_channels=block_out_channels[-1],
345
+ out_channels=block_out_channels[-1],
346
+ temb_channels=time_embed_dim,
347
+ ),
348
+ ])
349
+
350
+ # Up blocks
351
+ self.up_blocks = nn.ModuleList([])
352
+ reversed_block_out_channels = list(reversed(block_out_channels))
353
+
354
+ for i, up_block_type in enumerate(["up", "up", "up", "up"]):
355
+ prev_output_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
356
+ output_channel = reversed_block_out_channels[i]
357
+ is_final_block = i == len(block_out_channels) - 1
358
+
359
+ up_block = UpBlock2D(
360
+ in_channels=reversed_block_out_channels[i - 1] if i > 0 else reversed_block_out_channels[0],
361
+ out_channels=output_channel,
362
+ prev_output_channel=prev_output_channel,
363
+ temb_channels=time_embed_dim,
364
+ num_layers=layers_per_block + 1,
365
+ add_upsample=not is_final_block,
366
+ has_cross_attention=True,
367
+ cross_attention_dim=cross_attention_dim,
368
+ )
369
+
370
+ self.up_blocks.append(up_block)
371
+
372
+ # Output
373
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_channels=block_out_channels[0], eps=1e-6)
374
+ self.conv_act = nn.SiLU(inplace=True)
375
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
376
+
377
+ def forward(
378
+ self,
379
+ sample: torch.Tensor,
380
+ timestep: torch.Tensor,
381
+ encoder_hidden_states: torch.Tensor,
382
+ ) -> torch.Tensor:
383
+ # Time embedding
384
+ timesteps_proj = self.time_proj(timestep)
385
+ temb = timesteps_proj
386
+
387
+ # Initial convolution
388
+ hidden_states = self.conv_in(sample)
389
+
390
+ # Down sampling path
391
+ down_block_res_samples = (hidden_states,)
392
+
393
+ for downsample_block in self.down_blocks:
394
+ hidden_states, res_samples = downsample_block(
395
+ hidden_states=hidden_states,
396
+ temb=temb,
397
+ encoder_hidden_states=encoder_hidden_states,
398
+ )
399
+ down_block_res_samples += res_samples
400
+
401
+ # Middle
402
+ for layer in self.mid_block:
403
+ if isinstance(layer, ResnetBlock2D):
404
+ hidden_states = layer(hidden_states, temb)
405
+ else:
406
+ hidden_states = layer(hidden_states, encoder_hidden_states)
407
+
408
+ # Up sampling path
409
+ for upsample_block in self.up_blocks:
410
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
411
+ down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
412
+
413
+ hidden_states = upsample_block(
414
+ hidden_states=hidden_states,
415
+ res_hidden_states_tuple=res_samples,
416
+ temb=temb,
417
+ encoder_hidden_states=encoder_hidden_states,
418
+ )
419
+
420
+ # Output
421
+ hidden_states = self.conv_norm_out(hidden_states)
422
+ hidden_states = self.conv_act(hidden_states)
423
+ hidden_states = self.conv_out(hidden_states)
424
+
425
+ return hidden_states
426
+
427
+
428
+ class AutoencoderKL(nn.Module):
429
+ """
430
+ Variational Autoencoder for image compression and reconstruction
431
+ Compresses images to latent space for efficient diffusion
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ in_channels: int = 3,
437
+ out_channels: int = 3,
438
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",) * 4,
439
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4,
440
+ latent_channels: int = 4,
441
+ sample_size: int = 512,
442
+ ):
443
+ super().__init__()
444
+
445
+ self.sample_size = sample_size
446
+
447
+ # Encoder
448
+ self.encoder = nn.ModuleList()
449
+ channels = [in_channels, 128, 256, 512, 512]
450
+
451
+ for i in range(len(down_block_types)):
452
+ block = nn.Sequential(
453
+ nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
454
+ nn.GroupNorm(num_channels=channels[i+1], num_channels=channels[i+1], eps=1e-6),
455
+ nn.SiLU(inplace=True),
456
+ )
457
+ self.encoder.append(block)
458
+
459
+ # Latent space projection
460
+ self.quant_conv = nn.Conv2d(512, latent_channels * 2, kernel_size=1)
461
+
462
+ # Decoder
463
+ self.decoder = nn.ModuleList()
464
+ decoder_channels = [latent_channels, 512, 512, 256, 128]
465
+
466
+ for i in range(len(up_block_types)):
467
+ block = nn.Sequential(
468
+ nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1),
469
+ nn.GroupNorm(num_channels=decoder_channels[i+1], num_channels=decoder_channels[i+1], eps=1e-6),
470
+ nn.SiLU(inplace=True),
471
+ )
472
+ self.decoder.append(block)
473
+
474
+ self.post_quant_conv = nn.Conv2d(latent_channels, 512, kernel_size=1)
475
+ self.conv_out = nn.Conv2d(128, out_channels, kernel_size=3, stride=1, padding=1)
476
+
477
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
478
+ """Encode image to latent space"""
479
+ for block in self.encoder:
480
+ x = block(x)
481
+ x = self.quant_conv(x)
482
+ return x
483
+
484
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
485
+ """Decode from latent space to image"""
486
+ z = self.post_quant_conv(z)
487
+ for block in self.decoder:
488
+ z = block(z)
489
+ z = self.conv_out(z)
490
+ return z
491
+
492
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
493
+ """Full autoencoder forward pass"""
494
+ encoded = self.encode(x)
495
+ decoded = self.decode(encoded[:, :4]) # Use first 4 channels
496
+ return decoded
497
+
498
+
499
+ class CLIPTextModel(nn.Module):
500
+ """
501
+ CLIP text encoder for understanding text prompts
502
+ Extracts semantic features from text for conditioning
503
+ """
504
+
505
+ def __init__(self, model_name: str = "openai/clip-vit-large-patch14", max_length: int = 77):
506
+ super().__init__()
507
+
508
+ try:
509
+ from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer
510
+
511
+ self.model = HFCLIPTextModel.from_pretrained(model_name)
512
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
513
+ self.max_length = max_length
514
+ except ImportError:
515
+ print("Warning: transformers not installed. Using dummy text encoder.")
516
+ self.model = None
517
+ self.tokenizer = None
518
+
519
+ def forward(self, text: Union[str, List[str]], device: torch.device = None) -> torch.Tensor:
520
+ """
521
+ Encode text to embeddings
522
+
523
+ Args:
524
+ text: Text string or list of strings
525
+ device: Target device for computation
526
+
527
+ Returns:
528
+ Text embeddings tensor
529
+ """
530
+ if self.model is None:
531
+ # Dummy implementation if transformers not available
532
+ return torch.zeros(1, 77, 768)
533
+
534
+ inputs = self.tokenizer(
535
+ text,
536
+ padding="max_length",
537
+ max_length=self.max_length,
538
+ truncation=True,
539
+ return_tensors="pt",
540
+ )
541
+
542
+ if device is not None:
543
+ inputs = {k: v.to(device) for k, v in inputs.items()}
544
+
545
+ outputs = self.model(**inputs)
546
+ return outputs.last_hidden_state
547
+
548
+
549
+ def create_unet(config):
550
+ """Factory function to create UNet from config"""
551
+ unet_config = config['model']['unet']
552
+ return UNet2DConditionModel(
553
+ in_channels=unet_config['in_channels'],
554
+ out_channels=unet_config['out_channels'],
555
+ block_out_channels=tuple(unet_config['block_out_channels']),
556
+ layers_per_block=unet_config['layers_per_block'],
557
+ attention_head_dim=unet_config['attention_head_dim'],
558
+ cross_attention_dim=unet_config['cross_attention_dim'],
559
+ use_linear_projection=unet_config['use_linear_projection'],
560
+ )
561
+
562
+
563
+ def create_vae(config):
564
+ """Factory function to create VAE from config"""
565
+ vae_config = config['model']['vae']
566
+ return AutoencoderKL(
567
+ in_channels=vae_config['in_channels'],
568
+ out_channels=vae_config['out_channels'],
569
+ down_block_types=tuple(vae_config['down_block_types']),
570
+ up_block_types=tuple(vae_config['up_block_types']),
571
+ latent_channels=vae_config['latent_channels'],
572
+ sample_size=vae_config['sample_size'],
573
+ )
574
+
575
+
576
+ def create_text_encoder(config):
577
+ """Factory function to create text encoder from config"""
578
+ text_config = config['model']['text_encoder']
579
+ return CLIPTextModel(
580
+ model_name=text_config['model'],
581
+ max_length=text_config['max_length'],
582
+ )
bytedream/pipeline.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream Pipeline
3
+ Complete diffusion pipeline for text-to-image generation
4
+ Integrates all components: text encoder, UNet, VAE, and scheduler
5
+ """
6
+
7
+ import torch
8
+ import numpy as np
9
+ from typing import Optional, Union, List
10
+ from PIL import Image
11
+ import gc
12
+
13
+
14
+ class ByteDreamPipeline:
15
+ """
16
+ Complete pipeline for text-to-image generation
17
+ Manages the entire diffusion process from prompt to final image
18
+ Optimized for CPU inference with memory efficiency
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ text_encoder,
24
+ vae,
25
+ unet,
26
+ scheduler,
27
+ device: str = "cpu",
28
+ dtype: torch.dtype = torch.float32,
29
+ ):
30
+ self.text_encoder = text_encoder
31
+ self.vae = vae
32
+ self.unet = unet
33
+ self.scheduler = scheduler
34
+
35
+ self.device = torch.device(device)
36
+ self.dtype = dtype
37
+
38
+ # Move models to device
39
+ self._move_models_to_device()
40
+
41
+ # Set models to evaluation mode
42
+ self._set_eval_mode()
43
+
44
+ def _move_models_to_device(self):
45
+ """Move all models to target device with memory optimization"""
46
+ print(f"Loading models to {self.device}...")
47
+
48
+ if hasattr(self.text_encoder, 'model') and self.text_encoder.model is not None:
49
+ self.text_encoder.model.to(self.device)
50
+
51
+ self.vae.to(self.device)
52
+ self.unet.to(self.device)
53
+
54
+ def _set_eval_mode(self):
55
+ """Set all models to evaluation mode"""
56
+ if hasattr(self.text_encoder, 'model') and self.text_encoder.model is not None:
57
+ self.text_encoder.model.eval()
58
+ self.vae.eval()
59
+ self.unet.eval()
60
+
61
+ @torch.no_grad()
62
+ def encode_prompt(
63
+ self,
64
+ prompt: Union[str, List[str]],
65
+ negative_prompt: Optional[Union[str, List[str]]] = None,
66
+ num_images_per_prompt: int = 1,
67
+ ) -> torch.Tensor:
68
+ """
69
+ Encode text prompts to embeddings
70
+
71
+ Args:
72
+ prompt: Text prompt or list of prompts
73
+ negative_prompt: Negative prompt for guidance
74
+ num_images_per_prompt: Number of images to generate per prompt
75
+
76
+ Returns:
77
+ Text embeddings tensor
78
+ """
79
+ # Handle batch size
80
+ if isinstance(prompt, str):
81
+ prompt = [prompt]
82
+
83
+ batch_size = len(prompt) * num_images_per_prompt
84
+
85
+ # Encode positive prompt
86
+ text_embeddings = self.text_encoder(prompt, device=self.device)
87
+ text_embeddings = text_embeddings.to(self.dtype)
88
+
89
+ # Encode negative prompt if provided
90
+ if negative_prompt is not None:
91
+ if isinstance(negative_prompt, str):
92
+ negative_prompt = [negative_prompt]
93
+
94
+ uncond_embeddings = self.text_encoder(negative_prompt, device=self.device)
95
+ uncond_embeddings = uncond_embeddings.to(self.dtype)
96
+
97
+ # Concatenate for classifier-free guidance
98
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
99
+
100
+ return text_embeddings
101
+
102
+ @torch.no_grad()
103
+ def decode_latents(self, latents: torch.Tensor) -> Image.Image:
104
+ """
105
+ Decode latent representation to image
106
+
107
+ Args:
108
+ latents: Latent space tensor
109
+
110
+ Returns:
111
+ PIL Image
112
+ """
113
+ # Scale latents
114
+ latents = 1 / 0.18215 * latents
115
+
116
+ # Decode through VAE
117
+ image = self.vae.decode(latents)
118
+ image = torch.clamp(image, -1, 1)
119
+
120
+ # Convert to PIL Image
121
+ image = (image / 2 + 0.5).clamp(0, 1)
122
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
123
+ image = (image * 255).round().astype("uint8")
124
+
125
+ return Image.fromarray(image)
126
+
127
+ @torch.no_grad()
128
+ def prepare_latents(
129
+ self,
130
+ batch_size: int,
131
+ height: int,
132
+ width: int,
133
+ generator: Optional[torch.Generator] = None,
134
+ ) -> torch.Tensor:
135
+ """
136
+ Initialize random noise latents
137
+
138
+ Args:
139
+ batch_size: Number of images to generate
140
+ height: Image height
141
+ width: Image width
142
+ generator: Random number generator
143
+
144
+ Returns:
145
+ Initial noise tensor
146
+ """
147
+ shape = (batch_size, 4, height // 8, width // 8)
148
+ latents = torch.randn(shape, generator=generator, dtype=self.dtype)
149
+ latents = latents.to(self.device)
150
+
151
+ # Scale initial noise
152
+ latents = latents * self.scheduler.init_noise_scale if hasattr(self.scheduler, 'init_noise_scale') else latents
153
+
154
+ return latents
155
+
156
+ @torch.no_grad()
157
+ def __call__(
158
+ self,
159
+ prompt: Union[str, List[str]],
160
+ negative_prompt: Optional[Union[str, List[str]]] = None,
161
+ height: int = 512,
162
+ width: int = 512,
163
+ num_inference_steps: int = 50,
164
+ guidance_scale: float = 7.5,
165
+ eta: float = 0.0,
166
+ generator: Optional[torch.Generator] = None,
167
+ output_type: str = "pil",
168
+ return_dict: bool = False,
169
+ ) -> Union[List[Image.Image], tuple]:
170
+ """
171
+ Generate images from text prompts
172
+
173
+ Args:
174
+ prompt: Text prompt(s) for generation
175
+ negative_prompt: Negative prompt for better quality
176
+ height: Output image height
177
+ width: Output image width
178
+ num_inference_steps: Number of denoising steps
179
+ guidance_scale: Classifier-free guidance scale
180
+ eta: DDIM eta parameter
181
+ generator: Random number generator
182
+ output_type: Output format ("pil" or "tensor")
183
+ return_dict: Whether to return as dictionary
184
+
185
+ Returns:
186
+ Generated images or tuple
187
+ """
188
+ # Default settings
189
+ if negative_prompt is None:
190
+ negative_prompt = ""
191
+
192
+ # Batch size
193
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
194
+
195
+ # Encode prompts
196
+ text_embeddings = self.encode_prompt(
197
+ prompt=prompt,
198
+ negative_prompt=negative_prompt,
199
+ num_images_per_prompt=1,
200
+ )
201
+
202
+ # Prepare timesteps
203
+ self.scheduler.set_timesteps(num_inference_steps)
204
+ timesteps = self.scheduler.timesteps
205
+
206
+ # Prepare latents
207
+ latents = self.prepare_latents(
208
+ batch_size=batch_size,
209
+ height=height,
210
+ width=width,
211
+ generator=generator,
212
+ )
213
+
214
+ # Denoising loop
215
+ for i, t in enumerate(timesteps):
216
+ # Expand latents for classifier-free guidance
217
+ latent_model_input = torch.cat([latents] * 2) if negative_prompt else latents
218
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
219
+
220
+ # Predict noise
221
+ timestep_tensor = torch.tensor([t], dtype=torch.long, device=self.device)
222
+ noise_pred = self.unet(
223
+ sample=latent_model_input,
224
+ timestep=timestep_tensor,
225
+ encoder_hidden_states=text_embeddings,
226
+ )
227
+
228
+ # Apply classifier-free guidance
229
+ if negative_prompt:
230
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
231
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
232
+
233
+ # Compute previous noisy sample
234
+ latents, _ = self.scheduler.step(
235
+ model_output=noise_pred,
236
+ timestep=t,
237
+ sample=latents,
238
+ eta=eta,
239
+ generator=generator,
240
+ )
241
+
242
+ # Print progress
243
+ if (i + 1) % 10 == 0 or i == len(timesteps) - 1:
244
+ print(f"Step {i+1}/{len(timesteps)}")
245
+
246
+ # Decode to image
247
+ image = self.decode_latents(latents)
248
+
249
+ if output_type != "pil":
250
+ return (image,) if not return_dict else {"images": [image]}
251
+
252
+ return [image] if not return_dict else {"images": [image]}
253
+
254
+ def enable_memory_efficient_mode(self):
255
+ """Enable memory-efficient mode for CPU inference"""
256
+ # Clear CUDA cache if available
257
+ if torch.cuda.is_available():
258
+ torch.cuda.empty_cache()
259
+
260
+ # Force garbage collection
261
+ gc.collect()
262
+
263
+ print("Memory efficient mode enabled")
264
+
265
+ def optimize_for_cpu(self, threads: int = -1):
266
+ """
267
+ Optimize pipeline for CPU inference
268
+
269
+ Args:
270
+ threads: Number of threads to use (-1 for all available)
271
+ """
272
+ if threads > 0:
273
+ torch.set_num_threads(threads)
274
+
275
+ # Set optimal number of threads
276
+ if threads == -1:
277
+ import os
278
+ torch.set_num_threads(os.cpu_count())
279
+
280
+ print(f"Optimized for CPU with {torch.get_num_threads()} threads")
281
+
282
+
283
+ def create_pipeline(config, device: str = "cpu"):
284
+ """
285
+ Factory function to create complete pipeline from config
286
+
287
+ Args:
288
+ config: Configuration dictionary
289
+ device: Target device
290
+
291
+ Returns:
292
+ ByteDreamPipeline instance
293
+ """
294
+ from .model import create_unet, create_vae, create_text_encoder
295
+ from .scheduler import create_scheduler
296
+
297
+ # Create components
298
+ text_encoder = create_text_encoder(config)
299
+ vae = create_vae(config)
300
+ unet = create_unet(config)
301
+ scheduler = create_scheduler(config)
302
+
303
+ # Create pipeline
304
+ pipeline = ByteDreamPipeline(
305
+ text_encoder=text_encoder,
306
+ vae=vae,
307
+ unet=unet,
308
+ scheduler=scheduler,
309
+ device=device,
310
+ )
311
+
312
+ return pipeline
bytedream/scheduler.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream Diffusion Scheduler
3
+ Implements DDIM (Denoising Diffusion Implicit Models) sampling for fast, high-quality generation
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from typing import Optional, Tuple, Union
9
+ import math
10
+
11
+
12
+ class DDIMScheduler:
13
+ """
14
+ DDIM Scheduler for diffusion sampling
15
+ Provides deterministic sampling with fewer steps than traditional DDPM
16
+ Optimized for CPU inference with efficient computation
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_train_timesteps: int = 1000,
22
+ beta_start: float = 0.00085,
23
+ beta_end: float = 0.012,
24
+ beta_schedule: str = "scaled_linear",
25
+ clip_sample: bool = False,
26
+ set_alpha_to_one: bool = False,
27
+ ):
28
+ self.num_train_timesteps = num_train_timesteps
29
+ self.beta_start = beta_start
30
+ self.beta_end = beta_end
31
+ self.beta_schedule = beta_schedule
32
+ self.clip_sample = clip_sample
33
+ self.set_alpha_to_one = set_alpha_to_one
34
+
35
+ # Compute betas
36
+ if beta_schedule == "scaled_linear":
37
+ self.betas = torch.linspace(
38
+ beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32
39
+ ) ** 2
40
+ elif beta_schedule == "linear":
41
+ self.betas = torch.linspace(
42
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
43
+ )
44
+ else:
45
+ raise ValueError(f"Unknown beta schedule: {beta_schedule}")
46
+
47
+ # Compute alphas
48
+ self.alphas = 1.0 - self.betas
49
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
50
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
51
+
52
+ # Set timesteps
53
+ self.timesteps = None
54
+
55
+ def set_timesteps(self, num_inference_steps: int) -> None:
56
+ """
57
+ Set timesteps for inference
58
+
59
+ Args:
60
+ num_inference_steps: Number of denoising steps
61
+ """
62
+ step_ratio = self.num_train_timesteps // num_inference_steps
63
+ self.timesteps = (
64
+ (torch.arange(0, num_inference_steps) * step_ratio)
65
+ .round()
66
+ .flip(dims=[0])
67
+ .to(dtype=torch.long)
68
+ )
69
+
70
+ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
71
+ """Compute variance for the diffusion step"""
72
+ alpha_prod_t = self.alphas_cumprod[timestep]
73
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
74
+ beta_prod_t = 1 - alpha_prod_t
75
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
76
+
77
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
78
+ return variance
79
+
80
+ def step(
81
+ self,
82
+ model_output: torch.Tensor,
83
+ timestep: int,
84
+ sample: torch.Tensor,
85
+ eta: float = 0.0,
86
+ use_clipped_model_output: bool = False,
87
+ generator: Optional[torch.Generator] = None,
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Perform a single denoising step
91
+
92
+ Args:
93
+ model_output: Predicted noise from UNet
94
+ timestep: Current timestep
95
+ sample: Current noisy sample
96
+ eta: DDIM eta parameter (0 for deterministic)
97
+ use_clipped_model_output: Whether to clip model output
98
+ generator: Random number generator
99
+
100
+ Returns:
101
+ Tuple of (previous_sample, pred_original_sample)
102
+ """
103
+ # Get previous timestep
104
+ prev_timestep = timestep - self.num_train_timesteps // len(self.timesteps)
105
+
106
+ # Compute alpha and sigma
107
+ alpha_prod_t = self.alphas_cumprod[timestep]
108
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
109
+ beta_prod_t = 1 - alpha_prod_t
110
+
111
+ # Compute predicted original sample
112
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
113
+
114
+ if use_clipped_model_output:
115
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
116
+
117
+ # Compute direction pointing to x_t
118
+ model_output_direction = (1 - alpha_prod_t_prev) ** 0.5
119
+
120
+ # Compute sigma (eta * std_dev)
121
+ variance = self._get_variance(timestep, prev_timestep)
122
+ std_dev_t = eta * variance ** 0.5
123
+
124
+ # Compute x_{t-1}
125
+ pred_sample_direction = alpha_prod_t_prev ** 0.5
126
+ prev_sample = pred_sample_direction * pred_original_sample + model_output_direction * model_output
127
+
128
+ # Add noise if eta > 0
129
+ if eta > 0:
130
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(model_output.device)
131
+ prev_sample = prev_sample + std_dev_t * noise
132
+
133
+ # Clip if necessary
134
+ if self.clip_sample:
135
+ prev_sample = torch.clamp(prev_sample, -1, 1)
136
+
137
+ return prev_sample, pred_original_sample
138
+
139
+ def add_noise(
140
+ self,
141
+ original_samples: torch.Tensor,
142
+ noise: torch.Tensor,
143
+ timesteps: torch.Tensor,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Add noise to samples (forward diffusion process)
147
+
148
+ Args:
149
+ original_samples: Original clean samples
150
+ noise: Noise to add
151
+ timesteps: Timesteps for each sample
152
+
153
+ Returns:
154
+ Noisy samples
155
+ """
156
+ alpha_prod_t = self.alphas_cumprod[timesteps].view(-1, 1, 1, 1)
157
+ sqrt_alpha_prod = alpha_prod_t ** 0.5
158
+ sqrt_one_minus_alpha_prod = (1 - alpha_prod_t) ** 0.5
159
+
160
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
161
+ return noisy_samples
162
+
163
+ def scale_model_input(
164
+ self,
165
+ sample: torch.Tensor,
166
+ timestep: Optional[int] = None,
167
+ ) -> torch.Tensor:
168
+ """
169
+ Scale sample by standard deviation (for compatibility)
170
+
171
+ Args:
172
+ sample: Input sample
173
+ timestep: Current timestep
174
+
175
+ Returns:
176
+ Scaled sample
177
+ """
178
+ return sample
179
+
180
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
181
+ """Get scalings for boundary condition"""
182
+ alpha_prod_t = self.alphas_cumprod[timestep]
183
+ sigma_t = ((1 - alpha_prod_t) * alpha_prod_t / (alpha_prod_t)) ** 0.5
184
+ c_out = -sigma_t
185
+ c_in = 1 / (alpha_prod_t ** 0.5)
186
+ return c_out, c_in
187
+
188
+
189
+ class EulerDiscreteScheduler:
190
+ """
191
+ Euler discretization scheduler for ODE-based sampling
192
+ Alternative to DDIM with different sampling characteristics
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ num_train_timesteps: int = 1000,
198
+ beta_start: float = 0.00085,
199
+ beta_end: float = 0.012,
200
+ beta_schedule: str = "scaled_linear",
201
+ ):
202
+ self.num_train_timesteps = num_train_timesteps
203
+ self.beta_start = beta_start
204
+ self.beta_end = beta_end
205
+ self.beta_schedule = beta_schedule
206
+
207
+ # Compute betas and sigmas
208
+ betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps) ** 2
209
+ alphas = 1.0 - betas
210
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
211
+
212
+ self.sigmas = torch.cat([
213
+ torch.ones(1),
214
+ ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
215
+ ])
216
+
217
+ self.timesteps = None
218
+
219
+ def set_timesteps(self, num_inference_steps: int) -> None:
220
+ """Set timesteps for Euler sampling"""
221
+ step_ratio = len(self.sigmas) // num_inference_steps
222
+ self.timesteps = torch.arange(0, num_inference_steps) * step_ratio
223
+ self.timesteps = self.timesteps.flip(0)
224
+
225
+ def step(
226
+ self,
227
+ model_output: torch.Tensor,
228
+ timestep: int,
229
+ sample: torch.Tensor,
230
+ ) -> torch.Tensor:
231
+ """Perform Euler step"""
232
+ sigma_from = self.sigmas[timestep]
233
+ sigma_to = self.sigmas[timestep + 1] if timestep + 1 < len(self.sigmas) else torch.tensor(0.0)
234
+
235
+ sample_normalized = sample / ((sigma_from ** 2 + 1) ** 0.5)
236
+ derivative = (sample - sample_normalized) / sigma_from
237
+
238
+ dt = sigma_to - sigma_from
239
+ prev_sample = sample + derivative * dt
240
+
241
+ return prev_sample
242
+
243
+
244
+ def create_scheduler(config):
245
+ """
246
+ Factory function to create scheduler from config
247
+
248
+ Args:
249
+ config: Configuration dictionary
250
+
251
+ Returns:
252
+ Scheduler instance
253
+ """
254
+ sched_config = config['model']['scheduler']
255
+
256
+ if sched_config['name'] == 'DDIM':
257
+ return DDIMScheduler(
258
+ num_train_timesteps=sched_config['num_train_timesteps'],
259
+ beta_start=sched_config['beta_start'],
260
+ beta_end=sched_config['beta_end'],
261
+ beta_schedule=sched_config['beta_schedule'],
262
+ clip_sample=sched_config['clip_sample'],
263
+ set_alpha_to_one=sched_config['set_alpha_to_one'],
264
+ )
265
+ elif sched_config['name'] == 'EulerDiscrete':
266
+ return EulerDiscreteScheduler(
267
+ num_train_timesteps=sched_config['num_train_timesteps'],
268
+ beta_start=sched_config['beta_start'],
269
+ beta_end=sched_config['beta_end'],
270
+ beta_schedule=sched_config['beta_schedule'],
271
+ )
272
+ else:
273
+ raise ValueError(f"Unknown scheduler: {sched_config['name']}")
bytedream/utils.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream Utilities
3
+ Helper functions for image processing, model management, and optimization
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ from pathlib import Path
10
+ import hashlib
11
+ import json
12
+ from typing import Optional, Tuple, List
13
+
14
+
15
+ def load_image(image_path: str) -> Image.Image:
16
+ """
17
+ Load image from file
18
+
19
+ Args:
20
+ image_path: Path to image file
21
+
22
+ Returns:
23
+ PIL Image object
24
+ """
25
+ path = Path(image_path)
26
+
27
+ if not path.exists():
28
+ raise FileNotFoundError(f"Image not found: {image_path}")
29
+
30
+ try:
31
+ image = Image.open(path).convert('RGB')
32
+ return image
33
+ except Exception as e:
34
+ raise IOError(f"Error loading image: {e}")
35
+
36
+
37
+ def save_image(
38
+ image: Image.Image,
39
+ output_path: str,
40
+ format: str = None,
41
+ quality: int = 95,
42
+ ):
43
+ """
44
+ Save image to file
45
+
46
+ Args:
47
+ image: PIL Image to save
48
+ output_path: Output file path
49
+ format: Image format (PNG, JPEG, etc.)
50
+ quality: JPEG quality (1-100)
51
+ """
52
+ path = Path(output_path)
53
+ path.parent.mkdir(parents=True, exist_ok=True)
54
+
55
+ # Auto-detect format from extension
56
+ if format is None:
57
+ format = path.suffix.upper().replace('.', '')
58
+ if format == 'JPG':
59
+ format = 'JPEG'
60
+
61
+ # Save with appropriate settings
62
+ if format == 'JPEG':
63
+ image.save(path, format=format, quality=quality, optimize=True)
64
+ else:
65
+ image.save(path, format=format, optimize=True)
66
+
67
+ print(f"Image saved to: {path}")
68
+
69
+
70
+ def resize_image(
71
+ image: Image.Image,
72
+ width: Optional[int] = None,
73
+ height: Optional[int] = None,
74
+ maintain_aspect: bool = True,
75
+ ) -> Image.Image:
76
+ """
77
+ Resize image to specified dimensions
78
+
79
+ Args:
80
+ image: Input image
81
+ width: Target width
82
+ height: Target height
83
+ maintain_aspect: Maintain aspect ratio
84
+
85
+ Returns:
86
+ Resized PIL Image
87
+ """
88
+ orig_width, orig_height = image.size
89
+
90
+ if width is None and height is None:
91
+ return image
92
+
93
+ if maintain_aspect:
94
+ if width and height:
95
+ # Fit within bounding box
96
+ ratio = min(width / orig_width, height / orig_height)
97
+ new_width = int(orig_width * ratio)
98
+ new_height = int(orig_height * ratio)
99
+ elif width:
100
+ ratio = width / orig_width
101
+ new_width = width
102
+ new_height = int(orig_height * ratio)
103
+ else:
104
+ ratio = height / orig_height
105
+ new_width = int(orig_width * ratio)
106
+ new_height = height
107
+ else:
108
+ new_width = width if width else orig_width
109
+ new_height = height if height else orig_height
110
+
111
+ resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
112
+ return resized
113
+
114
+
115
+ def center_crop(image: Image.Image, width: int, height: int) -> Image.Image:
116
+ """
117
+ Center crop image to specified dimensions
118
+
119
+ Args:
120
+ image: Input image
121
+ width: Crop width
122
+ height: Crop height
123
+
124
+ Returns:
125
+ Cropped PIL Image
126
+ """
127
+ orig_width, orig_height = image.size
128
+
129
+ left = (orig_width - width) // 2
130
+ top = (orig_height - height) // 2
131
+ right = left + width
132
+ bottom = top + height
133
+
134
+ cropped = image.crop((left, top, right, bottom))
135
+ return cropped
136
+
137
+
138
+ def image_to_tensor(image: Image.Image) -> torch.Tensor:
139
+ """
140
+ Convert PIL Image to PyTorch tensor
141
+
142
+ Args:
143
+ image: PIL Image
144
+
145
+ Returns:
146
+ Normalized tensor in range [-1, 1]
147
+ """
148
+ # Convert to numpy array
149
+ img_array = np.array(image).astype(np.float32)
150
+
151
+ # Normalize to [0, 1]
152
+ img_array = img_array / 255.0
153
+
154
+ # Normalize to [-1, 1]
155
+ img_array = 2.0 * img_array - 1.0
156
+
157
+ # Convert to tensor and rearrange to CHW format
158
+ tensor = torch.from_numpy(img_array).permute(2, 0, 1)
159
+
160
+ return tensor
161
+
162
+
163
+ def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
164
+ """
165
+ Convert PyTorch tensor to PIL Image
166
+
167
+ Args:
168
+ tensor: Tensor in range [-1, 1], shape (B, C, H, W) or (C, H, W)
169
+
170
+ Returns:
171
+ PIL Image
172
+ """
173
+ # Handle batch dimension
174
+ if tensor.dim() == 4:
175
+ tensor = tensor[0]
176
+
177
+ # Convert from CHW to HWC
178
+ img_array = tensor.cpu().numpy().transpose(1, 2, 0)
179
+
180
+ # Clip to valid range
181
+ img_array = np.clip(img_array, -1, 1)
182
+
183
+ # Convert from [-1, 1] to [0, 255]
184
+ img_array = ((img_array + 1.0) * 127.5).round().astype(np.uint8)
185
+
186
+ # Ensure RGB format
187
+ if img_array.shape[2] == 1:
188
+ img_array = np.repeat(img_array, 3, axis=2)
189
+
190
+ image = Image.fromarray(img_array)
191
+ return image
192
+
193
+
194
+ def generate_prompt_hash(prompt: str) -> str:
195
+ """
196
+ Generate unique hash for a prompt
197
+
198
+ Args:
199
+ prompt: Text prompt
200
+
201
+ Returns:
202
+ Short hash string
203
+ """
204
+ hash_object = hashlib.md5(prompt.encode())
205
+ return hash_object.hexdigest()[:8]
206
+
207
+
208
+ def get_model_statistics(model: torch.nn.Module) -> dict:
209
+ """
210
+ Get model parameter statistics
211
+
212
+ Args:
213
+ model: PyTorch model
214
+
215
+ Returns:
216
+ Dictionary with parameter counts
217
+ """
218
+ total_params = sum(p.numel() for p in model.parameters())
219
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
220
+
221
+ param_size = 0
222
+ for param in model.parameters():
223
+ param_size += param.numel() * param.element_size()
224
+
225
+ buffer_size = 0
226
+ for buffer in model.buffers():
227
+ buffer_size += buffer.numel() * buffer.element_size()
228
+
229
+ size_mb = (param_size + buffer_size) / 1024 ** 2
230
+
231
+ stats = {
232
+ 'total_parameters': total_params,
233
+ 'trainable_parameters': trainable_params,
234
+ 'non_trainable_parameters': total_params - trainable_params,
235
+ 'model_size_mb': round(size_mb, 2),
236
+ }
237
+
238
+ return stats
239
+
240
+
241
+ def optimize_memory_usage(device: str = "cpu"):
242
+ """
243
+ Optimize memory usage for inference
244
+
245
+ Args:
246
+ device: Target device
247
+ """
248
+ import gc
249
+
250
+ # Clear CUDA cache if available
251
+ if torch.cuda.is_available():
252
+ torch.cuda.empty_cache()
253
+
254
+ # Force garbage collection
255
+ gc.collect()
256
+
257
+ # Set memory allocator for CPU
258
+ if device == "cpu":
259
+ # Enable memory efficient attention if available
260
+ try:
261
+ import os
262
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
263
+ except:
264
+ pass
265
+
266
+ print("Memory optimization applied")
267
+
268
+
269
+ def set_seed(seed: int):
270
+ """
271
+ Set random seed for reproducibility
272
+
273
+ Args:
274
+ seed: Random seed value
275
+ """
276
+ torch.manual_seed(seed)
277
+ if torch.cuda.is_available():
278
+ torch.cuda.manual_seed_all(seed)
279
+ np.random.seed(seed)
280
+
281
+
282
+ def validate_prompt(prompt: str) -> Tuple[bool, str]:
283
+ """
284
+ Validate and sanitize prompt
285
+
286
+ Args:
287
+ prompt: Input prompt
288
+
289
+ Returns:
290
+ Tuple of (is_valid, message)
291
+ """
292
+ if not prompt or not prompt.strip():
293
+ return False, "Prompt cannot be empty"
294
+
295
+ if len(prompt) > 1000:
296
+ return False, "Prompt too long (max 1000 characters)"
297
+
298
+ # Check for potentially harmful content
299
+ forbidden_terms = []
300
+ for term in forbidden_terms:
301
+ if term.lower() in prompt.lower():
302
+ return False, f"Prompt contains forbidden term: {term}"
303
+
304
+ return True, "Valid prompt"
305
+
306
+
307
+ def create_image_grid(
308
+ images: List[Image.Image],
309
+ rows: int = None,
310
+ cols: int = None,
311
+ ) -> Image.Image:
312
+ """
313
+ Create a grid of images
314
+
315
+ Args:
316
+ images: List of PIL Images
317
+ rows: Number of rows
318
+ cols: Number of columns
319
+
320
+ Returns:
321
+ Grid image
322
+ """
323
+ if not images:
324
+ raise ValueError("No images provided")
325
+
326
+ num_images = len(images)
327
+
328
+ # Determine grid dimensions
329
+ if rows is None and cols is None:
330
+ cols = int(np.ceil(np.sqrt(num_images)))
331
+ rows = int(np.ceil(num_images / cols))
332
+ elif rows is None:
333
+ rows = int(np.ceil(num_images / cols))
334
+ elif cols is None:
335
+ cols = int(np.ceil(num_images / rows))
336
+
337
+ # Get image size (use first image as reference)
338
+ width, height = images[0].size
339
+
340
+ # Create grid image
341
+ grid_width = cols * width
342
+ grid_height = rows * height
343
+ grid_image = Image.new('RGB', (grid_width, grid_height), color='white')
344
+
345
+ # Paste images into grid
346
+ for i, image in enumerate(images):
347
+ row = i // cols
348
+ col = i % cols
349
+ x = col * width
350
+ y = row * height
351
+ grid_image.paste(image, (x, y))
352
+
353
+ return grid_image
354
+
355
+
356
+ def get_device_info() -> dict:
357
+ """
358
+ Get device information
359
+
360
+ Returns:
361
+ Dictionary with device info
362
+ """
363
+ info = {
364
+ 'cuda_available': torch.cuda.is_available(),
365
+ 'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
366
+ 'cpu_cores': __import__('os').cpu_count(),
367
+ }
368
+
369
+ if torch.cuda.is_available():
370
+ info['current_device'] = torch.cuda.current_device()
371
+ info['device_name'] = torch.cuda.get_device_name(0)
372
+ info['cuda_version'] = torch.version.cuda
373
+
374
+ return info
375
+
376
+
377
+ class ProgressTracker:
378
+ """Track progress of long-running operations"""
379
+
380
+ def __init__(self, total: int, description: str = ""):
381
+ self.total = total
382
+ self.current = 0
383
+ self.description = description
384
+
385
+ def update(self, n: int = 1):
386
+ """Update progress"""
387
+ self.current += n
388
+
389
+ def get_progress(self) -> float:
390
+ """Get progress percentage"""
391
+ return (self.current / self.total) * 100 if self.total > 0 else 0
392
+
393
+ def __str__(self):
394
+ percent = self.get_progress()
395
+ bar_length = 30
396
+ filled_length = int(bar_length * self.current // self.total)
397
+ bar = '█' * filled_length + '-' * (bar_length - filled_length)
398
+ return f"{self.description}: [{bar}] {percent:.1f}% ({self.current}/{self.total})"
config.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte Dream Configuration
2
+
3
+ model:
4
+ name: "Byte Dream"
5
+ version: "1.0.0"
6
+
7
+ # Model architecture parameters
8
+ unet:
9
+ in_channels: 4
10
+ out_channels: 4
11
+ block_out_channels: [320, 640, 1280, 1280]
12
+ layers_per_block: 2
13
+ attention_head_dim: 8
14
+ cross_attention_dim: 768
15
+ use_linear_projection: true
16
+
17
+ scheduler:
18
+ name: "DDIM" # Options: DDIM, PNDM, LMSDiscrete, EulerDiscrete
19
+ num_train_timesteps: 1000
20
+ beta_start: 0.00085
21
+ beta_end: 0.012
22
+ beta_schedule: "scaled_linear"
23
+ clip_sample: false
24
+ set_alpha_to_one: false
25
+
26
+ vae:
27
+ in_channels: 3
28
+ out_channels: 3
29
+ down_block_types: ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"]
30
+ up_block_types: ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"]
31
+ latent_channels: 4
32
+ sample_size: 512
33
+
34
+ text_encoder:
35
+ model: "openai/clip-vit-large-patch14"
36
+ max_length: 77
37
+
38
+ # Generation parameters
39
+ generation:
40
+ width: 512
41
+ height: 512
42
+ num_inference_steps: 50
43
+ guidance_scale: 7.5
44
+ negative_prompt: "ugly, blurry, low quality, distorted, deformed"
45
+ seed: null # null for random, or set integer
46
+
47
+ # CPU Optimization
48
+ cpu_optimization:
49
+ use_openvino: false
50
+ use_onnx: false
51
+ precision: "fp32" # fp32 or fp16
52
+ threads: -1 # -1 for all available threads
53
+ memory_limit: null # null for auto, or MB value
54
+
55
+ # Training parameters
56
+ training:
57
+ dataset_path: "./dataset"
58
+ output_dir: "./models/bytedream"
59
+ epochs: 100
60
+ batch_size: 4
61
+ gradient_accumulation_steps: 1
62
+ learning_rate: 1e-5
63
+ lr_scheduler: "constant_with_warmup"
64
+ lr_warmup_steps: 500
65
+ max_grad_norm: 1.0
66
+ mixed_precision: "no" # no, fp16, bf16
67
+
68
+ # Data augmentation
69
+ random_flip: true
70
+ random_crop: false
71
+ center_crop: true
72
+
73
+ # Logging
74
+ logging_dir: "./logs"
75
+ log_every_n_steps: 10
76
+
77
+ # Hugging Face
78
+ huggingface:
79
+ organization: "" # Your HF username/organization
80
+ private: false
81
+ push_to_hub: true
environment.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: bytedream
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - python=3.10
8
+ - pip
9
+ - pip:
10
+ - transformers>=4.35.0
11
+ - diffusers>=0.24.0
12
+ - torch>=2.1.0
13
+ - torchaudio>=2.1.0
14
+ - accelerate>=0.25.0
15
+ - numpy>=1.24.0
16
+ - pillow>=10.0.0
17
+ - opencv-python>=4.8.0
18
+ - safetensors>=0.4.0
19
+ - huggingface_hub>=0.19.0
20
+ - gradio>=4.0.0
21
+ - tqdm>=4.66.0
22
+ - pyyaml>=6.0
23
+ - matplotlib>=3.8.0
24
+ - scipy>=1.11.0
25
+ - einops>=0.7.0
examples.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream - Example Usage Scripts
3
+ Practical examples for different use cases
4
+ """
5
+
6
+ from bytedream import ByteDreamGenerator
7
+ from pathlib import Path
8
+
9
+
10
+ def example_basic_generation():
11
+ """Basic image generation example"""
12
+ print("\n" + "="*60)
13
+ print("Example 1: Basic Generation")
14
+ print("="*60)
15
+
16
+ generator = ByteDreamGenerator()
17
+
18
+ # Simple prompt
19
+ image = generator.generate(
20
+ prompt="A beautiful sunset over mountains, digital art",
21
+ )
22
+
23
+ image.save("example_basic.png")
24
+ print("✓ Saved to: example_basic.png")
25
+
26
+
27
+ def example_advanced_parameters():
28
+ """Advanced parameter tuning"""
29
+ print("\n" + "="*60)
30
+ print("Example 2: Advanced Parameters")
31
+ print("="*60)
32
+
33
+ generator = ByteDreamGenerator()
34
+
35
+ # Custom parameters
36
+ image = generator.generate(
37
+ prompt="Cyberpunk city at night, neon lights, futuristic architecture",
38
+ negative_prompt="ugly, blurry, low quality, distorted, dark",
39
+ width=768,
40
+ height=768,
41
+ num_inference_steps=75,
42
+ guidance_scale=9.0,
43
+ seed=42,
44
+ )
45
+
46
+ image.save("example_advanced.png")
47
+ print("✓ Saved to: example_advanced.png")
48
+
49
+
50
+ def example_batch_generation():
51
+ """Generate multiple images"""
52
+ print("\n" + "="*60)
53
+ print("Example 3: Batch Generation")
54
+ print("="*60)
55
+
56
+ generator = ByteDreamGenerator()
57
+
58
+ prompts = [
59
+ "Fantasy landscape with castle and waterfall, epic scenery",
60
+ "Underwater coral reef, tropical fish, sunlight through water",
61
+ "Space nebula, colorful clouds, stars, cosmic scene",
62
+ "Medieval knight in armor, dramatic lighting, portrait",
63
+ "Japanese garden, cherry blossoms, peaceful atmosphere",
64
+ ]
65
+
66
+ images = generator.generate_batch(
67
+ prompts=prompts,
68
+ width=512,
69
+ height=512,
70
+ num_inference_steps=50,
71
+ guidance_scale=7.5,
72
+ )
73
+
74
+ # Save individually
75
+ for i, (prompt, image) in enumerate(zip(prompts, images)):
76
+ filename = f"batch_{i+1}.png"
77
+ image.save(filename)
78
+ print(f"✓ Saved: {filename}")
79
+
80
+ # Create grid
81
+ from bytedream.utils import create_image_grid
82
+ grid = create_image_grid(images)
83
+ grid.save("batch_grid.png")
84
+ print("✓ Grid saved to: batch_grid.png")
85
+
86
+
87
+ def example_artistic_styles():
88
+ """Different artistic styles"""
89
+ print("\n" + "="*60)
90
+ print("Example 4: Artistic Styles")
91
+ print("="*60)
92
+
93
+ generator = ByteDreamGenerator()
94
+
95
+ style_prompts = [
96
+ ("Oil Painting", "Portrait of a woman, oil painting style, brush strokes, classical art"),
97
+ ("Watercolor", "Forest landscape, watercolor painting, soft colors, artistic"),
98
+ ("Digital Art", "Sci-fi spaceship, digital art, concept art, highly detailed"),
99
+ ("Sketch", "City skyline, pencil sketch, black and white, drawing"),
100
+ ("Abstract", "Emotions and dreams, abstract art, colorful shapes, surreal"),
101
+ ]
102
+
103
+ for style_name, prompt in style_prompts:
104
+ print(f"\nGenerating {style_name}...")
105
+
106
+ image = generator.generate(
107
+ prompt=prompt,
108
+ num_inference_steps=50,
109
+ guidance_scale=7.5,
110
+ )
111
+
112
+ filename = f"style_{style_name.lower().replace(' ', '_')}.png"
113
+ image.save(filename)
114
+ print(f"✓ Saved: {filename}")
115
+
116
+
117
+ def example_resolutions():
118
+ """Test different resolutions"""
119
+ print("\n" + "="*60)
120
+ print("Example 5: Different Resolutions")
121
+ print("="*60)
122
+
123
+ generator = ByteDreamGenerator()
124
+
125
+ base_prompt = "Majestic mountain range, snow peaks, blue sky"
126
+
127
+ resolutions = [
128
+ (256, 256),
129
+ (512, 512),
130
+ (768, 768),
131
+ (512, 768), # Portrait
132
+ (768, 512), # Landscape
133
+ ]
134
+
135
+ for width, height in resolutions:
136
+ print(f"\nGenerating {width}x{height}...")
137
+
138
+ image = generator.generate(
139
+ prompt=base_prompt,
140
+ width=width,
141
+ height=height,
142
+ num_inference_steps=40,
143
+ )
144
+
145
+ filename = f"res_{width}x{height}.png"
146
+ image.save(filename)
147
+ print(f"✓ Saved: {filename}")
148
+
149
+
150
+ def example_reproducibility():
151
+ """Demonstrate reproducibility with seeds"""
152
+ print("\n" + "="*60)
153
+ print("Example 6: Reproducibility with Seeds")
154
+ print("="*60)
155
+
156
+ generator = ByteDreamGenerator()
157
+
158
+ prompt = "A mystical forest with glowing mushrooms, fantasy art"
159
+
160
+ # Generate same image twice with same seed
161
+ print("\nGenerating with seed=123...")
162
+ image1 = generator.generate(
163
+ prompt=prompt,
164
+ seed=123,
165
+ )
166
+ image1.save("repro_1.png")
167
+
168
+ print("Generating again with seed=123...")
169
+ image2 = generator.generate(
170
+ prompt=prompt,
171
+ seed=123,
172
+ )
173
+ image2.save("repro_2.png")
174
+
175
+ print("\nBoth images should be identical!")
176
+ print("✓ Check repro_1.png and repro_2.png")
177
+
178
+ # Generate with different seed
179
+ print("\nGenerating with seed=456...")
180
+ image3 = generator.generate(
181
+ prompt=prompt,
182
+ seed=456,
183
+ )
184
+ image3.save("repro_3.png")
185
+ print("This one will be different!")
186
+
187
+
188
+ def example_negative_prompts():
189
+ """Using negative prompts effectively"""
190
+ print("\n" + "="*60)
191
+ print("Example 7: Negative Prompts")
192
+ print("="*60)
193
+
194
+ generator = ByteDreamGenerator()
195
+
196
+ base_prompt = "Beautiful princess, elegant dress, castle background"
197
+
198
+ # Without negative prompt
199
+ print("\nWithout negative prompt...")
200
+ image1 = generator.generate(
201
+ prompt=base_prompt,
202
+ seed=789,
203
+ )
204
+ image1.save("no_negative.png")
205
+
206
+ # With negative prompt
207
+ print("With negative prompt...")
208
+ image2 = generator.generate(
209
+ prompt=base_prompt,
210
+ negative_prompt="ugly, deformed, noisy, blurry, bad anatomy, poorly drawn",
211
+ seed=789,
212
+ )
213
+ image2.save("with_negative.png")
214
+
215
+ print("\nCompare no_negative.png and with_negative.png")
216
+
217
+
218
+ def example_quick_preview():
219
+ """Quick low-resolution previews"""
220
+ print("\n" + "="*60)
221
+ print("Example 8: Quick Preview Mode")
222
+ print("="*60)
223
+
224
+ generator = ByteDreamGenerator()
225
+
226
+ prompt = "Dragon breathing fire, epic fantasy battle scene"
227
+
228
+ # Quick preview
229
+ print("Generating quick preview (256x256, 20 steps)...")
230
+ preview = generator.generate(
231
+ prompt=prompt,
232
+ width=256,
233
+ height=256,
234
+ num_inference_steps=20,
235
+ )
236
+ preview.save("preview.png")
237
+ print("✓ Preview saved")
238
+
239
+ # Full resolution
240
+ print("\nGenerating full quality (768x768, 75 steps)...")
241
+ full = generator.generate(
242
+ prompt=prompt,
243
+ width=768,
244
+ height=768,
245
+ num_inference_steps=75,
246
+ )
247
+ full.save("full_quality.png")
248
+ print("✓ Full quality saved")
249
+
250
+
251
+ def run_all_examples():
252
+ """Run all examples sequentially"""
253
+ print("\n" + "="*60)
254
+ print("Byte Dream - Complete Examples Suite")
255
+ print("="*60)
256
+
257
+ examples = [
258
+ example_basic_generation,
259
+ example_advanced_parameters,
260
+ example_batch_generation,
261
+ example_artistic_styles,
262
+ example_resolutions,
263
+ example_reproducibility,
264
+ example_negative_prompts,
265
+ example_quick_preview,
266
+ ]
267
+
268
+ for example_func in examples:
269
+ try:
270
+ example_func()
271
+ except Exception as e:
272
+ print(f"\n✗ Error in {example_func.__name__}: {e}")
273
+ import traceback
274
+ traceback.print_exc()
275
+
276
+ print("\n" + "-"*60)
277
+ input("Press Enter to continue to next example...")
278
+
279
+ print("\n" + "="*60)
280
+ print("All examples completed!")
281
+ print("="*60)
282
+
283
+
284
+ if __name__ == "__main__":
285
+ import argparse
286
+
287
+ parser = argparse.ArgumentParser(description="Byte Dream Examples")
288
+ parser.add_argument("--all", action="store_true", help="Run all examples")
289
+ parser.add_argument("--basic", action="store_true", help="Run basic example")
290
+ parser.add_argument("--advanced", action="store_true", help="Run advanced example")
291
+ parser.add_argument("--batch", action="store_true", help="Run batch generation")
292
+ parser.add_argument("--styles", action="store_true", help="Run artistic styles")
293
+
294
+ args = parser.parse_args()
295
+
296
+ if args.all:
297
+ run_all_examples()
298
+ elif args.basic:
299
+ example_basic_generation()
300
+ elif args.advanced:
301
+ example_advanced_parameters()
302
+ elif args.batch:
303
+ example_batch_generation()
304
+ elif args.styles:
305
+ example_artistic_styles()
306
+ else:
307
+ # Default: show menu
308
+ print("\nByte Dream Examples")
309
+ print("="*60)
310
+ print("Choose an example:")
311
+ print(" --basic : Basic generation")
312
+ print(" --advanced : Advanced parameters")
313
+ print(" --batch : Batch generation")
314
+ print(" --styles : Artistic styles")
315
+ print(" --all : Run all examples")
316
+ print("\nOr just run without arguments to see all examples")
infer.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream - Command Line Inference Tool
3
+ Generate images from text prompts using the command line
4
+ """
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+ import torch
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Byte Dream - AI Image Generation",
14
+ formatter_class=argparse.RawDescriptionHelpFormatter,
15
+ epilog="""
16
+ Examples:
17
+ # Basic usage
18
+ python infer.py --prompt "A beautiful sunset over mountains"
19
+
20
+ # With custom parameters
21
+ python infer.py --prompt "Cyberpunk city" --negative "blurry" --steps 75 --guidance 8.0
22
+
23
+ # Specify output and size
24
+ python infer.py --prompt "Fantasy landscape" --output fantasy.png --width 768 --height 768
25
+
26
+ # With seed for reproducibility
27
+ python infer.py --prompt "Dragon" --seed 42 --output dragon.png
28
+ """
29
+ )
30
+
31
+ parser.add_argument(
32
+ "--prompt", "-p",
33
+ type=str,
34
+ required=True,
35
+ help="Text prompt describing the desired image"
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--negative", "-n",
40
+ type=str,
41
+ default="",
42
+ help="Negative prompt - things to avoid in the image"
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--output", "-o",
47
+ type=str,
48
+ default="output.png",
49
+ help="Output image filename (default: output.png)"
50
+ )
51
+
52
+ parser.add_argument(
53
+ "--width", "-W",
54
+ type=int,
55
+ default=512,
56
+ help="Image width in pixels (default: 512)"
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--height", "-H",
61
+ type=int,
62
+ default=512,
63
+ help="Image height in pixels (default: 512)"
64
+ )
65
+
66
+ parser.add_argument(
67
+ "--steps", "-s",
68
+ type=int,
69
+ default=50,
70
+ help="Number of inference steps (default: 50)"
71
+ )
72
+
73
+ parser.add_argument(
74
+ "--guidance", "-g",
75
+ type=float,
76
+ default=7.5,
77
+ help="Guidance scale - how closely to follow prompt (default: 7.5)"
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--seed",
82
+ type=int,
83
+ default=None,
84
+ help="Random seed for reproducibility (default: random)"
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--model", "-m",
89
+ type=str,
90
+ default=None,
91
+ help="Path to model directory (default: uses config)"
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--config", "-c",
96
+ type=str,
97
+ default="config.yaml",
98
+ help="Path to config file (default: config.yaml)"
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--device",
103
+ type=str,
104
+ default="cpu",
105
+ help="Device to run on: cpu or cuda (default: cpu)"
106
+ )
107
+
108
+ args = parser.parse_args()
109
+
110
+ # Import generator
111
+ from bytedream.generator import ByteDreamGenerator
112
+
113
+ # Initialize generator
114
+ print("="*60)
115
+ print("Byte Dream - AI Image Generator")
116
+ print("="*60)
117
+
118
+ generator = ByteDreamGenerator(
119
+ model_path=args.model,
120
+ config_path=args.config,
121
+ device=args.device,
122
+ )
123
+
124
+ # Print model info
125
+ info = generator.get_model_info()
126
+ print(f"\nModel: {info['name']} v{info['version']}")
127
+ print(f"Device: {info['device']}")
128
+ print(f"Parameters: {info['unet_parameters']}")
129
+ print("="*60)
130
+
131
+ # Generate image
132
+ image = generator.generate(
133
+ prompt=args.prompt,
134
+ negative_prompt=args.negative if args.negative else None,
135
+ width=args.width,
136
+ height=args.height,
137
+ num_inference_steps=args.steps,
138
+ guidance_scale=args.guidance,
139
+ seed=args.seed,
140
+ )
141
+
142
+ # Save image
143
+ output_path = Path(args.output)
144
+ image.save(output_path)
145
+ print(f"\n✓ Image saved to: {output_path.absolute()}")
146
+ print("="*60)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()
main.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream - Main Application Interface
3
+ Simple Python API for image generation
4
+ """
5
+
6
+ from bytedream.generator import ByteDreamGenerator
7
+ from bytedream.utils import (
8
+ load_image,
9
+ save_image,
10
+ resize_image,
11
+ create_image_grid,
12
+ )
13
+ from typing import Optional, List
14
+ from PIL import Image
15
+
16
+
17
+ class ByteDreamApp:
18
+ """
19
+ High-level application interface for Byte Dream
20
+ Simplifies common tasks like image generation and batch processing
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model_path: Optional[str] = None,
26
+ device: str = "cpu",
27
+ verbose: bool = True,
28
+ ):
29
+ """
30
+ Initialize Byte Dream application
31
+
32
+ Args:
33
+ model_path: Path to model weights
34
+ device: Device to run on
35
+ verbose: Enable verbose output
36
+ """
37
+ self.verbose = verbose
38
+
39
+ if self.verbose:
40
+ print("Initializing Byte Dream Application...")
41
+
42
+ self.generator = ByteDreamGenerator(
43
+ model_path=model_path,
44
+ config_path="config.yaml",
45
+ device=device,
46
+ )
47
+
48
+ if self.verbose:
49
+ print("✓ Application ready!")
50
+
51
+ def generate(
52
+ self,
53
+ prompt: str,
54
+ output_path: str = "output.png",
55
+ negative_prompt: Optional[str] = None,
56
+ width: int = 512,
57
+ height: int = 512,
58
+ steps: int = 50,
59
+ guidance: float = 7.5,
60
+ seed: Optional[int] = None,
61
+ save: bool = True,
62
+ ) -> Image.Image:
63
+ """
64
+ Generate image from prompt and optionally save to file
65
+
66
+ Args:
67
+ prompt: Text description
68
+ output_path: Where to save the image
69
+ negative_prompt: What to avoid
70
+ width: Image width
71
+ height: Image height
72
+ steps: Inference steps
73
+ guidance: Guidance scale
74
+ seed: Random seed
75
+ save: Whether to save to file
76
+
77
+ Returns:
78
+ Generated PIL Image
79
+ """
80
+ # Generate image
81
+ image = self.generator.generate(
82
+ prompt=prompt,
83
+ negative_prompt=negative_prompt,
84
+ width=width,
85
+ height=height,
86
+ num_inference_steps=steps,
87
+ guidance_scale=guidance,
88
+ seed=seed,
89
+ )
90
+
91
+ # Save if requested
92
+ if save:
93
+ save_image(image, output_path)
94
+ if self.verbose:
95
+ print(f"✓ Image saved to: {output_path}")
96
+
97
+ return image
98
+
99
+ def generate_multiple(
100
+ self,
101
+ prompts: List[str],
102
+ output_dir: str = "./outputs",
103
+ negative_prompt: Optional[str] = None,
104
+ width: int = 512,
105
+ height: int = 512,
106
+ steps: int = 50,
107
+ guidance: float = 7.5,
108
+ seeds: Optional[List[int]] = None,
109
+ create_grid: bool = True,
110
+ ) -> List[Image.Image]:
111
+ """
112
+ Generate multiple images from prompts
113
+
114
+ Args:
115
+ prompts: List of prompts
116
+ output_dir: Directory to save images
117
+ negative_prompt: Negative prompt for all
118
+ width: Image width
119
+ height: Image height
120
+ steps: Inference steps
121
+ guidance: Guidance scale
122
+ seeds: Seeds for each image
123
+ create_grid: Create grid of all images
124
+
125
+ Returns:
126
+ List of generated images
127
+ """
128
+ from pathlib import Path
129
+
130
+ output_path = Path(output_dir)
131
+ output_path.mkdir(parents=True, exist_ok=True)
132
+
133
+ images = []
134
+
135
+ for i, prompt in enumerate(prompts):
136
+ print(f"\n{'='*60}")
137
+ print(f"Generating image {i+1}/{len(prompts)}")
138
+ print(f"{'='*60}")
139
+
140
+ seed = seeds[i] if seeds else None
141
+
142
+ image = self.generate(
143
+ prompt=prompt,
144
+ output_path=str(output_path / f"image_{i+1:03d}.png"),
145
+ negative_prompt=negative_prompt,
146
+ width=width,
147
+ height=height,
148
+ steps=steps,
149
+ guidance=guidance,
150
+ seed=seed,
151
+ save=True,
152
+ )
153
+
154
+ images.append(image)
155
+
156
+ # Create grid
157
+ if create_grid and len(images) > 1:
158
+ grid = create_image_grid(images)
159
+ grid_path = output_path / "grid.png"
160
+ grid.save(grid_path)
161
+ print(f"\n✓ Grid saved to: {grid_path}")
162
+
163
+ return images
164
+
165
+ def img2img(
166
+ self,
167
+ input_image_path: str,
168
+ prompt: str,
169
+ output_path: str = "output_img2img.png",
170
+ strength: float = 0.75,
171
+ negative_prompt: Optional[str] = None,
172
+ steps: int = 50,
173
+ guidance: float = 7.5,
174
+ seed: Optional[int] = None,
175
+ ) -> Image.Image:
176
+ """
177
+ Image-to-image transformation (placeholder for future implementation)
178
+
179
+ Args:
180
+ input_image_path: Input image path
181
+ prompt: Transformation prompt
182
+ output_path: Output path
183
+ strength: How much to transform (0-1)
184
+ negative_prompt: Negative prompt
185
+ steps: Inference steps
186
+ guidance: Guidance scale
187
+ seed: Random seed
188
+
189
+ Returns:
190
+ Transformed image
191
+ """
192
+ print("⚠ img2img functionality will be available in a future update")
193
+ print(" For now, using text-to-image generation only")
194
+
195
+ # For now, just generate from prompt
196
+ return self.generate(
197
+ prompt=prompt,
198
+ output_path=output_path,
199
+ negative_prompt=negative_prompt,
200
+ steps=steps,
201
+ guidance=guidance,
202
+ seed=seed,
203
+ )
204
+
205
+ def info(self):
206
+ """Print model information"""
207
+ info = self.generator.get_model_info()
208
+
209
+ print("\n" + "="*60)
210
+ print("Byte Dream Model Information")
211
+ print("="*60)
212
+ for key, value in info.items():
213
+ print(f"{key.replace('_', ' ').title()}: {value}")
214
+ print("="*60)
215
+
216
+
217
+ def demo():
218
+ """Run a quick demo"""
219
+ print("\n" + "="*60)
220
+ print("Byte Dream - Quick Demo")
221
+ print("="*60)
222
+
223
+ app = ByteDreamApp(device="cpu", verbose=True)
224
+
225
+ # Demo prompts
226
+ prompts = [
227
+ "A beautiful sunset over mountains, digital art, vibrant colors",
228
+ "Cyberpunk city at night with neon lights, futuristic",
229
+ "Fantasy landscape with castle and waterfall, epic",
230
+ ]
231
+
232
+ print("\nGenerating sample images...")
233
+
234
+ images = app.generate_multiple(
235
+ prompts=prompts,
236
+ output_dir="./demo_outputs",
237
+ steps=30, # Fewer steps for demo
238
+ guidance=7.5,
239
+ create_grid=True,
240
+ )
241
+
242
+ print(f"\n✓ Demo complete! Generated {len(images)} images")
243
+ print(" Check ./demo_outputs/ for results")
244
+
245
+
246
+ if __name__ == "__main__":
247
+ import argparse
248
+
249
+ parser = argparse.ArgumentParser(description="Byte Dream Application")
250
+ parser.add_argument("--demo", action="store_true", help="Run demo")
251
+ args = parser.parse_args()
252
+
253
+ if args.demo:
254
+ demo()
255
+ else:
256
+ # Interactive mode
257
+ app = ByteDreamApp()
258
+
259
+ print("\nByte Dream Interactive Mode")
260
+ print("Type 'quit' to exit\n")
261
+
262
+ while True:
263
+ prompt = input("Prompt: ").strip()
264
+
265
+ if prompt.lower() in ['quit', 'exit', 'q']:
266
+ break
267
+
268
+ if not prompt:
269
+ continue
270
+
271
+ try:
272
+ image = app.generate(
273
+ prompt=prompt,
274
+ output_path=f"output_{len(prompt)}.png",
275
+ )
276
+ print("✓ Image generated!\n")
277
+ except Exception as e:
278
+ print(f"Error: {e}\n")
prepare_dataset.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Preparation Tool
3
+ Prepare and preprocess image-text datasets for training
4
+ """
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import json
10
+ import shutil
11
+ from typing import List, Tuple
12
+
13
+
14
+ def prepare_dataset(
15
+ input_dir: str,
16
+ output_dir: str,
17
+ image_size: int = 512,
18
+ min_resolution: int = 256,
19
+ filter_low_quality: bool = True,
20
+ ):
21
+ """
22
+ Prepare dataset for training
23
+
24
+ Args:
25
+ input_dir: Directory with raw images
26
+ output_dir: Output directory for processed data
27
+ image_size: Target image size
28
+ min_resolution: Minimum acceptable resolution
29
+ filter_low_quality: Filter out low quality images
30
+ """
31
+ input_path = Path(input_dir)
32
+ output_path = Path(output_dir)
33
+
34
+ # Create output directories
35
+ output_path.mkdir(parents=True, exist_ok=True)
36
+ (output_path / "images").mkdir(exist_ok=True)
37
+ (output_path / "captions").mkdir(exist_ok=True)
38
+
39
+ # Find all images
40
+ image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
41
+ image_files = []
42
+
43
+ for ext in image_extensions:
44
+ image_files.extend(input_path.glob(f"*{ext}"))
45
+ image_files.extend(input_path.glob(f"**/*{ext}"))
46
+
47
+ print(f"Found {len(image_files)} images")
48
+
49
+ # Process each image
50
+ processed_count = 0
51
+ skipped_count = 0
52
+
53
+ for img_file in image_files:
54
+ try:
55
+ process_image(
56
+ img_path=img_file,
57
+ output_img_path=output_path / "images" / f"{img_file.stem}.jpg",
58
+ caption_path=output_path / "captions" / f"{img_file.stem}.txt",
59
+ image_size=image_size,
60
+ min_resolution=min_resolution,
61
+ filter_low_quality=filter_low_quality,
62
+ )
63
+ processed_count += 1
64
+
65
+ if processed_count % 10 == 0:
66
+ print(f"Processed: {processed_count}/{len(image_files)}")
67
+
68
+ except Exception as e:
69
+ print(f"Error processing {img_file}: {e}")
70
+ skipped_count += 1
71
+
72
+ # Save metadata
73
+ metadata = {
74
+ 'total_images': processed_count,
75
+ 'skipped_images': skipped_count,
76
+ 'image_size': image_size,
77
+ 'min_resolution': min_resolution,
78
+ }
79
+
80
+ with open(output_path / "metadata.json", 'w') as f:
81
+ json.dump(metadata, f, indent=2)
82
+
83
+ print(f"\n✓ Dataset preparation complete!")
84
+ print(f" Processed: {processed_count} images")
85
+ print(f" Skipped: {skipped_count} images")
86
+ print(f" Output: {output_path}")
87
+
88
+
89
+ def process_image(
90
+ img_path: Path,
91
+ output_img_path: Path,
92
+ caption_path: Path,
93
+ image_size: int = 512,
94
+ min_resolution: int = 256,
95
+ filter_low_quality: bool = True,
96
+ ):
97
+ """
98
+ Process single image
99
+
100
+ Args:
101
+ img_path: Input image path
102
+ output_img_path: Output image path
103
+ caption_path: Output caption path
104
+ image_size: Target size
105
+ min_resolution: Minimum resolution
106
+ filter_low_quality: Filter low quality
107
+ """
108
+ # Load image
109
+ image = Image.open(img_path).convert('RGB')
110
+
111
+ # Check resolution
112
+ width, height = image.size
113
+
114
+ if width < min_resolution or height < min_resolution:
115
+ raise ValueError(f"Image too small: {width}x{height}")
116
+
117
+ # Resize if necessary
118
+ if min(width, height) > image_size * 1.5:
119
+ # Downscale large images
120
+ scale = image_size / max(width, height)
121
+ new_width = int(width * scale)
122
+ new_height = int(height * scale)
123
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
124
+
125
+ # Center crop to square
126
+ size = min(image.size)
127
+ left = (image.size[0] - size) // 2
128
+ top = (image.size[1] - size) // 2
129
+ image = image.crop((left, top, left + size, top + size))
130
+
131
+ # Resize to target size
132
+ image = image.resize((image_size, image_size), Image.Resampling.LANCZOS)
133
+
134
+ # Save processed image
135
+ image.save(output_img_path, quality=95, optimize=True)
136
+
137
+ # Generate or load caption
138
+ caption = generate_caption(img_path)
139
+
140
+ with open(caption_path, 'w', encoding='utf-8') as f:
141
+ f.write(caption)
142
+
143
+
144
+ def generate_caption(img_path: Path) -> str:
145
+ """
146
+ Generate caption from image filename or load from adjacent text file
147
+
148
+ Args:
149
+ img_path: Path to image
150
+
151
+ Returns:
152
+ Caption text
153
+ """
154
+ # Try to load from adjacent .txt file
155
+ txt_file = img_path.with_suffix('.txt')
156
+
157
+ if txt_file.exists():
158
+ with open(txt_file, 'r', encoding='utf-8') as f:
159
+ caption = f.read().strip()
160
+ if caption:
161
+ return caption
162
+
163
+ # Use filename as fallback
164
+ caption = img_path.stem.replace('_', ' ').replace('-', ' ')
165
+
166
+ # Capitalize first letter
167
+ caption = caption.capitalize()
168
+
169
+ return caption
170
+
171
+
172
+ def create_training_splits(
173
+ data_dir: str,
174
+ train_ratio: float = 0.9,
175
+ val_ratio: float = 0.05,
176
+ test_ratio: float = 0.05,
177
+ ):
178
+ """
179
+ Create train/val/test splits
180
+
181
+ Args:
182
+ data_dir: Directory with processed data
183
+ train_ratio: Training set ratio
184
+ val_ratio: Validation set ratio
185
+ test_ratio: Test set ratio
186
+ """
187
+ data_path = Path(data_dir)
188
+
189
+ # Get all images
190
+ images = list((data_path / "images").glob("*.jpg"))
191
+
192
+ # Shuffle deterministically
193
+ import random
194
+ random.seed(42)
195
+ random.shuffle(images)
196
+
197
+ # Calculate split sizes
198
+ total = len(images)
199
+ train_size = int(total * train_ratio)
200
+ val_size = int(total * val_ratio)
201
+
202
+ # Split datasets
203
+ train_images = images[:train_size]
204
+ val_images = images[train_size:train_size + val_size]
205
+ test_images = images[train_size + val_size:]
206
+
207
+ # Save splits
208
+ def save_split(image_list, split_name):
209
+ split_data = {
210
+ 'images': [str(img.name) for img in image_list],
211
+ 'count': len(image_list),
212
+ }
213
+
214
+ with open(data_path / f"{split_name}.json", 'w') as f:
215
+ json.dump(split_data, f, indent=2)
216
+
217
+ print(f"{split_name}: {len(image_list)} images")
218
+
219
+ save_split(train_images, "train")
220
+ save_split(val_images, "validation")
221
+ save_split(test_images, "test")
222
+
223
+ print(f"\n✓ Created training splits")
224
+ print(f" Total: {total} images")
225
+
226
+
227
+ def main():
228
+ parser = argparse.ArgumentParser(description="Prepare dataset for Byte Dream training")
229
+
230
+ parser.add_argument(
231
+ "--input", "-i",
232
+ type=str,
233
+ required=True,
234
+ help="Input directory with raw images"
235
+ )
236
+
237
+ parser.add_argument(
238
+ "--output", "-o",
239
+ type=str,
240
+ default="./processed_data",
241
+ help="Output directory for processed data"
242
+ )
243
+
244
+ parser.add_argument(
245
+ "--size", "-s",
246
+ type=int,
247
+ default=512,
248
+ help="Target image size (default: 512)"
249
+ )
250
+
251
+ parser.add_argument(
252
+ "--min_res",
253
+ type=int,
254
+ default=256,
255
+ help="Minimum image resolution (default: 256)"
256
+ )
257
+
258
+ parser.add_argument(
259
+ "--no_filter",
260
+ action="store_true",
261
+ help="Disable low quality filtering"
262
+ )
263
+
264
+ parser.add_argument(
265
+ "--create_splits",
266
+ action="store_true",
267
+ help="Create train/val/test splits"
268
+ )
269
+
270
+ args = parser.parse_args()
271
+
272
+ # Prepare dataset
273
+ prepare_dataset(
274
+ input_dir=args.input,
275
+ output_dir=args.output,
276
+ image_size=args.size,
277
+ min_resolution=args.min_res,
278
+ filter_low_quality=not args.no_filter,
279
+ )
280
+
281
+ # Create splits if requested
282
+ if args.create_splits:
283
+ create_training_splits(args.output)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
publish_to_hf.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload Byte Dream to Hugging Face Hub
3
+ Using login() and upload_folder() methods
4
+ """
5
+
6
+ import os
7
+ from huggingface_hub import login, upload_folder
8
+
9
+ # Get token from command line argument or prompt
10
+ import sys
11
+
12
+ if len(sys.argv) > 1:
13
+ token = sys.argv[1]
14
+ else:
15
+ token = input("Enter your Hugging Face token: ")
16
+
17
+ # Login with your Hugging Face token
18
+ print("Logging in to Hugging Face...")
19
+ login(token=token)
20
+
21
+ # Push your model files
22
+ print("\nUploading model to Hugging Face Hub...")
23
+ upload_folder(
24
+ folder_path=".",
25
+ repo_id="Enzo8930302/ByteDream",
26
+ repo_type="model"
27
+ )
28
+
29
+ print("\n✓ Model uploaded successfully!")
30
+ print("📦 View your model at: https://huggingface.co/Enzo8930302/ByteDream")
quick_start.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick Start Script
3
+ Setup and test Byte Dream in one command
4
+ """
5
+
6
+ import subprocess
7
+ import sys
8
+ from pathlib import Path
9
+
10
+
11
+ def check_requirements():
12
+ """Check if requirements are installed"""
13
+ print("Checking requirements...")
14
+
15
+ required = [
16
+ 'torch',
17
+ 'transformers',
18
+ 'diffusers',
19
+ 'pillow',
20
+ 'numpy',
21
+ 'gradio',
22
+ ]
23
+
24
+ missing = []
25
+
26
+ for package in required:
27
+ try:
28
+ __import__(package.replace('-', '_'))
29
+ print(f" ✓ {package}")
30
+ except ImportError:
31
+ print(f" ✗ {package} - MISSING")
32
+ missing.append(package)
33
+
34
+ if missing:
35
+ print(f"\nMissing packages: {', '.join(missing)}")
36
+ print("\nInstall with:")
37
+ print(" pip install -r requirements.txt")
38
+ return False
39
+
40
+ print("\n✓ All requirements satisfied!")
41
+ return True
42
+
43
+
44
+ def test_model():
45
+ """Test model generation"""
46
+ print("\n" + "="*60)
47
+ print("Testing Byte Dream Model")
48
+ print("="*60)
49
+
50
+ try:
51
+ from bytedream.generator import ByteDreamGenerator
52
+
53
+ print("\nInitializing generator...")
54
+ generator = ByteDreamGenerator(device="cpu")
55
+
56
+ print("\nModel info:")
57
+ info = generator.get_model_info()
58
+ for key, value in info.items():
59
+ print(f" {key}: {value}")
60
+
61
+ print("\nGenerating test image...")
62
+ print("Prompt: A simple test pattern, geometric shapes")
63
+
64
+ image = generator.generate(
65
+ prompt="A simple test pattern, geometric shapes, abstract art",
66
+ width=256,
67
+ height=256,
68
+ num_inference_steps=20,
69
+ seed=42,
70
+ )
71
+
72
+ output_path = Path("test_output.png")
73
+ image.save(output_path)
74
+
75
+ print(f"\n✓ Test successful!")
76
+ print(f" Image saved to: {output_path.absolute()}")
77
+
78
+ return True
79
+
80
+ except Exception as e:
81
+ print(f"\n✗ Error: {e}")
82
+ print("\nThe model needs to be trained or pretrained weights downloaded.")
83
+ return False
84
+
85
+
86
+ def download_pretrained():
87
+ """Download pretrained model from Hugging Face"""
88
+ print("\n" + "="*60)
89
+ print("Downloading Pretrained Model")
90
+ print("="*60)
91
+
92
+ print("\nTo download a pretrained model:")
93
+ print("1. Visit https://huggingface.co/models")
94
+ print("2. Search for 'stable-diffusion' or similar")
95
+ print("3. Download using:")
96
+ print("\n from huggingface_hub import snapshot_download")
97
+ print(" snapshot_download(repo_id='username/model', local_dir='./models/bytedream')")
98
+ print("\nOr train your own model with:")
99
+ print(" python train.py --train_data ./dataset --output_dir ./models/bytedream")
100
+
101
+
102
+ def main():
103
+ print("="*60)
104
+ print("Byte Dream - Quick Start")
105
+ print("="*60)
106
+
107
+ # Check requirements
108
+ if not check_requirements():
109
+ print("\n⚠ Please install requirements first")
110
+ sys.exit(1)
111
+
112
+ # Test model
113
+ if test_model():
114
+ print("\n✓ Byte Dream is ready to use!")
115
+ print("\nNext steps:")
116
+ print(" - Run: python infer.py --prompt 'Your prompt here'")
117
+ print(" - Run: python app.py (for web interface)")
118
+ print(" - Run: python main.py --demo (for demo)")
119
+ else:
120
+ download_pretrained()
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.35.0
2
+ diffusers>=0.24.0
3
+ torch>=2.1.0
4
+ torchaudio>=2.1.0
5
+ accelerate>=0.25.0
6
+ numpy>=1.24.0
7
+ pillow>=10.0.0
8
+ opencv-python>=4.8.0
9
+ safetensors>=0.4.0
10
+ huggingface_hub>=0.19.0
11
+ gradio>=4.0.0
12
+ tqdm>=4.66.0
13
+ pyyaml>=6.0
14
+ matplotlib>=3.8.0
15
+ scipy>=1.11.0
16
+ einops>=0.7.0
train.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte Dream Training Pipeline
3
+ Complete training system for diffusion models with CPU optimization
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms
12
+ from PIL import Image
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ import yaml
16
+ import argparse
17
+ from pathlib import Path
18
+ from typing import Tuple, List, Optional
19
+ import gc
20
+
21
+
22
+ class ImageTextDataset(Dataset):
23
+ """
24
+ Dataset for image-text pairs
25
+ Supports various data augmentations for better generalization
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ data_dir: str,
31
+ image_size: int = 512,
32
+ random_flip: bool = True,
33
+ random_crop: bool = False,
34
+ center_crop: bool = True,
35
+ ):
36
+ self.data_dir = Path(data_dir)
37
+ self.image_paths = list(self.data_dir.glob("*.jpg")) + \
38
+ list(self.data_dir.glob("*.png")) + \
39
+ list(self.data_dir.glob("*.jpeg"))
40
+
41
+ self.image_size = image_size
42
+ self.random_flip = random_flip
43
+ self.random_crop = random_crop
44
+ self.center_crop = center_crop
45
+
46
+ # Transformations
47
+ self.transform = self._get_transform()
48
+
49
+ # Load captions
50
+ self.captions = self._load_captions()
51
+
52
+ def _get_transform(self) -> transforms.Compose:
53
+ """Get image transformation pipeline"""
54
+ transforms_list = []
55
+
56
+ if self.random_crop:
57
+ transforms_list.append(transforms.RandomCrop(self.image_size))
58
+ elif self.center_crop:
59
+ transforms_list.append(transforms.CenterCrop(self.image_size))
60
+ else:
61
+ transforms_list.append(transforms.Resize((self.image_size, self.image_size)))
62
+
63
+ if self.random_flip:
64
+ transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
65
+
66
+ transforms_list.extend([
67
+ transforms.ToTensor(),
68
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
69
+ ])
70
+
71
+ return transforms.Compose(transforms_list)
72
+
73
+ def _load_captions(self) -> dict:
74
+ """Load captions from text files"""
75
+ captions = {}
76
+
77
+ for img_path in self.image_paths:
78
+ caption_path = img_path.with_suffix('.txt')
79
+ if caption_path.exists():
80
+ with open(caption_path, 'r', encoding='utf-8') as f:
81
+ captions[str(img_path)] = f.read().strip()
82
+ else:
83
+ # Use filename as caption if no text file
84
+ captions[str(img_path)] = img_path.stem.replace('_', ' ')
85
+
86
+ return captions
87
+
88
+ def __len__(self) -> int:
89
+ return len(self.image_paths)
90
+
91
+ def __getitem__(self, idx: int) -> dict:
92
+ img_path = self.image_paths[idx]
93
+
94
+ # Load image
95
+ try:
96
+ image = Image.open(img_path).convert('RGB')
97
+ except Exception as e:
98
+ print(f"Error loading image {img_path}: {e}")
99
+ return self.__getitem__((idx + 1) % len(self))
100
+
101
+ # Transform image
102
+ pixel_values = self.transform(image)
103
+
104
+ # Get caption
105
+ caption = self.captions.get(str(img_path), "")
106
+
107
+ return {
108
+ "pixel_values": pixel_values,
109
+ "input_ids": caption,
110
+ "image_path": str(img_path),
111
+ }
112
+
113
+
114
+ class LatentDiffusionTrainer:
115
+ """
116
+ Trainer for latent diffusion models
117
+ Implements training loop with mixed precision and gradient accumulation
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ unet: nn.Module,
123
+ vae: nn.Module,
124
+ text_encoder: nn.Module,
125
+ scheduler,
126
+ config: dict,
127
+ device: str = "cpu",
128
+ ):
129
+ self.unet = unet
130
+ self.vae = vae
131
+ self.text_encoder = text_encoder
132
+ self.scheduler = scheduler
133
+ self.config = config
134
+ self.device = torch.device(device)
135
+
136
+ # Training parameters
137
+ self.epochs = config['training']['epochs']
138
+ self.batch_size = config['training']['batch_size']
139
+ self.learning_rate = config['training']['learning_rate']
140
+ self.gradient_accumulation_steps = config['training']['gradient_accumulation_steps']
141
+ self.max_grad_norm = config['training']['max_grad_norm']
142
+
143
+ # Mixed precision
144
+ self.mixed_precision = config['training']['mixed_precision']
145
+ self.use_amp = self.mixed_precision != "no"
146
+
147
+ # Output directories
148
+ self.output_dir = Path(config['training']['output_dir'])
149
+ self.logging_dir = Path(config['training']['logging_dir'])
150
+ self.output_dir.mkdir(parents=True, exist_ok=True)
151
+ self.logging_dir.mkdir(parents=True, exist_ok=True)
152
+
153
+ # Initialize optimizer
154
+ self.optimizer = torch.optim.AdamW(
155
+ unet.parameters(),
156
+ lr=self.learning_rate,
157
+ betas=(0.9, 0.999),
158
+ weight_decay=1e-2,
159
+ eps=1e-08,
160
+ )
161
+
162
+ # Learning rate scheduler
163
+ self.lr_scheduler = self._create_lr_scheduler()
164
+
165
+ # Gradient scaler for mixed precision
166
+ self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and torch.cuda.is_available() else None
167
+
168
+ # Move models to device
169
+ self._prepare_models()
170
+
171
+ def _prepare_models(self):
172
+ """Prepare models for training"""
173
+ print(f"Preparing models on {self.device}...")
174
+
175
+ self.vae.to(self.device)
176
+ self.text_encoder.to(self.device)
177
+ self.unet.to(self.device)
178
+
179
+ # Set VAE and text encoder to eval mode (frozen)
180
+ self.vae.eval()
181
+ if hasattr(self.text_encoder, 'model'):
182
+ self.text_encoder.model.eval()
183
+
184
+ # Freeze VAE and text encoder parameters
185
+ for param in self.vae.parameters():
186
+ param.requires_grad = False
187
+
188
+ if hasattr(self.text_encoder, 'model'):
189
+ for param in self.text_encoder.model.parameters():
190
+ param.requires_grad = False
191
+
192
+ # Set UNet to train mode
193
+ self.unet.train()
194
+
195
+ def _create_lr_scheduler(self):
196
+ """Create learning rate scheduler"""
197
+ sched_config = self.config['training']
198
+
199
+ if sched_config['lr_scheduler'] == "constant_with_warmup":
200
+ return torch.optim.lr_scheduler.ConstantLR(
201
+ self.optimizer,
202
+ factor=1.0,
203
+ total_iters=sched_config['lr_warmup_steps'],
204
+ )
205
+ elif sched_config['lr_scheduler'] == "linear":
206
+ return torch.optim.lr_scheduler.LinearLR(
207
+ self.optimizer,
208
+ start_factor=0.1,
209
+ end_factor=1.0,
210
+ total_iters=sched_config['lr_warmup_steps'],
211
+ )
212
+ else:
213
+ return torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
214
+
215
+ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
216
+ """Encode images to latent space"""
217
+ with torch.no_grad():
218
+ latents = self.vae.encode(images)
219
+ latents = latents * 0.18215 # Scale factor
220
+ return latents
221
+
222
+ def encode_text(self, texts: List[str]) -> torch.Tensor:
223
+ """Encode text to embeddings"""
224
+ with torch.no_grad():
225
+ text_embeddings = self.text_encoder(texts, device=self.device)
226
+ return text_embeddings
227
+
228
+ def compute_loss(
229
+ self,
230
+ latents: torch.Tensor,
231
+ text_embeddings: torch.Tensor,
232
+ ) -> torch.Tensor:
233
+ """
234
+ Compute diffusion loss
235
+
236
+ Args:
237
+ latents: Latent representations of images
238
+ text_embeddings: Text embeddings
239
+
240
+ Returns:
241
+ Loss value
242
+ """
243
+ batch_size = latents.shape[0]
244
+
245
+ # Sample random timesteps
246
+ timesteps = torch.randint(
247
+ 0,
248
+ self.scheduler.num_train_timesteps,
249
+ (batch_size,),
250
+ device=self.device,
251
+ ).long()
252
+
253
+ # Add noise to latents
254
+ noise = torch.randn_like(latents)
255
+ noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
256
+
257
+ # Predict noise
258
+ timestep_tensor = timesteps
259
+
260
+ model_output = self.unet(
261
+ sample=noisy_latents,
262
+ timestep=timestep_tensor,
263
+ encoder_hidden_states=text_embeddings,
264
+ )
265
+
266
+ # Compute loss
267
+ loss = F.mse_loss(model_output, noise, reduction="mean")
268
+
269
+ return loss
270
+
271
+ def train_step(
272
+ self,
273
+ batch: dict,
274
+ ) -> float:
275
+ """
276
+ Perform single training step
277
+
278
+ Args:
279
+ batch: Batch of data
280
+
281
+ Returns:
282
+ Loss value
283
+ """
284
+ pixel_values = batch["pixel_values"].to(self.device)
285
+ input_ids = batch["input_ids"]
286
+
287
+ # Encode images and text
288
+ latents = self.encode_images(pixel_values)
289
+ text_embeddings = self.encode_text(input_ids)
290
+
291
+ # Compute loss
292
+ if self.use_amp and self.scaler is not None:
293
+ with torch.cuda.amp.autocast():
294
+ loss = self.compute_loss(latents, text_embeddings)
295
+ loss = loss / self.gradient_accumulation_steps
296
+
297
+ self.scaler.scale(loss).backward()
298
+ else:
299
+ loss = self.compute_loss(latents, text_embeddings)
300
+ loss = loss / self.gradient_accumulation_steps
301
+ loss.backward()
302
+
303
+ return loss.item() * self.gradient_accumulation_steps
304
+
305
+ def save_checkpoint(self, epoch: int, step: int):
306
+ """Save model checkpoint"""
307
+ checkpoint_dir = self.output_dir / f"checkpoint-{epoch}-{step}"
308
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
309
+
310
+ # Save UNet
311
+ torch.save({
312
+ 'epoch': epoch,
313
+ 'step': step,
314
+ 'unet_state_dict': self.unet.state_dict(),
315
+ 'optimizer_state_dict': self.optimizer.state_dict(),
316
+ 'scheduler_state_dict': self.lr_scheduler.state_dict() if self.lr_scheduler else None,
317
+ }, checkpoint_dir / "pytorch_model.bin")
318
+
319
+ # Save config
320
+ with open(checkpoint_dir / "config.yaml", 'w') as f:
321
+ yaml.dump(self.config, f)
322
+
323
+ print(f"Checkpoint saved to {checkpoint_dir}")
324
+
325
+ def train(self, resume_from_checkpoint: Optional[str] = None):
326
+ """
327
+ Main training loop
328
+
329
+ Args:
330
+ resume_from_checkpoint: Path to checkpoint to resume from
331
+ """
332
+ # Create dataset and dataloader
333
+ train_config = self.config['training']
334
+
335
+ dataset = ImageTextDataset(
336
+ data_dir=train_config['dataset_path'],
337
+ image_size=512,
338
+ random_flip=train_config['random_flip'],
339
+ random_crop=train_config['random_crop'],
340
+ center_crop=train_config['center_crop'],
341
+ )
342
+
343
+ dataloader = DataLoader(
344
+ dataset,
345
+ batch_size=self.batch_size,
346
+ shuffle=True,
347
+ num_workers=0, # CPU training
348
+ pin_memory=False,
349
+ )
350
+
351
+ # Resume from checkpoint
352
+ start_epoch = 0
353
+ global_step = 0
354
+
355
+ if resume_from_checkpoint:
356
+ print(f"Resuming from checkpoint: {resume_from_checkpoint}")
357
+ checkpoint = torch.load(resume_from_checkpoint, map_location=self.device)
358
+ self.unet.load_state_dict(checkpoint['unet_state_dict'])
359
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
360
+ if checkpoint['scheduler_state_dict']:
361
+ self.lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
362
+ start_epoch = checkpoint['epoch']
363
+ global_step = checkpoint['step']
364
+
365
+ # Training loop
366
+ total_steps = len(dataloader) * self.epochs
367
+
368
+ print(f"Starting training for {self.epochs} epochs...")
369
+ print(f"Total steps: {total_steps}")
370
+ print(f"Batch size: {self.batch_size}")
371
+ print(f"Mixed precision: {self.mixed_precision}")
372
+
373
+ for epoch in range(start_epoch, self.epochs):
374
+ self.unet.train()
375
+ progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{self.epochs}")
376
+
377
+ epoch_loss = 0
378
+ num_steps = 0
379
+
380
+ for step, batch in enumerate(progress_bar):
381
+ # Training step
382
+ loss = self.train_step(batch)
383
+ epoch_loss += loss
384
+ num_steps += 1
385
+
386
+ # Gradient clipping and optimizer step
387
+ if (step + 1) % self.gradient_accumulation_steps == 0:
388
+ if self.use_amp and self.scaler is not None:
389
+ self.scaler.unscale_(self.optimizer)
390
+ torch.nn.utils.clip_grad_norm_(
391
+ self.unet.parameters(),
392
+ self.max_grad_norm,
393
+ )
394
+ self.scaler.step(self.optimizer)
395
+ self.scaler.update()
396
+ else:
397
+ torch.nn.utils.clip_grad_norm_(self.unet.parameters(), self.max_grad_norm)
398
+ self.optimizer.step()
399
+
400
+ # Learning rate scheduling
401
+ if self.lr_scheduler:
402
+ self.lr_scheduler.step()
403
+
404
+ # Zero gradients
405
+ self.optimizer.zero_grad()
406
+
407
+ # Update progress bar
408
+ avg_loss = epoch_loss / num_steps
409
+ progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
410
+
411
+ # Logging
412
+ if (global_step + 1) % self.config['training']['log_every_n_steps'] == 0:
413
+ print(f"\nStep {global_step + 1}: Loss = {avg_loss:.4f}")
414
+
415
+ # Save checkpoint periodically
416
+ if (global_step + 1) % 1000 == 0:
417
+ self.save_checkpoint(epoch, global_step)
418
+
419
+ global_step += 1
420
+
421
+ # End of epoch
422
+ avg_epoch_loss = epoch_loss / max(num_steps, 1)
423
+ print(f"\nEpoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}")
424
+
425
+ # Save epoch checkpoint
426
+ self.save_checkpoint(epoch, global_step)
427
+
428
+ # Clear memory
429
+ gc.collect()
430
+ if torch.cuda.is_available():
431
+ torch.cuda.empty_cache()
432
+
433
+ # Save final model
434
+ print("\nTraining completed!")
435
+ self.save_final_model()
436
+
437
+ def save_final_model(self):
438
+ """Save final trained model"""
439
+ final_dir = self.output_dir / "final"
440
+ final_dir.mkdir(parents=True, exist_ok=True)
441
+
442
+ # Save UNet
443
+ torch.save({
444
+ 'unet_state_dict': self.unet.state_dict(),
445
+ 'config': self.config,
446
+ }, final_dir / "unet_pytorch_model.bin")
447
+
448
+ print(f"Final model saved to {final_dir}")
449
+
450
+
451
+ def main():
452
+ """Main training function"""
453
+ parser = argparse.ArgumentParser(description="Train Byte Dream diffusion model")
454
+ parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
455
+ parser.add_argument("--train_data", type=str, required=True, help="Path to training data")
456
+ parser.add_argument("--output_dir", type=str, default="./models/bytedream", help="Output directory")
457
+ parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
458
+ parser.add_argument("--device", type=str, default="cpu", help="Device to train on")
459
+
460
+ args = parser.parse_args()
461
+
462
+ # Load config
463
+ with open(args.config, 'r') as f:
464
+ config = yaml.safe_load(f)
465
+
466
+ # Override config with command line arguments
467
+ config['training']['dataset_path'] = args.train_data
468
+ config['training']['output_dir'] = args.output_dir
469
+
470
+ # Import model components
471
+ from bytedream.model import create_unet, create_vae, create_text_encoder
472
+ from bytedream.scheduler import create_scheduler
473
+
474
+ # Create components
475
+ print("Creating model components...")
476
+ unet = create_unet(config)
477
+ vae = create_vae(config)
478
+ text_encoder = create_text_encoder(config)
479
+ scheduler = create_scheduler(config)
480
+
481
+ # Count parameters
482
+ total_params = sum(p.numel() for p in unet.parameters())
483
+ print(f"UNet parameters: {total_params:,}")
484
+
485
+ # Create trainer
486
+ trainer = LatentDiffusionTrainer(
487
+ unet=unet,
488
+ vae=vae,
489
+ text_encoder=text_encoder,
490
+ scheduler=scheduler,
491
+ config=config,
492
+ device=args.device,
493
+ )
494
+
495
+ # Start training
496
+ trainer.train(resume_from_checkpoint=args.resume)
497
+
498
+
499
+ if __name__ == "__main__":
500
+ main()
upload_to_hf.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Integration
3
+ Upload and deploy Byte Dream to Hugging Face Hub and Spaces
4
+ """
5
+
6
+ import argparse
7
+ from pathlib import Path
8
+ import yaml
9
+
10
+
11
+ def upload_to_huggingface(
12
+ model_path: str,
13
+ repo_id: str,
14
+ token: str = None,
15
+ private: bool = False,
16
+ ):
17
+ """
18
+ Upload model to Hugging Face Hub
19
+
20
+ Args:
21
+ model_path: Path to model directory
22
+ repo_id: Repository ID (username/model-name)
23
+ token: Hugging Face API token
24
+ private: Whether to make repository private
25
+ """
26
+ from huggingface_hub import HfApi, create_repo
27
+
28
+ print(f"Uploading model to Hugging Face Hub...")
29
+ print(f"Repository: {repo_id}")
30
+
31
+ # Initialize API
32
+ api = HfApi()
33
+
34
+ # Create repository
35
+ try:
36
+ create_repo(
37
+ repo_id=repo_id,
38
+ token=token,
39
+ private=private,
40
+ exist_ok=True,
41
+ repo_type="model",
42
+ )
43
+ print("✓ Repository created/verified")
44
+ except Exception as e:
45
+ print(f"Error creating repository: {e}")
46
+ return
47
+
48
+ # Upload model files
49
+ model_dir = Path(model_path)
50
+
51
+ if not model_dir.exists():
52
+ print(f"Error: Model directory {model_dir} does not exist")
53
+ return
54
+
55
+ print("\nUploading files...")
56
+
57
+ try:
58
+ # Upload entire directory
59
+ api.upload_folder(
60
+ folder_path=str(model_dir),
61
+ repo_id=repo_id,
62
+ token=token,
63
+ repo_type="model",
64
+ )
65
+ print("✓ Model uploaded successfully!")
66
+
67
+ # Print repository URL
68
+ print(f"\n📦 View your model at:")
69
+ print(f"https://huggingface.co/{repo_id}")
70
+
71
+ except Exception as e:
72
+ print(f"Error uploading model: {e}")
73
+
74
+
75
+ def create_gradio_app():
76
+ """Create Gradio app for Hugging Face Spaces"""
77
+ gradio_code = '''"""
78
+ Byte Dream - Gradio Web Interface
79
+ Deploy on Hugging Face Spaces
80
+ """
81
+
82
+ import gradio as gr
83
+ from bytedream.generator import ByteDreamGenerator
84
+ import torch
85
+
86
+ # Initialize generator
87
+ print("Loading Byte Dream model...")
88
+ generator = ByteDreamGenerator(
89
+ model_path="./models/bytedream",
90
+ config_path="config.yaml",
91
+ device="cpu",
92
+ )
93
+
94
+ def generate_image(
95
+ prompt,
96
+ negative_prompt,
97
+ width,
98
+ height,
99
+ num_steps,
100
+ guidance_scale,
101
+ seed,
102
+ ):
103
+ """Generate image from prompt"""
104
+
105
+ # Convert seed to None if -1
106
+ seed_value = None if seed == -1 else seed
107
+
108
+ try:
109
+ # Generate image
110
+ image = generator.generate(
111
+ prompt=prompt,
112
+ negative_prompt=negative_prompt if negative_prompt else None,
113
+ width=int(width),
114
+ height=int(height),
115
+ num_inference_steps=int(num_steps),
116
+ guidance_scale=float(guidance_scale),
117
+ seed=seed_value,
118
+ )
119
+
120
+ return image, "Success!"
121
+
122
+ except Exception as e:
123
+ print(f"Error generating image: {e}")
124
+ return None, f"Error: {str(e)}"
125
+
126
+
127
+ # Create Gradio interface
128
+ with gr.Blocks(title="Byte Dream - AI Image Generator", theme=gr.themes.Soft()) as demo:
129
+ gr.Markdown("""
130
+ # 🎨 Byte Dream - AI Image Generator
131
+
132
+ Generate stunning images from text descriptions using advanced diffusion models.
133
+ Optimized for CPU inference.
134
+
135
+ **Tips for better results:**
136
+ - Be specific and descriptive in your prompts
137
+ - Use negative prompts to avoid unwanted elements
138
+ - Higher steps = better quality but slower
139
+ - Adjust guidance scale for creativity vs accuracy
140
+ """)
141
+
142
+ with gr.Row():
143
+ with gr.Column(scale=1):
144
+ gr.Markdown("### 📝 Prompt")
145
+ prompt_input = gr.Textbox(
146
+ label="Prompt",
147
+ placeholder="A beautiful sunset over mountains, digital art, highly detailed",
148
+ lines=3,
149
+ )
150
+
151
+ negative_prompt_input = gr.Textbox(
152
+ label="Negative Prompt (optional)",
153
+ placeholder="ugly, blurry, low quality, distorted",
154
+ lines=2,
155
+ )
156
+
157
+ with gr.Row():
158
+ width_slider = gr.Slider(
159
+ minimum=256,
160
+ maximum=1024,
161
+ step=64,
162
+ value=512,
163
+ label="Width"
164
+ )
165
+ height_slider = gr.Slider(
166
+ minimum=256,
167
+ maximum=1024,
168
+ step=64,
169
+ value=512,
170
+ label="Height"
171
+ )
172
+
173
+ with gr.Row():
174
+ steps_slider = gr.Slider(
175
+ minimum=10,
176
+ maximum=150,
177
+ step=5,
178
+ value=50,
179
+ label="Inference Steps"
180
+ )
181
+ guidance_slider = gr.Slider(
182
+ minimum=1.0,
183
+ maximum=20.0,
184
+ step=0.5,
185
+ value=7.5,
186
+ label="Guidance Scale"
187
+ )
188
+
189
+ seed_input = gr.Number(
190
+ label="Seed (-1 for random)",
191
+ value=-1,
192
+ precision=0,
193
+ )
194
+
195
+ generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")
196
+
197
+ with gr.Column(scale=1):
198
+ gr.Markdown("### 🖼️ Generated Image")
199
+ output_image = gr.Image(
200
+ label="Generated Image",
201
+ type="pil",
202
+ )
203
+ status_text = gr.Textbox(label="Status")
204
+
205
+ # Examples
206
+ gr.Markdown("### 💡 Example Prompts")
207
+ gr.Examples(
208
+ examples=[
209
+ ["A cyberpunk city at night with neon lights, futuristic architecture, flying cars, highly detailed, digital art"],
210
+ ["A majestic dragon breathing fire, fantasy art, dramatic lighting, epic scene"],
211
+ ["A peaceful cottage in a meadow, flowers, sunny day, studio ghibli style"],
212
+ ["Portrait of a warrior princess, armor, fantasy, intricate details, character design"],
213
+ ["Underwater coral reef, tropical fish, sunlight filtering through water, photorealistic"],
214
+ ],
215
+ inputs=[prompt_input],
216
+ )
217
+
218
+ # Connect button
219
+ generate_btn.click(
220
+ fn=generate_image,
221
+ inputs=[
222
+ prompt_input,
223
+ negative_prompt_input,
224
+ width_slider,
225
+ height_slider,
226
+ steps_slider,
227
+ guidance_slider,
228
+ seed_input,
229
+ ],
230
+ outputs=[output_image, status_text],
231
+ )
232
+
233
+ gr.Markdown("""
234
+ ---
235
+ **Byte Dream** v1.0.0 | Powered by Latent Diffusion Models
236
+ """)
237
+
238
+
239
+ if __name__ == "__main__":
240
+ demo.launch(server_name="0.0.0.0", server_port=7860)
241
+ '''
242
+
243
+ return gradio_code
244
+
245
+
246
+ def create_readme_for_hf(repo_id: str):
247
+ """Create README for Hugging Face repository"""
248
+
249
+ readme = f'''---
250
+ license: mit
251
+ language:
252
+ - en
253
+ tags:
254
+ - text-to-image
255
+ - diffusion
256
+ - generative-ai
257
+ - cpu-optimized
258
+ ---
259
+
260
+ # {repo_id.split('/')[-1]}
261
+
262
+ {repo_id.split('/')[-1]} is a powerful text-to-image diffusion model optimized for CPU inference. Generate high-quality images from text prompts using advanced latent diffusion architecture.
263
+
264
+ ## Features
265
+
266
+ - 🚀 **CPU Optimized**: Runs efficiently on CPU without GPU requirement
267
+ - 🎨 **High Quality**: Generates 512x512 and higher resolution images
268
+ - ⚡ **Fast Inference**: Optimized for speed with quality preservation
269
+ - 🔧 **Flexible**: Supports various sampling methods and customization
270
+ - 📦 **Easy to Use**: Simple Python API and web interface
271
+
272
+ ## Installation
273
+
274
+ ```bash
275
+ pip install -r requirements.txt
276
+ ```
277
+
278
+ ## Usage
279
+
280
+ ### Python API
281
+
282
+ ```python
283
+ from bytedream import ByteDreamGenerator
284
+
285
+ # Initialize generator
286
+ generator = ByteDreamGenerator()
287
+
288
+ # Generate image
289
+ image = generator.generate(
290
+ prompt="A beautiful sunset over mountains, digital art",
291
+ num_inference_steps=50,
292
+ guidance_scale=7.5
293
+ )
294
+
295
+ image.save("output.png")
296
+ ```
297
+
298
+ ### Command Line
299
+
300
+ ```bash
301
+ python infer.py --prompt "A dragon flying over castle" --output dragon.png
302
+ ```
303
+
304
+ ### Web Interface
305
+
306
+ ```bash
307
+ python app.py
308
+ ```
309
+
310
+ ## Model Details
311
+
312
+ - **Architecture**: Latent Diffusion Model (UNet + VAE + Text Encoder)
313
+ - **Parameters**: ~1.2B
314
+ - **Training**: Trained on diverse image-text pairs
315
+ - **Optimization**: CPU-optimized with efficient memory usage
316
+
317
+ ## Examples
318
+
319
+ Try these prompts:
320
+ - "Cyberpunk city at night, neon lights, futuristic"
321
+ - "Fantasy landscape with mountains and waterfall"
322
+ - "Portrait of a warrior, detailed armor, dramatic lighting"
323
+ - "Abstract art, colorful, geometric shapes"
324
+
325
+ ## Configuration
326
+
327
+ Edit `config.yaml` to customize:
328
+ - Model architecture parameters
329
+ - Generation settings (resolution, steps, guidance)
330
+ - CPU optimization options
331
+
332
+ ## License
333
+
334
+ MIT License
335
+
336
+ ## Acknowledgments
337
+
338
+ Built with:
339
+ - [PyTorch](https://pytorch.org/)
340
+ - [Hugging Face Diffusers](https://github.com/huggingface/diffusers)
341
+ - [CLIP](https://openai.com/research/clip)
342
+
343
+ Enjoy creating with Byte Dream! 🎨
344
+ '''
345
+
346
+ return readme
347
+
348
+
349
+ def main():
350
+ parser = argparse.ArgumentParser(description="Upload Byte Dream to Hugging Face")
351
+
352
+ parser.add_argument(
353
+ "--model_path",
354
+ type=str,
355
+ default="./models/bytedream",
356
+ help="Path to model directory"
357
+ )
358
+
359
+ parser.add_argument(
360
+ "--repo_id",
361
+ type=str,
362
+ required=True,
363
+ help="Repository ID (e.g., username/bytedream)"
364
+ )
365
+
366
+ parser.add_argument(
367
+ "--token",
368
+ type=str,
369
+ default=None,
370
+ help="Hugging Face API token"
371
+ )
372
+
373
+ parser.add_argument(
374
+ "--private",
375
+ action="store_true",
376
+ help="Make repository private"
377
+ )
378
+
379
+ parser.add_argument(
380
+ "--create_space",
381
+ action="store_true",
382
+ help="Also create Gradio Space code"
383
+ )
384
+
385
+ args = parser.parse_args()
386
+
387
+ # Upload model
388
+ upload_to_huggingface(
389
+ model_path=args.model_path,
390
+ repo_id=args.repo_id,
391
+ token=args.token,
392
+ private=args.private,
393
+ )
394
+
395
+ # Create Space files if requested
396
+ if args.create_space:
397
+ print("\n\nCreating Gradio Space files...")
398
+
399
+ # Save Gradio app
400
+ with open("app.py", 'w') as f:
401
+ f.write(create_gradio_app())
402
+ print("✓ Created app.py for Gradio Space")
403
+
404
+ # Save README
405
+ readme = create_readme_for_hf(args.repo_id)
406
+ with open("README_HF.md", 'w') as f:
407
+ f.write(readme)
408
+ print("✓ Created README_HF.md")
409
+
410
+ print("\n📋 To deploy on Hugging Face Spaces:")
411
+ print("1. Go to https://huggingface.co/spaces")
412
+ print("2. Click 'Create new Space'")
413
+ print("3. Choose Gradio SDK")
414
+ print("4. Upload all files")
415
+ print("5. Select CPU hardware")
416
+ print("6. Deploy!")
417
+
418
+
419
+ if __name__ == "__main__":
420
+ main()