Upload folder using huggingface_hub
Browse files- .gitignore +73 -0
- LICENSE +21 -0
- README.md +182 -3
- SETUP_GUIDE.md +262 -0
- app.py +305 -0
- bytedream/__init__.py +21 -0
- bytedream/generator.py +317 -0
- bytedream/model.py +582 -0
- bytedream/pipeline.py +312 -0
- bytedream/scheduler.py +273 -0
- bytedream/utils.py +398 -0
- config.yaml +81 -0
- environment.yml +25 -0
- examples.py +316 -0
- infer.py +150 -0
- main.py +278 -0
- prepare_dataset.py +287 -0
- publish_to_hf.py +30 -0
- quick_start.py +124 -0
- requirements.txt +16 -0
- train.py +500 -0
- upload_to_hf.py +420 -0
.gitignore
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
.venv
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
|
| 36 |
+
# Jupyter Notebook
|
| 37 |
+
.ipynb_checkpoints
|
| 38 |
+
|
| 39 |
+
# PyTorch
|
| 40 |
+
*.pth
|
| 41 |
+
*.onnx
|
| 42 |
+
|
| 43 |
+
# Model checkpoints
|
| 44 |
+
models/
|
| 45 |
+
checkpoints/
|
| 46 |
+
*.bin
|
| 47 |
+
*.safetensors
|
| 48 |
+
|
| 49 |
+
# Outputs
|
| 50 |
+
outputs/
|
| 51 |
+
demo_outputs/
|
| 52 |
+
*.png
|
| 53 |
+
*.jpg
|
| 54 |
+
*.jpeg
|
| 55 |
+
*.webp
|
| 56 |
+
|
| 57 |
+
# Logs
|
| 58 |
+
logs/
|
| 59 |
+
*.log
|
| 60 |
+
|
| 61 |
+
# OS
|
| 62 |
+
.DS_Store
|
| 63 |
+
Thumbs.db
|
| 64 |
+
desktop.ini
|
| 65 |
+
|
| 66 |
+
# Temporary files
|
| 67 |
+
tmp/
|
| 68 |
+
temp/
|
| 69 |
+
*.tmp
|
| 70 |
+
|
| 71 |
+
# Hugging Face cache
|
| 72 |
+
.huggingface/
|
| 73 |
+
.cache/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Byte Dream
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,182 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte Dream - AI Image Generation Model
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
Byte Dream is a robust, production-ready text-to-image diffusion model optimized for CPU inference. This model uses advanced latent diffusion architecture to generate high-quality images from text prompts.
|
| 5 |
+
|
| 6 |
+
## Features
|
| 7 |
+
- **CPU Optimized**: Runs efficiently on CPU without GPU requirement
|
| 8 |
+
- **High Quality**: Generates 512x512 and higher resolution images
|
| 9 |
+
- **Fast Inference**: Optimized for speed with quality preservation
|
| 10 |
+
- **Hugging Face Ready**: Easy deployment to Hugging Face Spaces
|
| 11 |
+
- **Flexible**: Supports various sampling methods and customization
|
| 12 |
+
|
| 13 |
+
## Installation
|
| 14 |
+
|
| 15 |
+
### Using pip
|
| 16 |
+
```bash
|
| 17 |
+
pip install -r requirements.txt
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### Using conda
|
| 21 |
+
```bash
|
| 22 |
+
conda env create -f environment.yml
|
| 23 |
+
conda activate bytedream
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Usage
|
| 27 |
+
|
| 28 |
+
### Basic Image Generation
|
| 29 |
+
```python
|
| 30 |
+
from bytedream import ByteDreamGenerator
|
| 31 |
+
|
| 32 |
+
# Initialize generator
|
| 33 |
+
generator = ByteDreamGenerator()
|
| 34 |
+
|
| 35 |
+
# Generate image from prompt
|
| 36 |
+
image = generator.generate(
|
| 37 |
+
prompt="A beautiful sunset over mountains, digital art",
|
| 38 |
+
num_inference_steps=50,
|
| 39 |
+
guidance_scale=7.5
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Save image
|
| 43 |
+
image.save("output.png")
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### Advanced Usage
|
| 47 |
+
```python
|
| 48 |
+
from bytedream import ByteDreamGenerator
|
| 49 |
+
|
| 50 |
+
generator = ByteDreamGenerator(model_path="models/bytedream")
|
| 51 |
+
|
| 52 |
+
# Generate with custom parameters
|
| 53 |
+
image = generator.generate(
|
| 54 |
+
prompt="Cyberpunk city at night, neon lights, futuristic",
|
| 55 |
+
negative_prompt="blurry, low quality, distorted",
|
| 56 |
+
width=768,
|
| 57 |
+
height=768,
|
| 58 |
+
num_inference_steps=100,
|
| 59 |
+
guidance_scale=9.0,
|
| 60 |
+
seed=42
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
image.save("cyberpunk_city.png")
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Command Line Interface
|
| 67 |
+
```bash
|
| 68 |
+
# Generate image from command line
|
| 69 |
+
python infer.py --prompt "A dragon flying over castle" --output dragon.png
|
| 70 |
+
|
| 71 |
+
# With advanced options
|
| 72 |
+
python infer.py --prompt "Fantasy landscape" --negative "ugly, blurry" --steps 75 --guidance 8.0
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Gradio Web Interface
|
| 76 |
+
```bash
|
| 77 |
+
python app.py
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Model Architecture
|
| 81 |
+
|
| 82 |
+
Byte Dream uses a latent diffusion model with:
|
| 83 |
+
- **Text Encoder**: CLIP-based text understanding
|
| 84 |
+
- **UNet**: Noise prediction network with cross-attention
|
| 85 |
+
- **VAE**: Variational Autoencoder for image compression
|
| 86 |
+
- **Scheduler**: Advanced DDIM/PNDM sampling
|
| 87 |
+
|
| 88 |
+
## Training
|
| 89 |
+
|
| 90 |
+
### Prepare Dataset
|
| 91 |
+
```bash
|
| 92 |
+
python prepare_dataset.py --data_dir ./dataset --output_dir ./processed_data
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### Train Model
|
| 96 |
+
```bash
|
| 97 |
+
python train.py \
|
| 98 |
+
--train_data ./processed_data \
|
| 99 |
+
--output_dir ./models/bytedream \
|
| 100 |
+
--epochs 100 \
|
| 101 |
+
--batch_size 4 \
|
| 102 |
+
--learning_rate 1e-5
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Hugging Face Deployment
|
| 106 |
+
|
| 107 |
+
### Upload to Hugging Face
|
| 108 |
+
```bash
|
| 109 |
+
python upload_to_hf.py --model_id your_username/bytedream
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Deploy to Spaces
|
| 113 |
+
1. Create new Space on Hugging Face
|
| 114 |
+
2. Select Gradio SDK
|
| 115 |
+
3. Upload all files
|
| 116 |
+
4. Configure CPU hardware
|
| 117 |
+
5. Deploy automatically
|
| 118 |
+
|
| 119 |
+
## Configuration
|
| 120 |
+
|
| 121 |
+
Edit `config.yaml` for custom settings:
|
| 122 |
+
- Model dimensions
|
| 123 |
+
- Sampling parameters
|
| 124 |
+
- Training hyperparameters
|
| 125 |
+
- CPU optimization settings
|
| 126 |
+
|
| 127 |
+
## Performance Optimization
|
| 128 |
+
|
| 129 |
+
### CPU Optimization
|
| 130 |
+
- OpenVINO integration available
|
| 131 |
+
- ONNX runtime support
|
| 132 |
+
- Mixed precision (FP16/FP32)
|
| 133 |
+
- Batch processing
|
| 134 |
+
|
| 135 |
+
### Memory Management
|
| 136 |
+
- Gradient checkpointing
|
| 137 |
+
- Model offloading
|
| 138 |
+
- Progressive generation
|
| 139 |
+
|
| 140 |
+
## File Structure
|
| 141 |
+
```
|
| 142 |
+
Byte Dream/
|
| 143 |
+
├── bytedream/ # Core package
|
| 144 |
+
│ ├── __init__.py
|
| 145 |
+
│ ├── model.py # Model architecture
|
| 146 |
+
│ ├── pipeline.py # Generation pipeline
|
| 147 |
+
│ ├── scheduler.py # Diffusion scheduler
|
| 148 |
+
│ └── utils.py # Utilities
|
| 149 |
+
├── train.py # Training script
|
| 150 |
+
├── infer.py # Inference script
|
| 151 |
+
├── app.py # Gradio web interface
|
| 152 |
+
├── config.yaml # Configuration
|
| 153 |
+
├── requirements.txt # Dependencies
|
| 154 |
+
└── README.md # This file
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## Examples
|
| 158 |
+
|
| 159 |
+
Generate various types of images:
|
| 160 |
+
- Digital art and illustrations
|
| 161 |
+
- Photorealistic scenes
|
| 162 |
+
- Abstract concepts
|
| 163 |
+
- Character designs
|
| 164 |
+
- Landscapes and environments
|
| 165 |
+
|
| 166 |
+
## License
|
| 167 |
+
|
| 168 |
+
MIT License - See LICENSE file for details
|
| 169 |
+
|
| 170 |
+
## Citation
|
| 171 |
+
|
| 172 |
+
If you use Byte Dream in your research:
|
| 173 |
+
```bibtex
|
| 174 |
+
@software{bytedream2024,
|
| 175 |
+
title={Byte Dream: CPU-Optimized Text-to-Image Generation},
|
| 176 |
+
year={2024}
|
| 177 |
+
}
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## Support
|
| 181 |
+
|
| 182 |
+
For issues and questions, please open a GitHub issue or contact the maintainers.
|
SETUP_GUIDE.md
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte Dream - Setup Guide
|
| 2 |
+
|
| 3 |
+
## Quick Start (Windows)
|
| 4 |
+
|
| 5 |
+
### 1. Install Dependencies
|
| 6 |
+
|
| 7 |
+
#### Option A: Using pip (Recommended)
|
| 8 |
+
```cmd
|
| 9 |
+
cd "c:\Users\Enzo\Documents\Byte Dream"
|
| 10 |
+
pip install -r requirements.txt
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
#### Option B: Using conda
|
| 14 |
+
```cmd
|
| 15 |
+
cd "c:\Users\Enzo\Documents\Byte Dream"
|
| 16 |
+
conda env create -f environment.yml
|
| 17 |
+
conda activate bytedream
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### 2. Verify Installation
|
| 21 |
+
```cmd
|
| 22 |
+
python quick_start.py
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
This will check if all dependencies are installed and test the model.
|
| 26 |
+
|
| 27 |
+
### 3. Generate Your First Image
|
| 28 |
+
|
| 29 |
+
#### Command Line
|
| 30 |
+
```cmd
|
| 31 |
+
python infer.py --prompt "A beautiful sunset over mountains, digital art" --output sunset.png
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
#### Web Interface
|
| 35 |
+
```cmd
|
| 36 |
+
python app.py
|
| 37 |
+
```
|
| 38 |
+
Then open http://localhost:7860 in your browser.
|
| 39 |
+
|
| 40 |
+
#### Python Script
|
| 41 |
+
```python
|
| 42 |
+
from bytedream import ByteDreamGenerator
|
| 43 |
+
|
| 44 |
+
generator = ByteDreamGenerator()
|
| 45 |
+
image = generator.generate(
|
| 46 |
+
prompt="A cyberpunk city at night with neon lights",
|
| 47 |
+
num_inference_steps=50,
|
| 48 |
+
guidance_scale=7.5
|
| 49 |
+
)
|
| 50 |
+
image.save("cyberpunk_city.png")
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Model Training
|
| 54 |
+
|
| 55 |
+
### Prepare Your Dataset
|
| 56 |
+
|
| 57 |
+
1. Collect images in a folder (JPG, PNG formats)
|
| 58 |
+
2. Optionally add .txt files with captions for each image
|
| 59 |
+
3. Run preparation script:
|
| 60 |
+
|
| 61 |
+
```cmd
|
| 62 |
+
python prepare_dataset.py --input ./my_images --output ./processed_data --size 512
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Train the Model
|
| 66 |
+
|
| 67 |
+
```cmd
|
| 68 |
+
python train.py --train_data ./processed_data --output_dir ./models/bytedream --epochs 100 --batch_size 4
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Training time depends on:
|
| 72 |
+
- Dataset size
|
| 73 |
+
- Number of epochs
|
| 74 |
+
- CPU speed (expect several hours to days for CPU training)
|
| 75 |
+
|
| 76 |
+
## Hugging Face Deployment
|
| 77 |
+
|
| 78 |
+
### Upload to Hugging Face Hub
|
| 79 |
+
|
| 80 |
+
1. Get your Hugging Face token from https://huggingface.co/settings/tokens
|
| 81 |
+
2. Upload model:
|
| 82 |
+
|
| 83 |
+
```cmd
|
| 84 |
+
python upload_to_hf.py --model_path ./models/bytedream --repo_id your_username/bytedream --token YOUR_TOKEN
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Deploy to Spaces
|
| 88 |
+
|
| 89 |
+
1. Create Gradio app file (already included as `app.py`)
|
| 90 |
+
2. Go to https://huggingface.co/spaces
|
| 91 |
+
3. Click "Create new Space"
|
| 92 |
+
4. Choose Gradio SDK
|
| 93 |
+
5. Upload all project files
|
| 94 |
+
6. Select CPU hardware (COSTAR or similar)
|
| 95 |
+
7. Deploy!
|
| 96 |
+
|
| 97 |
+
## File Structure
|
| 98 |
+
|
| 99 |
+
```
|
| 100 |
+
Byte Dream/
|
| 101 |
+
├── bytedream/ # Core package
|
| 102 |
+
│ ├── __init__.py # Package initialization
|
| 103 |
+
│ ├── model.py # Neural network architectures
|
| 104 |
+
│ ├── pipeline.py # Generation pipeline
|
| 105 |
+
│ ├── scheduler.py # Diffusion scheduler
|
| 106 |
+
│ ├── generator.py # Main generator class
|
| 107 |
+
│ └── utils.py # Utility functions
|
| 108 |
+
├── train.py # Training script
|
| 109 |
+
├── infer.py # Command-line inference
|
| 110 |
+
├── app.py # Gradio web interface
|
| 111 |
+
├── main.py # High-level application API
|
| 112 |
+
├── prepare_dataset.py # Dataset preparation
|
| 113 |
+
├── upload_to_hf.py # Hugging Face upload
|
| 114 |
+
├── quick_start.py # Quick start guide
|
| 115 |
+
├── config.yaml # Configuration
|
| 116 |
+
├── requirements.txt # Python dependencies
|
| 117 |
+
├── environment.yml # Conda environment
|
| 118 |
+
├── README.md # Documentation
|
| 119 |
+
└── LICENSE # MIT License
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Usage Examples
|
| 123 |
+
|
| 124 |
+
### Basic Generation
|
| 125 |
+
```cmd
|
| 126 |
+
python infer.py -p "A dragon flying over castle" -o dragon.png
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Advanced Parameters
|
| 130 |
+
```cmd
|
| 131 |
+
python infer.py -p "Fantasy landscape" -n "ugly, blurry" -W 768 -H 768 -s 75 -g 8.0 --seed 42
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### Batch Generation (Python)
|
| 135 |
+
```python
|
| 136 |
+
from bytedream import ByteDreamGenerator
|
| 137 |
+
|
| 138 |
+
generator = ByteDreamGenerator()
|
| 139 |
+
|
| 140 |
+
prompts = [
|
| 141 |
+
"Sunset beach, palm trees, tropical paradise",
|
| 142 |
+
"Mountain landscape, snow peaks, alpine lake",
|
| 143 |
+
"Forest path, sunlight filtering through trees"
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
images = generator.generate_batch(
|
| 147 |
+
prompts=prompts,
|
| 148 |
+
width=512,
|
| 149 |
+
height=512,
|
| 150 |
+
num_inference_steps=50
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
for i, img in enumerate(images):
|
| 154 |
+
img.save(f"landscape_{i}.png")
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## Performance Optimization
|
| 158 |
+
|
| 159 |
+
### CPU Optimization
|
| 160 |
+
The model is already optimized for CPU, but you can:
|
| 161 |
+
|
| 162 |
+
1. Increase threads in `config.yaml`:
|
| 163 |
+
```yaml
|
| 164 |
+
cpu_optimization:
|
| 165 |
+
threads: 8 # Set to number of CPU cores
|
| 166 |
+
precision: fp32
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
2. Use fewer inference steps for faster generation:
|
| 170 |
+
```cmd
|
| 171 |
+
python infer.py -p "Quick preview" -s 20
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
3. Generate smaller images:
|
| 175 |
+
```cmd
|
| 176 |
+
python infer.py -p "Small image" -W 256 -H 256
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
### Memory Management
|
| 180 |
+
For systems with limited RAM:
|
| 181 |
+
|
| 182 |
+
1. Enable memory efficient mode (already default)
|
| 183 |
+
2. Generate one image at a time
|
| 184 |
+
3. Restart Python between batch generations
|
| 185 |
+
|
| 186 |
+
## Troubleshooting
|
| 187 |
+
|
| 188 |
+
### Import Errors
|
| 189 |
+
If you get import errors:
|
| 190 |
+
```cmd
|
| 191 |
+
pip install --upgrade torch transformers diffusers
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Memory Errors
|
| 195 |
+
Reduce image size or inference steps:
|
| 196 |
+
```cmd
|
| 197 |
+
python infer.py -p "Test" -W 256 -H 256 -s 20
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
### Slow Generation
|
| 201 |
+
CPU generation is slower than GPU. Expect:
|
| 202 |
+
- 256x256: ~30-60 seconds
|
| 203 |
+
- 512x512: ~2-5 minutes
|
| 204 |
+
- 768x768: ~5-10 minutes
|
| 205 |
+
|
| 206 |
+
Times vary by CPU speed and number of steps.
|
| 207 |
+
|
| 208 |
+
### Model Not Loading
|
| 209 |
+
The model needs trained weights. Either:
|
| 210 |
+
1. Train your own model using `train.py`
|
| 211 |
+
2. Download pretrained weights from Hugging Face
|
| 212 |
+
3. Use Stable Diffusion weights as base
|
| 213 |
+
|
| 214 |
+
## Tips for Better Results
|
| 215 |
+
|
| 216 |
+
### Writing Prompts
|
| 217 |
+
- Be specific and descriptive
|
| 218 |
+
- Include style references ("digital art", "oil painting")
|
| 219 |
+
- Mention lighting ("dramatic lighting", "soft sunlight")
|
| 220 |
+
- Add quality modifiers ("highly detailed", "4K", "masterpiece")
|
| 221 |
+
|
| 222 |
+
### Negative Prompts
|
| 223 |
+
Use to avoid common issues:
|
| 224 |
+
```
|
| 225 |
+
ugly, blurry, low quality, distorted, deformed, bad anatomy, extra limbs
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
### Parameters
|
| 229 |
+
- **Steps**: 20-30 (quick), 50 (good), 75-100 (best)
|
| 230 |
+
- **Guidance**: 5-7 (creative), 7-9 (balanced), 9-12 (strict)
|
| 231 |
+
- **Resolution**: Start with 512x512, increase if needed
|
| 232 |
+
|
| 233 |
+
## Advanced Features
|
| 234 |
+
|
| 235 |
+
### Custom Schedulers
|
| 236 |
+
Edit `config.yaml` to try different schedulers:
|
| 237 |
+
- DDIM (default) - Fast, deterministic
|
| 238 |
+
- EulerDiscrete - Alternative sampling
|
| 239 |
+
|
| 240 |
+
### Fine-tuning
|
| 241 |
+
Fine-tune on specific styles:
|
| 242 |
+
1. Collect 50-100 images in desired style
|
| 243 |
+
2. Prepare dataset
|
| 244 |
+
3. Train for 50-100 epochs with low learning rate (1e-6)
|
| 245 |
+
|
| 246 |
+
## Support
|
| 247 |
+
|
| 248 |
+
For issues and questions:
|
| 249 |
+
1. Check this guide first
|
| 250 |
+
2. Review README.md
|
| 251 |
+
3. Check code comments
|
| 252 |
+
4. Visit Hugging Face documentation
|
| 253 |
+
|
| 254 |
+
## Updates
|
| 255 |
+
|
| 256 |
+
Check for updates and improvements:
|
| 257 |
+
- New model architectures
|
| 258 |
+
- Better CPU optimization
|
| 259 |
+
- Additional features
|
| 260 |
+
- Bug fixes
|
| 261 |
+
|
| 262 |
+
Enjoy creating with Byte Dream! 🎨
|
app.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream - Gradio Web Interface
|
| 3 |
+
Interactive web UI for image generation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from bytedream.generator import ByteDreamGenerator
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Initialize generator
|
| 12 |
+
print("Loading Byte Dream model...")
|
| 13 |
+
try:
|
| 14 |
+
generator = ByteDreamGenerator(
|
| 15 |
+
model_path="./models/bytedream",
|
| 16 |
+
config_path="config.yaml",
|
| 17 |
+
device="cpu",
|
| 18 |
+
)
|
| 19 |
+
print("✓ Model loaded successfully!")
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print(f"⚠ Warning: Could not load model: {e}")
|
| 22 |
+
print(" Please train the model or download pretrained weights.")
|
| 23 |
+
generator = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def generate_image(
|
| 27 |
+
prompt,
|
| 28 |
+
negative_prompt,
|
| 29 |
+
width,
|
| 30 |
+
height,
|
| 31 |
+
num_steps,
|
| 32 |
+
guidance_scale,
|
| 33 |
+
seed,
|
| 34 |
+
):
|
| 35 |
+
"""Generate image from prompt"""
|
| 36 |
+
|
| 37 |
+
if generator is None:
|
| 38 |
+
return None, "Error: Model not loaded. Please train or download model weights."
|
| 39 |
+
|
| 40 |
+
# Convert seed to None if -1
|
| 41 |
+
seed_value = None if seed == -1 else seed
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
# Generate image
|
| 45 |
+
image = generator.generate(
|
| 46 |
+
prompt=prompt,
|
| 47 |
+
negative_prompt=negative_prompt if negative_prompt else None,
|
| 48 |
+
width=int(width),
|
| 49 |
+
height=int(height),
|
| 50 |
+
num_inference_steps=int(num_steps),
|
| 51 |
+
guidance_scale=float(guidance_scale),
|
| 52 |
+
seed=seed_value,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
return image, "Success! ✓"
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error generating image: {e}")
|
| 59 |
+
import traceback
|
| 60 |
+
traceback.print_exc()
|
| 61 |
+
return None, f"Error: {str(e)}"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Create Gradio interface
|
| 65 |
+
with gr.Blocks(
|
| 66 |
+
title="Byte Dream - AI Image Generator",
|
| 67 |
+
theme=gr.themes.Soft(),
|
| 68 |
+
css="""
|
| 69 |
+
.gradio-container {
|
| 70 |
+
max-width: 1400px !important;
|
| 71 |
+
}
|
| 72 |
+
#main-heading {
|
| 73 |
+
text-align: center;
|
| 74 |
+
margin-bottom: 20px;
|
| 75 |
+
}
|
| 76 |
+
.description {
|
| 77 |
+
text-align: center;
|
| 78 |
+
margin-bottom: 30px;
|
| 79 |
+
}
|
| 80 |
+
"""
|
| 81 |
+
) as demo:
|
| 82 |
+
|
| 83 |
+
gr.Markdown("""
|
| 84 |
+
# 🎨 Byte Dream - AI Image Generator
|
| 85 |
+
|
| 86 |
+
### Transform your imagination into reality with advanced AI
|
| 87 |
+
|
| 88 |
+
Powered by state-of-the-art latent diffusion models, optimized for CPU inference.
|
| 89 |
+
""")
|
| 90 |
+
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(scale=1):
|
| 93 |
+
gr.Markdown("### 📝 Create Your Prompt")
|
| 94 |
+
|
| 95 |
+
prompt_input = gr.Textbox(
|
| 96 |
+
label="Positive Prompt",
|
| 97 |
+
placeholder="Describe the image you want to create...",
|
| 98 |
+
lines=3,
|
| 99 |
+
value="A beautiful sunset over mountains, digital art, highly detailed, vibrant colors",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
negative_prompt_input = gr.Textbox(
|
| 103 |
+
label="Negative Prompt (Optional)",
|
| 104 |
+
placeholder="What to avoid in the image...",
|
| 105 |
+
lines=2,
|
| 106 |
+
value="ugly, blurry, low quality, distorted, deformed",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
gr.Markdown("### ⚙️ Settings")
|
| 110 |
+
|
| 111 |
+
with gr.Row():
|
| 112 |
+
width_slider = gr.Slider(
|
| 113 |
+
minimum=256,
|
| 114 |
+
maximum=1024,
|
| 115 |
+
step=64,
|
| 116 |
+
value=512,
|
| 117 |
+
label="Width (px)",
|
| 118 |
+
info="Image width - multiples of 64"
|
| 119 |
+
)
|
| 120 |
+
height_slider = gr.Slider(
|
| 121 |
+
minimum=256,
|
| 122 |
+
maximum=1024,
|
| 123 |
+
step=64,
|
| 124 |
+
value=512,
|
| 125 |
+
label="Height (px)",
|
| 126 |
+
info="Image height - multiples of 64"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
with gr.Row():
|
| 130 |
+
steps_slider = gr.Slider(
|
| 131 |
+
minimum=10,
|
| 132 |
+
maximum=150,
|
| 133 |
+
step=5,
|
| 134 |
+
value=50,
|
| 135 |
+
label="Inference Steps",
|
| 136 |
+
info="More steps = better quality but slower"
|
| 137 |
+
)
|
| 138 |
+
guidance_slider = gr.Slider(
|
| 139 |
+
minimum=1.0,
|
| 140 |
+
maximum=20.0,
|
| 141 |
+
step=0.5,
|
| 142 |
+
value=7.5,
|
| 143 |
+
label="Guidance Scale",
|
| 144 |
+
info="Higher = closer to prompt, Lower = more creative"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
seed_input = gr.Number(
|
| 148 |
+
label="Seed",
|
| 149 |
+
value=-1,
|
| 150 |
+
precision=0,
|
| 151 |
+
info="-1 for random, any number for reproducibility",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
generate_btn = gr.Button(
|
| 155 |
+
"🎨 Generate Image",
|
| 156 |
+
variant="primary",
|
| 157 |
+
size="lg",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
with gr.Column(scale=1):
|
| 161 |
+
gr.Markdown("### 🖼️ Result")
|
| 162 |
+
|
| 163 |
+
output_image = gr.Image(
|
| 164 |
+
label="Generated Image",
|
| 165 |
+
type="pil",
|
| 166 |
+
height=512,
|
| 167 |
+
)
|
| 168 |
+
status_text = gr.Textbox(
|
| 169 |
+
label="Status",
|
| 170 |
+
interactive=False,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
download_btn = gr.File(
|
| 174 |
+
label="Download",
|
| 175 |
+
visible=True,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Tips section
|
| 179 |
+
with gr.Accordion("💡 Tips for Better Results", open=False):
|
| 180 |
+
gr.Markdown("""
|
| 181 |
+
**Writing Effective Prompts:**
|
| 182 |
+
- Be specific and descriptive
|
| 183 |
+
- Include art style references (e.g., "digital art", "oil painting", "watercolor")
|
| 184 |
+
- Mention lighting ("dramatic lighting", "soft sunlight", "neon lights")
|
| 185 |
+
- Add quality modifiers ("highly detailed", "4K", "masterpiece")
|
| 186 |
+
- Specify mood and atmosphere ("peaceful", "dramatic", "mysterious")
|
| 187 |
+
|
| 188 |
+
**Using Negative Prompts:**
|
| 189 |
+
- Remove unwanted elements ("no people", "no text")
|
| 190 |
+
- Avoid quality issues ("no blur", "no distortion")
|
| 191 |
+
- Fix common problems ("bad anatomy", "extra limbs")
|
| 192 |
+
|
| 193 |
+
**Parameter Guide:**
|
| 194 |
+
- **Steps**: 20-30 for quick previews, 50-75 for final images, 100+ for best quality
|
| 195 |
+
- **Guidance**: 5-7 for creative freedom, 7-9 for balanced, 9-12 for strict prompt following
|
| 196 |
+
- **Resolution**: Higher = more detail but slower. Start with 512x512, increase if needed
|
| 197 |
+
""")
|
| 198 |
+
|
| 199 |
+
# Examples section
|
| 200 |
+
gr.Markdown("### 💡 Example Prompts")
|
| 201 |
+
|
| 202 |
+
with gr.Row():
|
| 203 |
+
example_btn1 = gr.Button(
|
| 204 |
+
"🌆 Cyberpunk City",
|
| 205 |
+
size="sm",
|
| 206 |
+
)
|
| 207 |
+
example_btn2 = gr.Button(
|
| 208 |
+
"🐉 Fantasy Dragon",
|
| 209 |
+
size="sm",
|
| 210 |
+
)
|
| 211 |
+
example_btn3 = gr.Button(
|
| 212 |
+
"🏔️ Peaceful Landscape",
|
| 213 |
+
size="sm",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
with gr.Row():
|
| 217 |
+
example_btn4 = gr.Button(
|
| 218 |
+
"👤 Character Portrait",
|
| 219 |
+
size="sm",
|
| 220 |
+
)
|
| 221 |
+
example_btn5 = gr.Button(
|
| 222 |
+
"🌊 Underwater Scene",
|
| 223 |
+
size="sm",
|
| 224 |
+
)
|
| 225 |
+
example_btn6 = gr.Button(
|
| 226 |
+
"🎨 Abstract Art",
|
| 227 |
+
size="sm",
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Example prompt values
|
| 231 |
+
example_prompts = {
|
| 232 |
+
"example_btn1": (
|
| 233 |
+
"A cyberpunk city at night with neon lights, futuristic architecture, flying cars, rain-slicked streets, highly detailed, digital art, cinematic lighting",
|
| 234 |
+
"ugly, blurry, low quality, distorted, dark, gloomy"
|
| 235 |
+
),
|
| 236 |
+
"example_btn2": (
|
| 237 |
+
"A majestic dragon breathing fire, fantasy art, dramatic lighting, epic scene, scales gleaming, powerful wings, mountain landscape background",
|
| 238 |
+
"ugly, deformed, blurry, low quality, cartoonish"
|
| 239 |
+
),
|
| 240 |
+
"example_btn3": (
|
| 241 |
+
"A peaceful cottage in a meadow, wildflowers, sunny day, blue sky, studio ghibli style, serene atmosphere, pastoral landscape",
|
| 242 |
+
"people, animals, buildings, urban, dark, stormy"
|
| 243 |
+
),
|
| 244 |
+
"example_btn4": (
|
| 245 |
+
"Portrait of a warrior princess, ornate armor, fantasy setting, intricate details, character design, dramatic lighting, confident expression, long flowing hair",
|
| 246 |
+
"ugly, deformed, asymmetrical, blurry, low quality, bad anatomy"
|
| 247 |
+
),
|
| 248 |
+
"example_btn5": (
|
| 249 |
+
"Underwater coral reef, tropical fish, sunlight filtering through water, photorealistic, vibrant colors, marine life, crystal clear water",
|
| 250 |
+
"polluted, murky, dark, blurry, low quality"
|
| 251 |
+
),
|
| 252 |
+
"example_btn6": (
|
| 253 |
+
"Abstract geometric art, colorful shapes, dynamic composition, modern art, bold patterns, artistic expression, vivid colors",
|
| 254 |
+
"representational, realistic, boring, dull colors, simple"
|
| 255 |
+
),
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
# Connect example buttons
|
| 259 |
+
def set_example(prompt, negative):
|
| 260 |
+
return prompt, negative, "Click Generate to create!"
|
| 261 |
+
|
| 262 |
+
for btn_name, (prompt, negative) in example_prompts.items():
|
| 263 |
+
demo.get_component(btn_name).click(
|
| 264 |
+
fn=set_example,
|
| 265 |
+
inputs=[gr.State(prompt), gr.State(negative)],
|
| 266 |
+
outputs=[prompt_input, negative_prompt_input, status_text],
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Connect generate button
|
| 270 |
+
generate_btn.click(
|
| 271 |
+
fn=generate_image,
|
| 272 |
+
inputs=[
|
| 273 |
+
prompt_input,
|
| 274 |
+
negative_prompt_input,
|
| 275 |
+
width_slider,
|
| 276 |
+
height_slider,
|
| 277 |
+
steps_slider,
|
| 278 |
+
guidance_slider,
|
| 279 |
+
seed_input,
|
| 280 |
+
],
|
| 281 |
+
outputs=[output_image, status_text],
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Footer
|
| 285 |
+
gr.Markdown("""
|
| 286 |
+
---
|
| 287 |
+
**Byte Dream** v1.0.0 | Powered by Latent Diffusion Models | Optimized for CPU Inference
|
| 288 |
+
|
| 289 |
+
Created with ❤️ using PyTorch and Hugging Face Diffusers
|
| 290 |
+
""")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
print("\n" + "="*60)
|
| 295 |
+
print("Starting Byte Dream Web Interface")
|
| 296 |
+
print("="*60)
|
| 297 |
+
print("\nOpening browser...")
|
| 298 |
+
print("Press Ctrl+C to close\n")
|
| 299 |
+
|
| 300 |
+
demo.launch(
|
| 301 |
+
server_name="0.0.0.0",
|
| 302 |
+
server_port=7860,
|
| 303 |
+
share=False,
|
| 304 |
+
show_error=True,
|
| 305 |
+
)
|
bytedream/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream - AI Image Generation Model
|
| 3 |
+
Production-ready text-to-image diffusion model optimized for CPU inference
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
| 7 |
+
__author__ = "Byte Dream Team"
|
| 8 |
+
|
| 9 |
+
from .generator import ByteDreamGenerator
|
| 10 |
+
from .model import UNet2DConditionModel, AutoencoderKL, CLIPTextModel
|
| 11 |
+
from .pipeline import ByteDreamPipeline
|
| 12 |
+
from .scheduler import DDIMScheduler
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"ByteDreamGenerator",
|
| 16 |
+
"UNet2DConditionModel",
|
| 17 |
+
"AutoencoderKL",
|
| 18 |
+
"CLIPTextModel",
|
| 19 |
+
"ByteDreamPipeline",
|
| 20 |
+
"DDIMScheduler",
|
| 21 |
+
]
|
bytedream/generator.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream Generator
|
| 3 |
+
Main inference engine optimized for CPU with advanced features
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional, Union, List
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
import gc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ByteDreamGenerator:
|
| 16 |
+
"""
|
| 17 |
+
Production-ready image generation engine
|
| 18 |
+
Optimized for CPU inference with memory efficiency
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
model_path: Optional[str] = None,
|
| 24 |
+
config_path: str = "config.yaml",
|
| 25 |
+
device: str = "cpu",
|
| 26 |
+
use_safetensors: bool = True,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Initialize Byte Dream generator
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_path: Path to trained model weights
|
| 33 |
+
config_path: Path to configuration file
|
| 34 |
+
device: Device to run on (default: cpu)
|
| 35 |
+
use_safetensors: Use safetensors format if available
|
| 36 |
+
"""
|
| 37 |
+
self.device = device
|
| 38 |
+
self.config_path = config_path
|
| 39 |
+
self.use_safetensors = use_safetensors
|
| 40 |
+
|
| 41 |
+
# Load configuration
|
| 42 |
+
self.config = self._load_config(config_path)
|
| 43 |
+
|
| 44 |
+
# Initialize components
|
| 45 |
+
print("Initializing Byte Dream Generator...")
|
| 46 |
+
self.pipeline = self._initialize_pipeline(model_path)
|
| 47 |
+
|
| 48 |
+
# Optimize for CPU
|
| 49 |
+
self._optimize_for_cpu()
|
| 50 |
+
|
| 51 |
+
print("✓ Byte Dream Generator ready!")
|
| 52 |
+
|
| 53 |
+
def _load_config(self, config_path: str) -> dict:
|
| 54 |
+
"""Load configuration from YAML file"""
|
| 55 |
+
try:
|
| 56 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 57 |
+
config = yaml.safe_load(f)
|
| 58 |
+
return config
|
| 59 |
+
except FileNotFoundError:
|
| 60 |
+
print(f"Warning: Config file {config_path} not found. Using defaults.")
|
| 61 |
+
return self._get_default_config()
|
| 62 |
+
|
| 63 |
+
def _get_default_config(self) -> dict:
|
| 64 |
+
"""Get default configuration"""
|
| 65 |
+
return {
|
| 66 |
+
'model': {
|
| 67 |
+
'unet': {
|
| 68 |
+
'in_channels': 4,
|
| 69 |
+
'out_channels': 4,
|
| 70 |
+
'block_out_channels': [320, 640, 1280, 1280],
|
| 71 |
+
'layers_per_block': 2,
|
| 72 |
+
'attention_head_dim': 8,
|
| 73 |
+
'cross_attention_dim': 768,
|
| 74 |
+
'use_linear_projection': True,
|
| 75 |
+
},
|
| 76 |
+
'scheduler': {
|
| 77 |
+
'name': 'DDIM',
|
| 78 |
+
'num_train_timesteps': 1000,
|
| 79 |
+
'beta_start': 0.00085,
|
| 80 |
+
'beta_end': 0.012,
|
| 81 |
+
'beta_schedule': 'scaled_linear',
|
| 82 |
+
'clip_sample': False,
|
| 83 |
+
'set_alpha_to_one': False,
|
| 84 |
+
}
|
| 85 |
+
},
|
| 86 |
+
'generation': {
|
| 87 |
+
'width': 512,
|
| 88 |
+
'height': 512,
|
| 89 |
+
'num_inference_steps': 50,
|
| 90 |
+
'guidance_scale': 7.5,
|
| 91 |
+
'negative_prompt': 'ugly, blurry, low quality, distorted, deformed',
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def _initialize_pipeline(self, model_path: Optional[str]):
|
| 96 |
+
"""Initialize the generation pipeline"""
|
| 97 |
+
from bytedream.model import create_unet, create_vae, create_text_encoder
|
| 98 |
+
from bytedream.scheduler import create_scheduler
|
| 99 |
+
from bytedream.pipeline import ByteDreamPipeline
|
| 100 |
+
|
| 101 |
+
# Create model components
|
| 102 |
+
print("Creating UNet...")
|
| 103 |
+
unet = create_unet(self.config)
|
| 104 |
+
|
| 105 |
+
print("Creating VAE...")
|
| 106 |
+
vae = create_vae(self.config)
|
| 107 |
+
|
| 108 |
+
print("Creating Text Encoder...")
|
| 109 |
+
text_encoder = create_text_encoder(self.config)
|
| 110 |
+
|
| 111 |
+
print("Creating Scheduler...")
|
| 112 |
+
scheduler = create_scheduler(self.config)
|
| 113 |
+
|
| 114 |
+
# Load pretrained weights if provided
|
| 115 |
+
if model_path:
|
| 116 |
+
self._load_model_weights(unet, model_path)
|
| 117 |
+
|
| 118 |
+
# Create pipeline
|
| 119 |
+
pipeline = ByteDreamPipeline(
|
| 120 |
+
text_encoder=text_encoder,
|
| 121 |
+
vae=vae,
|
| 122 |
+
unet=unet,
|
| 123 |
+
scheduler=scheduler,
|
| 124 |
+
device=self.device,
|
| 125 |
+
dtype=torch.float32,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return pipeline
|
| 129 |
+
|
| 130 |
+
def _load_model_weights(self, unet, model_path: str):
|
| 131 |
+
"""Load pretrained model weights"""
|
| 132 |
+
model_file = Path(model_path) / "unet_pytorch_model.bin"
|
| 133 |
+
|
| 134 |
+
if not model_file.exists():
|
| 135 |
+
model_file = Path(model_path) / "pytorch_model.bin"
|
| 136 |
+
|
| 137 |
+
if model_file.exists():
|
| 138 |
+
print(f"Loading weights from {model_file}...")
|
| 139 |
+
checkpoint = torch.load(model_file, map_location=self.device)
|
| 140 |
+
|
| 141 |
+
if 'unet_state_dict' in checkpoint:
|
| 142 |
+
unet.load_state_dict(checkpoint['unet_state_dict'])
|
| 143 |
+
else:
|
| 144 |
+
unet.load_state_dict(checkpoint)
|
| 145 |
+
|
| 146 |
+
print("✓ Weights loaded successfully!")
|
| 147 |
+
else:
|
| 148 |
+
print("⚠ No pretrained weights found. Using random initialization.")
|
| 149 |
+
print(" Train the model or download pretrained weights.")
|
| 150 |
+
|
| 151 |
+
def _optimize_for_cpu(self):
|
| 152 |
+
"""Optimize pipeline for CPU inference"""
|
| 153 |
+
# Set number of threads
|
| 154 |
+
cpu_config = self.config.get('cpu_optimization', {})
|
| 155 |
+
threads = cpu_config.get('threads', -1)
|
| 156 |
+
|
| 157 |
+
if threads > 0:
|
| 158 |
+
torch.set_num_threads(threads)
|
| 159 |
+
else:
|
| 160 |
+
# Use all available cores
|
| 161 |
+
import os
|
| 162 |
+
torch.set_num_threads(os.cpu_count())
|
| 163 |
+
|
| 164 |
+
# Enable memory efficient mode
|
| 165 |
+
self.pipeline.enable_memory_efficient_mode()
|
| 166 |
+
|
| 167 |
+
print(f"✓ Optimized for CPU ({torch.get_num_threads()} threads)")
|
| 168 |
+
|
| 169 |
+
@torch.no_grad()
|
| 170 |
+
def generate(
|
| 171 |
+
self,
|
| 172 |
+
prompt: str,
|
| 173 |
+
negative_prompt: Optional[str] = None,
|
| 174 |
+
width: Optional[int] = None,
|
| 175 |
+
height: Optional[int] = None,
|
| 176 |
+
num_inference_steps: Optional[int] = None,
|
| 177 |
+
guidance_scale: Optional[float] = None,
|
| 178 |
+
seed: Optional[int] = None,
|
| 179 |
+
eta: float = 0.0,
|
| 180 |
+
output_type: str = "pil",
|
| 181 |
+
) -> Image.Image:
|
| 182 |
+
"""
|
| 183 |
+
Generate image from text prompt
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
prompt: Text description of desired image
|
| 187 |
+
negative_prompt: Things to avoid in the image
|
| 188 |
+
width: Output image width (default: 512)
|
| 189 |
+
height: Output image height (default: 512)
|
| 190 |
+
num_inference_steps: Number of denoising steps (default: 50)
|
| 191 |
+
guidance_scale: How closely to follow prompt (default: 7.5)
|
| 192 |
+
seed: Random seed for reproducibility
|
| 193 |
+
eta: DDIM eta parameter (0.0 for deterministic)
|
| 194 |
+
output_type: Output format ("pil" or "tensor")
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Generated PIL Image
|
| 198 |
+
"""
|
| 199 |
+
# Get default values from config
|
| 200 |
+
gen_config = self.config.get('generation', {})
|
| 201 |
+
|
| 202 |
+
width = width or gen_config.get('width', 512)
|
| 203 |
+
height = height or gen_config.get('height', 512)
|
| 204 |
+
num_inference_steps = num_inference_steps or gen_config.get('num_inference_steps', 50)
|
| 205 |
+
guidance_scale = guidance_scale or gen_config.get('guidance_scale', 7.5)
|
| 206 |
+
negative_prompt = negative_prompt or gen_config.get('negative_prompt', "")
|
| 207 |
+
|
| 208 |
+
# Ensure dimensions are divisible by 8
|
| 209 |
+
width = (width // 8) * 8
|
| 210 |
+
height = (height // 8) * 8
|
| 211 |
+
|
| 212 |
+
print(f"\nGenerating image...")
|
| 213 |
+
print(f"Prompt: {prompt}")
|
| 214 |
+
if negative_prompt:
|
| 215 |
+
print(f"Negative prompt: {negative_prompt}")
|
| 216 |
+
print(f"Size: {width}x{height}")
|
| 217 |
+
print(f"Steps: {num_inference_steps}")
|
| 218 |
+
print(f"Guidance scale: {guidance_scale}")
|
| 219 |
+
|
| 220 |
+
# Set random seed
|
| 221 |
+
generator = None
|
| 222 |
+
if seed is not None:
|
| 223 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 224 |
+
print(f"Seed: {seed}")
|
| 225 |
+
|
| 226 |
+
# Generate image
|
| 227 |
+
result = self.pipeline(
|
| 228 |
+
prompt=prompt,
|
| 229 |
+
negative_prompt=negative_prompt,
|
| 230 |
+
height=height,
|
| 231 |
+
width=width,
|
| 232 |
+
num_inference_steps=num_inference_steps,
|
| 233 |
+
guidance_scale=guidance_scale,
|
| 234 |
+
eta=eta,
|
| 235 |
+
generator=generator,
|
| 236 |
+
output_type=output_type,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
image = result[0] if isinstance(result, (list, tuple)) else result
|
| 240 |
+
|
| 241 |
+
print("\n✓ Image generated successfully!")
|
| 242 |
+
|
| 243 |
+
return image
|
| 244 |
+
|
| 245 |
+
def generate_batch(
|
| 246 |
+
self,
|
| 247 |
+
prompts: List[str],
|
| 248 |
+
negative_prompt: Optional[str] = None,
|
| 249 |
+
width: int = 512,
|
| 250 |
+
height: int = 512,
|
| 251 |
+
num_inference_steps: int = 50,
|
| 252 |
+
guidance_scale: float = 7.5,
|
| 253 |
+
seeds: Optional[List[int]] = None,
|
| 254 |
+
) -> List[Image.Image]:
|
| 255 |
+
"""
|
| 256 |
+
Generate multiple images from prompts
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
prompts: List of text prompts
|
| 260 |
+
negative_prompt: Negative prompt for all images
|
| 261 |
+
width: Image width
|
| 262 |
+
height: Image height
|
| 263 |
+
num_inference_steps: Number of denoising steps
|
| 264 |
+
guidance_scale: Guidance scale
|
| 265 |
+
seeds: Random seeds for each image
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
List of generated PIL Images
|
| 269 |
+
"""
|
| 270 |
+
images = []
|
| 271 |
+
|
| 272 |
+
for i, prompt in enumerate(prompts):
|
| 273 |
+
seed = seeds[i] if seeds and i < len(seeds) else None
|
| 274 |
+
|
| 275 |
+
print(f"\n{'='*50}")
|
| 276 |
+
print(f"Generating image {i+1}/{len(prompts)}")
|
| 277 |
+
print(f"{'='*50}")
|
| 278 |
+
|
| 279 |
+
image = self.generate(
|
| 280 |
+
prompt=prompt,
|
| 281 |
+
negative_prompt=negative_prompt,
|
| 282 |
+
width=width,
|
| 283 |
+
height=height,
|
| 284 |
+
num_inference_steps=num_inference_steps,
|
| 285 |
+
guidance_scale=guidance_scale,
|
| 286 |
+
seed=seed,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
images.append(image)
|
| 290 |
+
|
| 291 |
+
# Clear memory between generations
|
| 292 |
+
gc.collect()
|
| 293 |
+
|
| 294 |
+
return images
|
| 295 |
+
|
| 296 |
+
def get_model_info(self) -> dict:
|
| 297 |
+
"""Get model information"""
|
| 298 |
+
unet_params = sum(p.numel() for p in self.pipeline.unet.parameters())
|
| 299 |
+
vae_params = sum(p.numel() for p in self.pipeline.vae.parameters())
|
| 300 |
+
|
| 301 |
+
info = {
|
| 302 |
+
'name': self.config['model']['name'],
|
| 303 |
+
'version': self.config['model']['version'],
|
| 304 |
+
'unet_parameters': f"{unet_params:,}",
|
| 305 |
+
'device': self.device,
|
| 306 |
+
'dtype': str(self.pipeline.dtype),
|
| 307 |
+
'default_resolution': f"{self.config['generation']['width']}x{self.config['generation']['height']}",
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
return info
|
| 311 |
+
|
| 312 |
+
def clear_memory(self):
|
| 313 |
+
"""Clear GPU/CPU memory"""
|
| 314 |
+
gc.collect()
|
| 315 |
+
if torch.cuda.is_available():
|
| 316 |
+
torch.cuda.empty_cache()
|
| 317 |
+
print("Memory cleared")
|
bytedream/model.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream Model Architecture
|
| 3 |
+
Complete implementation of UNet, VAE, and Text Encoder for diffusion-based image generation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Optional, Tuple, Union
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResnetBlock2D(nn.Module):
|
| 14 |
+
"""Residual block for 2D convolutions"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
in_channels: int,
|
| 19 |
+
out_channels: int,
|
| 20 |
+
temb_channels: Optional[int] = None,
|
| 21 |
+
groups: int = 32,
|
| 22 |
+
eps: float = 1e-6,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 27 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 28 |
+
|
| 29 |
+
if temb_channels is not None:
|
| 30 |
+
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
| 31 |
+
|
| 32 |
+
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 33 |
+
self.dropout = nn.Dropout(0.0)
|
| 34 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 35 |
+
|
| 36 |
+
self.nonlinearity = nn.SiLU(inplace=True)
|
| 37 |
+
|
| 38 |
+
if in_channels != out_channels:
|
| 39 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 40 |
+
else:
|
| 41 |
+
self.conv_shortcut = None
|
| 42 |
+
|
| 43 |
+
def forward(
|
| 44 |
+
self,
|
| 45 |
+
hidden_states: torch.Tensor,
|
| 46 |
+
temb: Optional[torch.Tensor] = None,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
x = hidden_states
|
| 49 |
+
|
| 50 |
+
x = self.norm1(x)
|
| 51 |
+
x = self.nonlinearity(x)
|
| 52 |
+
x = self.conv1(x)
|
| 53 |
+
|
| 54 |
+
if temb is not None:
|
| 55 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
| 56 |
+
x = x + temb
|
| 57 |
+
|
| 58 |
+
x = self.norm2(x)
|
| 59 |
+
x = self.nonlinearity(x)
|
| 60 |
+
x = self.dropout(x)
|
| 61 |
+
x = self.conv2(x)
|
| 62 |
+
|
| 63 |
+
if self.conv_shortcut is not None:
|
| 64 |
+
hidden_states = self.conv_shortcut(hidden_states)
|
| 65 |
+
|
| 66 |
+
return x + hidden_states
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class AttentionBlock(nn.Module):
|
| 70 |
+
"""Cross-attention block for text-conditioned generation"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
query_dim: int,
|
| 75 |
+
cross_attention_dim: Optional[int] = None,
|
| 76 |
+
num_heads: int = 8,
|
| 77 |
+
head_dim: Optional[int] = None,
|
| 78 |
+
eps: float = 1e-6,
|
| 79 |
+
):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
inner_dim = num_heads * head_dim if head_dim is not None else query_dim
|
| 83 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 84 |
+
|
| 85 |
+
self.num_heads = num_heads
|
| 86 |
+
self.head_dim = head_dim if head_dim is not None else query_dim // num_heads
|
| 87 |
+
|
| 88 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 89 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 90 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 91 |
+
|
| 92 |
+
self.to_out = nn.ModuleList([
|
| 93 |
+
nn.Linear(inner_dim, query_dim),
|
| 94 |
+
nn.Dropout(0.0)
|
| 95 |
+
])
|
| 96 |
+
|
| 97 |
+
self.norm = nn.LayerNorm(query_dim, eps=eps)
|
| 98 |
+
|
| 99 |
+
def forward(
|
| 100 |
+
self,
|
| 101 |
+
hidden_states: torch.Tensor,
|
| 102 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 103 |
+
) -> torch.Tensor:
|
| 104 |
+
residual = hidden_states
|
| 105 |
+
|
| 106 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 107 |
+
|
| 108 |
+
query = self.to_q(hidden_states)
|
| 109 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
| 110 |
+
key = self.to_k(encoder_hidden_states)
|
| 111 |
+
value = self.to_v(encoder_hidden_states)
|
| 112 |
+
|
| 113 |
+
# Multi-head attention
|
| 114 |
+
query = query.reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
|
| 115 |
+
key = key.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 116 |
+
value = value.reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 117 |
+
|
| 118 |
+
# Scaled dot-product attention
|
| 119 |
+
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 120 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 121 |
+
|
| 122 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 123 |
+
attn_output = attn_output.transpose(1, 2).reshape(batch_size, sequence_length, -1)
|
| 124 |
+
|
| 125 |
+
# Output projection
|
| 126 |
+
for layer in self.to_out:
|
| 127 |
+
attn_output = layer(attn_output)
|
| 128 |
+
|
| 129 |
+
return residual + attn_output
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class DownBlock2D(nn.Module):
|
| 133 |
+
"""Downsampling block"""
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
in_channels: int,
|
| 138 |
+
out_channels: int,
|
| 139 |
+
temb_channels: int,
|
| 140 |
+
num_layers: int = 1,
|
| 141 |
+
add_downsample: bool = True,
|
| 142 |
+
has_cross_attention: bool = False,
|
| 143 |
+
cross_attention_dim: Optional[int] = None,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
resnets = []
|
| 148 |
+
attentions = []
|
| 149 |
+
|
| 150 |
+
for i in range(num_layers):
|
| 151 |
+
in_ch = in_channels if i == 0 else out_channels
|
| 152 |
+
|
| 153 |
+
resnets.append(ResnetBlock2D(
|
| 154 |
+
in_channels=in_ch,
|
| 155 |
+
out_channels=out_channels,
|
| 156 |
+
temb_channels=temb_channels,
|
| 157 |
+
))
|
| 158 |
+
|
| 159 |
+
if has_cross_attention:
|
| 160 |
+
attentions.append(AttentionBlock(
|
| 161 |
+
query_dim=out_channels,
|
| 162 |
+
cross_attention_dim=cross_attention_dim,
|
| 163 |
+
num_heads=8,
|
| 164 |
+
head_dim=out_channels // 8,
|
| 165 |
+
))
|
| 166 |
+
else:
|
| 167 |
+
attentions.append(None)
|
| 168 |
+
|
| 169 |
+
self.resnets = nn.ModuleList(resnets)
|
| 170 |
+
self.attentions = nn.ModuleList(attentions)
|
| 171 |
+
|
| 172 |
+
if add_downsample:
|
| 173 |
+
self.downsamplers = nn.ModuleList([
|
| 174 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 175 |
+
])
|
| 176 |
+
else:
|
| 177 |
+
self.downsamplers = None
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
hidden_states: torch.Tensor,
|
| 182 |
+
temb: Optional[torch.Tensor] = None,
|
| 183 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
output_states = ()
|
| 186 |
+
|
| 187 |
+
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
| 188 |
+
hidden_states = resnet(hidden_states, temb)
|
| 189 |
+
|
| 190 |
+
if attn is not None and encoder_hidden_states is not None:
|
| 191 |
+
hidden_states = attn(hidden_states, encoder_hidden_states)
|
| 192 |
+
|
| 193 |
+
output_states += (hidden_states,)
|
| 194 |
+
|
| 195 |
+
if self.downsamplers is not None:
|
| 196 |
+
for downsampler in self.downsamplers:
|
| 197 |
+
hidden_states = downsampler(hidden_states)
|
| 198 |
+
output_states += (hidden_states,)
|
| 199 |
+
|
| 200 |
+
return hidden_states, output_states
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class UpBlock2D(nn.Module):
|
| 204 |
+
"""Upsampling block"""
|
| 205 |
+
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
in_channels: int,
|
| 209 |
+
out_channels: int,
|
| 210 |
+
prev_output_channel: int,
|
| 211 |
+
temb_channels: int,
|
| 212 |
+
num_layers: int = 1,
|
| 213 |
+
add_upsample: bool = True,
|
| 214 |
+
has_cross_attention: bool = False,
|
| 215 |
+
cross_attention_dim: Optional[int] = None,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
resnets = []
|
| 220 |
+
attentions = []
|
| 221 |
+
|
| 222 |
+
for i in range(num_layers):
|
| 223 |
+
in_ch = in_channels if i == 0 else out_channels
|
| 224 |
+
mix_ch = prev_output_channel if i == num_layers - 1 else out_channels
|
| 225 |
+
|
| 226 |
+
resnets.append(ResnetBlock2D(
|
| 227 |
+
in_channels=in_ch + mix_ch,
|
| 228 |
+
out_channels=out_channels,
|
| 229 |
+
temb_channels=temb_channels,
|
| 230 |
+
))
|
| 231 |
+
|
| 232 |
+
if has_cross_attention:
|
| 233 |
+
attentions.append(AttentionBlock(
|
| 234 |
+
query_dim=out_channels,
|
| 235 |
+
cross_attention_dim=cross_attention_dim,
|
| 236 |
+
num_heads=8,
|
| 237 |
+
head_dim=out_channels // 8,
|
| 238 |
+
))
|
| 239 |
+
else:
|
| 240 |
+
attentions.append(None)
|
| 241 |
+
|
| 242 |
+
self.resnets = nn.ModuleList(resnets)
|
| 243 |
+
self.attentions = nn.ModuleList(attentions)
|
| 244 |
+
|
| 245 |
+
if add_upsample:
|
| 246 |
+
self.upsamplers = nn.ModuleList([
|
| 247 |
+
nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
| 248 |
+
])
|
| 249 |
+
else:
|
| 250 |
+
self.upsamplers = None
|
| 251 |
+
|
| 252 |
+
def forward(
|
| 253 |
+
self,
|
| 254 |
+
hidden_states: torch.Tensor,
|
| 255 |
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
| 256 |
+
temb: Optional[torch.Tensor] = None,
|
| 257 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 258 |
+
) -> torch.Tensor:
|
| 259 |
+
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
| 260 |
+
# Skip connection from U-Net downsampling path
|
| 261 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple[i]], dim=1)
|
| 262 |
+
|
| 263 |
+
hidden_states = resnet(hidden_states, temb)
|
| 264 |
+
|
| 265 |
+
if attn is not None and encoder_hidden_states is not None:
|
| 266 |
+
hidden_states = attn(hidden_states, encoder_hidden_states)
|
| 267 |
+
|
| 268 |
+
if self.upsamplers is not None:
|
| 269 |
+
for upsampler in self.upsamplers:
|
| 270 |
+
hidden_states = upsampler(hidden_states)
|
| 271 |
+
|
| 272 |
+
return hidden_states
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class UNet2DConditionModel(nn.Module):
|
| 276 |
+
"""
|
| 277 |
+
Main UNet architecture for diffusion-based image generation
|
| 278 |
+
Handles noise prediction conditioned on text embeddings
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
in_channels: int = 4,
|
| 284 |
+
out_channels: int = 4,
|
| 285 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
| 286 |
+
layers_per_block: int = 2,
|
| 287 |
+
attention_head_dim: int = 8,
|
| 288 |
+
cross_attention_dim: int = 768,
|
| 289 |
+
use_linear_projection: bool = True,
|
| 290 |
+
):
|
| 291 |
+
super().__init__()
|
| 292 |
+
|
| 293 |
+
self.in_channels = in_channels
|
| 294 |
+
self.block_out_channels = block_out_channels
|
| 295 |
+
self.layers_per_block = layers_per_block
|
| 296 |
+
self.cross_attention_dim = cross_attention_dim
|
| 297 |
+
|
| 298 |
+
# Time embedding
|
| 299 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 300 |
+
self.time_proj = nn.Sequential(
|
| 301 |
+
nn.Linear(block_out_channels[0], time_embed_dim),
|
| 302 |
+
nn.SiLU(inplace=True),
|
| 303 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Input convolution
|
| 307 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 308 |
+
|
| 309 |
+
# Down blocks
|
| 310 |
+
self.down_blocks = nn.ModuleList([])
|
| 311 |
+
output_channel = block_out_channels[0]
|
| 312 |
+
|
| 313 |
+
for i, down_block_type in enumerate(["down", "down", "down", "down"]):
|
| 314 |
+
input_channel = output_channel
|
| 315 |
+
output_channel = block_out_channels[i]
|
| 316 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 317 |
+
|
| 318 |
+
down_block = DownBlock2D(
|
| 319 |
+
in_channels=input_channel,
|
| 320 |
+
out_channels=output_channel,
|
| 321 |
+
temb_channels=time_embed_dim,
|
| 322 |
+
num_layers=layers_per_block,
|
| 323 |
+
add_downsample=not is_final_block,
|
| 324 |
+
has_cross_attention=True,
|
| 325 |
+
cross_attention_dim=cross_attention_dim,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.down_blocks.append(down_block)
|
| 329 |
+
|
| 330 |
+
# Middle blocks
|
| 331 |
+
self.mid_block = nn.ModuleList([
|
| 332 |
+
ResnetBlock2D(
|
| 333 |
+
in_channels=block_out_channels[-1],
|
| 334 |
+
out_channels=block_out_channels[-1],
|
| 335 |
+
temb_channels=time_embed_dim,
|
| 336 |
+
),
|
| 337 |
+
AttentionBlock(
|
| 338 |
+
query_dim=block_out_channels[-1],
|
| 339 |
+
cross_attention_dim=cross_attention_dim,
|
| 340 |
+
num_heads=attention_head_dim,
|
| 341 |
+
head_dim=block_out_channels[-1] // attention_head_dim,
|
| 342 |
+
),
|
| 343 |
+
ResnetBlock2D(
|
| 344 |
+
in_channels=block_out_channels[-1],
|
| 345 |
+
out_channels=block_out_channels[-1],
|
| 346 |
+
temb_channels=time_embed_dim,
|
| 347 |
+
),
|
| 348 |
+
])
|
| 349 |
+
|
| 350 |
+
# Up blocks
|
| 351 |
+
self.up_blocks = nn.ModuleList([])
|
| 352 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 353 |
+
|
| 354 |
+
for i, up_block_type in enumerate(["up", "up", "up", "up"]):
|
| 355 |
+
prev_output_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 356 |
+
output_channel = reversed_block_out_channels[i]
|
| 357 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 358 |
+
|
| 359 |
+
up_block = UpBlock2D(
|
| 360 |
+
in_channels=reversed_block_out_channels[i - 1] if i > 0 else reversed_block_out_channels[0],
|
| 361 |
+
out_channels=output_channel,
|
| 362 |
+
prev_output_channel=prev_output_channel,
|
| 363 |
+
temb_channels=time_embed_dim,
|
| 364 |
+
num_layers=layers_per_block + 1,
|
| 365 |
+
add_upsample=not is_final_block,
|
| 366 |
+
has_cross_attention=True,
|
| 367 |
+
cross_attention_dim=cross_attention_dim,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
self.up_blocks.append(up_block)
|
| 371 |
+
|
| 372 |
+
# Output
|
| 373 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_channels=block_out_channels[0], eps=1e-6)
|
| 374 |
+
self.conv_act = nn.SiLU(inplace=True)
|
| 375 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, stride=1, padding=1)
|
| 376 |
+
|
| 377 |
+
def forward(
|
| 378 |
+
self,
|
| 379 |
+
sample: torch.Tensor,
|
| 380 |
+
timestep: torch.Tensor,
|
| 381 |
+
encoder_hidden_states: torch.Tensor,
|
| 382 |
+
) -> torch.Tensor:
|
| 383 |
+
# Time embedding
|
| 384 |
+
timesteps_proj = self.time_proj(timestep)
|
| 385 |
+
temb = timesteps_proj
|
| 386 |
+
|
| 387 |
+
# Initial convolution
|
| 388 |
+
hidden_states = self.conv_in(sample)
|
| 389 |
+
|
| 390 |
+
# Down sampling path
|
| 391 |
+
down_block_res_samples = (hidden_states,)
|
| 392 |
+
|
| 393 |
+
for downsample_block in self.down_blocks:
|
| 394 |
+
hidden_states, res_samples = downsample_block(
|
| 395 |
+
hidden_states=hidden_states,
|
| 396 |
+
temb=temb,
|
| 397 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 398 |
+
)
|
| 399 |
+
down_block_res_samples += res_samples
|
| 400 |
+
|
| 401 |
+
# Middle
|
| 402 |
+
for layer in self.mid_block:
|
| 403 |
+
if isinstance(layer, ResnetBlock2D):
|
| 404 |
+
hidden_states = layer(hidden_states, temb)
|
| 405 |
+
else:
|
| 406 |
+
hidden_states = layer(hidden_states, encoder_hidden_states)
|
| 407 |
+
|
| 408 |
+
# Up sampling path
|
| 409 |
+
for upsample_block in self.up_blocks:
|
| 410 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
| 411 |
+
down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
|
| 412 |
+
|
| 413 |
+
hidden_states = upsample_block(
|
| 414 |
+
hidden_states=hidden_states,
|
| 415 |
+
res_hidden_states_tuple=res_samples,
|
| 416 |
+
temb=temb,
|
| 417 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Output
|
| 421 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
| 422 |
+
hidden_states = self.conv_act(hidden_states)
|
| 423 |
+
hidden_states = self.conv_out(hidden_states)
|
| 424 |
+
|
| 425 |
+
return hidden_states
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class AutoencoderKL(nn.Module):
|
| 429 |
+
"""
|
| 430 |
+
Variational Autoencoder for image compression and reconstruction
|
| 431 |
+
Compresses images to latent space for efficient diffusion
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
def __init__(
|
| 435 |
+
self,
|
| 436 |
+
in_channels: int = 3,
|
| 437 |
+
out_channels: int = 3,
|
| 438 |
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",) * 4,
|
| 439 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",) * 4,
|
| 440 |
+
latent_channels: int = 4,
|
| 441 |
+
sample_size: int = 512,
|
| 442 |
+
):
|
| 443 |
+
super().__init__()
|
| 444 |
+
|
| 445 |
+
self.sample_size = sample_size
|
| 446 |
+
|
| 447 |
+
# Encoder
|
| 448 |
+
self.encoder = nn.ModuleList()
|
| 449 |
+
channels = [in_channels, 128, 256, 512, 512]
|
| 450 |
+
|
| 451 |
+
for i in range(len(down_block_types)):
|
| 452 |
+
block = nn.Sequential(
|
| 453 |
+
nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=2, padding=1),
|
| 454 |
+
nn.GroupNorm(num_channels=channels[i+1], num_channels=channels[i+1], eps=1e-6),
|
| 455 |
+
nn.SiLU(inplace=True),
|
| 456 |
+
)
|
| 457 |
+
self.encoder.append(block)
|
| 458 |
+
|
| 459 |
+
# Latent space projection
|
| 460 |
+
self.quant_conv = nn.Conv2d(512, latent_channels * 2, kernel_size=1)
|
| 461 |
+
|
| 462 |
+
# Decoder
|
| 463 |
+
self.decoder = nn.ModuleList()
|
| 464 |
+
decoder_channels = [latent_channels, 512, 512, 256, 128]
|
| 465 |
+
|
| 466 |
+
for i in range(len(up_block_types)):
|
| 467 |
+
block = nn.Sequential(
|
| 468 |
+
nn.ConvTranspose2d(decoder_channels[i], decoder_channels[i+1], kernel_size=4, stride=2, padding=1),
|
| 469 |
+
nn.GroupNorm(num_channels=decoder_channels[i+1], num_channels=decoder_channels[i+1], eps=1e-6),
|
| 470 |
+
nn.SiLU(inplace=True),
|
| 471 |
+
)
|
| 472 |
+
self.decoder.append(block)
|
| 473 |
+
|
| 474 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, 512, kernel_size=1)
|
| 475 |
+
self.conv_out = nn.Conv2d(128, out_channels, kernel_size=3, stride=1, padding=1)
|
| 476 |
+
|
| 477 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 478 |
+
"""Encode image to latent space"""
|
| 479 |
+
for block in self.encoder:
|
| 480 |
+
x = block(x)
|
| 481 |
+
x = self.quant_conv(x)
|
| 482 |
+
return x
|
| 483 |
+
|
| 484 |
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| 485 |
+
"""Decode from latent space to image"""
|
| 486 |
+
z = self.post_quant_conv(z)
|
| 487 |
+
for block in self.decoder:
|
| 488 |
+
z = block(z)
|
| 489 |
+
z = self.conv_out(z)
|
| 490 |
+
return z
|
| 491 |
+
|
| 492 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 493 |
+
"""Full autoencoder forward pass"""
|
| 494 |
+
encoded = self.encode(x)
|
| 495 |
+
decoded = self.decode(encoded[:, :4]) # Use first 4 channels
|
| 496 |
+
return decoded
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class CLIPTextModel(nn.Module):
|
| 500 |
+
"""
|
| 501 |
+
CLIP text encoder for understanding text prompts
|
| 502 |
+
Extracts semantic features from text for conditioning
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
def __init__(self, model_name: str = "openai/clip-vit-large-patch14", max_length: int = 77):
|
| 506 |
+
super().__init__()
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
from transformers import CLIPTextModel as HFCLIPTextModel, CLIPTokenizer
|
| 510 |
+
|
| 511 |
+
self.model = HFCLIPTextModel.from_pretrained(model_name)
|
| 512 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
| 513 |
+
self.max_length = max_length
|
| 514 |
+
except ImportError:
|
| 515 |
+
print("Warning: transformers not installed. Using dummy text encoder.")
|
| 516 |
+
self.model = None
|
| 517 |
+
self.tokenizer = None
|
| 518 |
+
|
| 519 |
+
def forward(self, text: Union[str, List[str]], device: torch.device = None) -> torch.Tensor:
|
| 520 |
+
"""
|
| 521 |
+
Encode text to embeddings
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
text: Text string or list of strings
|
| 525 |
+
device: Target device for computation
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
Text embeddings tensor
|
| 529 |
+
"""
|
| 530 |
+
if self.model is None:
|
| 531 |
+
# Dummy implementation if transformers not available
|
| 532 |
+
return torch.zeros(1, 77, 768)
|
| 533 |
+
|
| 534 |
+
inputs = self.tokenizer(
|
| 535 |
+
text,
|
| 536 |
+
padding="max_length",
|
| 537 |
+
max_length=self.max_length,
|
| 538 |
+
truncation=True,
|
| 539 |
+
return_tensors="pt",
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if device is not None:
|
| 543 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 544 |
+
|
| 545 |
+
outputs = self.model(**inputs)
|
| 546 |
+
return outputs.last_hidden_state
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def create_unet(config):
|
| 550 |
+
"""Factory function to create UNet from config"""
|
| 551 |
+
unet_config = config['model']['unet']
|
| 552 |
+
return UNet2DConditionModel(
|
| 553 |
+
in_channels=unet_config['in_channels'],
|
| 554 |
+
out_channels=unet_config['out_channels'],
|
| 555 |
+
block_out_channels=tuple(unet_config['block_out_channels']),
|
| 556 |
+
layers_per_block=unet_config['layers_per_block'],
|
| 557 |
+
attention_head_dim=unet_config['attention_head_dim'],
|
| 558 |
+
cross_attention_dim=unet_config['cross_attention_dim'],
|
| 559 |
+
use_linear_projection=unet_config['use_linear_projection'],
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def create_vae(config):
|
| 564 |
+
"""Factory function to create VAE from config"""
|
| 565 |
+
vae_config = config['model']['vae']
|
| 566 |
+
return AutoencoderKL(
|
| 567 |
+
in_channels=vae_config['in_channels'],
|
| 568 |
+
out_channels=vae_config['out_channels'],
|
| 569 |
+
down_block_types=tuple(vae_config['down_block_types']),
|
| 570 |
+
up_block_types=tuple(vae_config['up_block_types']),
|
| 571 |
+
latent_channels=vae_config['latent_channels'],
|
| 572 |
+
sample_size=vae_config['sample_size'],
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def create_text_encoder(config):
|
| 577 |
+
"""Factory function to create text encoder from config"""
|
| 578 |
+
text_config = config['model']['text_encoder']
|
| 579 |
+
return CLIPTextModel(
|
| 580 |
+
model_name=text_config['model'],
|
| 581 |
+
max_length=text_config['max_length'],
|
| 582 |
+
)
|
bytedream/pipeline.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream Pipeline
|
| 3 |
+
Complete diffusion pipeline for text-to-image generation
|
| 4 |
+
Integrates all components: text encoder, UNet, VAE, and scheduler
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Optional, Union, List
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import gc
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ByteDreamPipeline:
|
| 15 |
+
"""
|
| 16 |
+
Complete pipeline for text-to-image generation
|
| 17 |
+
Manages the entire diffusion process from prompt to final image
|
| 18 |
+
Optimized for CPU inference with memory efficiency
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
text_encoder,
|
| 24 |
+
vae,
|
| 25 |
+
unet,
|
| 26 |
+
scheduler,
|
| 27 |
+
device: str = "cpu",
|
| 28 |
+
dtype: torch.dtype = torch.float32,
|
| 29 |
+
):
|
| 30 |
+
self.text_encoder = text_encoder
|
| 31 |
+
self.vae = vae
|
| 32 |
+
self.unet = unet
|
| 33 |
+
self.scheduler = scheduler
|
| 34 |
+
|
| 35 |
+
self.device = torch.device(device)
|
| 36 |
+
self.dtype = dtype
|
| 37 |
+
|
| 38 |
+
# Move models to device
|
| 39 |
+
self._move_models_to_device()
|
| 40 |
+
|
| 41 |
+
# Set models to evaluation mode
|
| 42 |
+
self._set_eval_mode()
|
| 43 |
+
|
| 44 |
+
def _move_models_to_device(self):
|
| 45 |
+
"""Move all models to target device with memory optimization"""
|
| 46 |
+
print(f"Loading models to {self.device}...")
|
| 47 |
+
|
| 48 |
+
if hasattr(self.text_encoder, 'model') and self.text_encoder.model is not None:
|
| 49 |
+
self.text_encoder.model.to(self.device)
|
| 50 |
+
|
| 51 |
+
self.vae.to(self.device)
|
| 52 |
+
self.unet.to(self.device)
|
| 53 |
+
|
| 54 |
+
def _set_eval_mode(self):
|
| 55 |
+
"""Set all models to evaluation mode"""
|
| 56 |
+
if hasattr(self.text_encoder, 'model') and self.text_encoder.model is not None:
|
| 57 |
+
self.text_encoder.model.eval()
|
| 58 |
+
self.vae.eval()
|
| 59 |
+
self.unet.eval()
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def encode_prompt(
|
| 63 |
+
self,
|
| 64 |
+
prompt: Union[str, List[str]],
|
| 65 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 66 |
+
num_images_per_prompt: int = 1,
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Encode text prompts to embeddings
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
prompt: Text prompt or list of prompts
|
| 73 |
+
negative_prompt: Negative prompt for guidance
|
| 74 |
+
num_images_per_prompt: Number of images to generate per prompt
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Text embeddings tensor
|
| 78 |
+
"""
|
| 79 |
+
# Handle batch size
|
| 80 |
+
if isinstance(prompt, str):
|
| 81 |
+
prompt = [prompt]
|
| 82 |
+
|
| 83 |
+
batch_size = len(prompt) * num_images_per_prompt
|
| 84 |
+
|
| 85 |
+
# Encode positive prompt
|
| 86 |
+
text_embeddings = self.text_encoder(prompt, device=self.device)
|
| 87 |
+
text_embeddings = text_embeddings.to(self.dtype)
|
| 88 |
+
|
| 89 |
+
# Encode negative prompt if provided
|
| 90 |
+
if negative_prompt is not None:
|
| 91 |
+
if isinstance(negative_prompt, str):
|
| 92 |
+
negative_prompt = [negative_prompt]
|
| 93 |
+
|
| 94 |
+
uncond_embeddings = self.text_encoder(negative_prompt, device=self.device)
|
| 95 |
+
uncond_embeddings = uncond_embeddings.to(self.dtype)
|
| 96 |
+
|
| 97 |
+
# Concatenate for classifier-free guidance
|
| 98 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 99 |
+
|
| 100 |
+
return text_embeddings
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def decode_latents(self, latents: torch.Tensor) -> Image.Image:
|
| 104 |
+
"""
|
| 105 |
+
Decode latent representation to image
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
latents: Latent space tensor
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
PIL Image
|
| 112 |
+
"""
|
| 113 |
+
# Scale latents
|
| 114 |
+
latents = 1 / 0.18215 * latents
|
| 115 |
+
|
| 116 |
+
# Decode through VAE
|
| 117 |
+
image = self.vae.decode(latents)
|
| 118 |
+
image = torch.clamp(image, -1, 1)
|
| 119 |
+
|
| 120 |
+
# Convert to PIL Image
|
| 121 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 122 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
| 123 |
+
image = (image * 255).round().astype("uint8")
|
| 124 |
+
|
| 125 |
+
return Image.fromarray(image)
|
| 126 |
+
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def prepare_latents(
|
| 129 |
+
self,
|
| 130 |
+
batch_size: int,
|
| 131 |
+
height: int,
|
| 132 |
+
width: int,
|
| 133 |
+
generator: Optional[torch.Generator] = None,
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
"""
|
| 136 |
+
Initialize random noise latents
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
batch_size: Number of images to generate
|
| 140 |
+
height: Image height
|
| 141 |
+
width: Image width
|
| 142 |
+
generator: Random number generator
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Initial noise tensor
|
| 146 |
+
"""
|
| 147 |
+
shape = (batch_size, 4, height // 8, width // 8)
|
| 148 |
+
latents = torch.randn(shape, generator=generator, dtype=self.dtype)
|
| 149 |
+
latents = latents.to(self.device)
|
| 150 |
+
|
| 151 |
+
# Scale initial noise
|
| 152 |
+
latents = latents * self.scheduler.init_noise_scale if hasattr(self.scheduler, 'init_noise_scale') else latents
|
| 153 |
+
|
| 154 |
+
return latents
|
| 155 |
+
|
| 156 |
+
@torch.no_grad()
|
| 157 |
+
def __call__(
|
| 158 |
+
self,
|
| 159 |
+
prompt: Union[str, List[str]],
|
| 160 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 161 |
+
height: int = 512,
|
| 162 |
+
width: int = 512,
|
| 163 |
+
num_inference_steps: int = 50,
|
| 164 |
+
guidance_scale: float = 7.5,
|
| 165 |
+
eta: float = 0.0,
|
| 166 |
+
generator: Optional[torch.Generator] = None,
|
| 167 |
+
output_type: str = "pil",
|
| 168 |
+
return_dict: bool = False,
|
| 169 |
+
) -> Union[List[Image.Image], tuple]:
|
| 170 |
+
"""
|
| 171 |
+
Generate images from text prompts
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
prompt: Text prompt(s) for generation
|
| 175 |
+
negative_prompt: Negative prompt for better quality
|
| 176 |
+
height: Output image height
|
| 177 |
+
width: Output image width
|
| 178 |
+
num_inference_steps: Number of denoising steps
|
| 179 |
+
guidance_scale: Classifier-free guidance scale
|
| 180 |
+
eta: DDIM eta parameter
|
| 181 |
+
generator: Random number generator
|
| 182 |
+
output_type: Output format ("pil" or "tensor")
|
| 183 |
+
return_dict: Whether to return as dictionary
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Generated images or tuple
|
| 187 |
+
"""
|
| 188 |
+
# Default settings
|
| 189 |
+
if negative_prompt is None:
|
| 190 |
+
negative_prompt = ""
|
| 191 |
+
|
| 192 |
+
# Batch size
|
| 193 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 194 |
+
|
| 195 |
+
# Encode prompts
|
| 196 |
+
text_embeddings = self.encode_prompt(
|
| 197 |
+
prompt=prompt,
|
| 198 |
+
negative_prompt=negative_prompt,
|
| 199 |
+
num_images_per_prompt=1,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Prepare timesteps
|
| 203 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 204 |
+
timesteps = self.scheduler.timesteps
|
| 205 |
+
|
| 206 |
+
# Prepare latents
|
| 207 |
+
latents = self.prepare_latents(
|
| 208 |
+
batch_size=batch_size,
|
| 209 |
+
height=height,
|
| 210 |
+
width=width,
|
| 211 |
+
generator=generator,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Denoising loop
|
| 215 |
+
for i, t in enumerate(timesteps):
|
| 216 |
+
# Expand latents for classifier-free guidance
|
| 217 |
+
latent_model_input = torch.cat([latents] * 2) if negative_prompt else latents
|
| 218 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 219 |
+
|
| 220 |
+
# Predict noise
|
| 221 |
+
timestep_tensor = torch.tensor([t], dtype=torch.long, device=self.device)
|
| 222 |
+
noise_pred = self.unet(
|
| 223 |
+
sample=latent_model_input,
|
| 224 |
+
timestep=timestep_tensor,
|
| 225 |
+
encoder_hidden_states=text_embeddings,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Apply classifier-free guidance
|
| 229 |
+
if negative_prompt:
|
| 230 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 231 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 232 |
+
|
| 233 |
+
# Compute previous noisy sample
|
| 234 |
+
latents, _ = self.scheduler.step(
|
| 235 |
+
model_output=noise_pred,
|
| 236 |
+
timestep=t,
|
| 237 |
+
sample=latents,
|
| 238 |
+
eta=eta,
|
| 239 |
+
generator=generator,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Print progress
|
| 243 |
+
if (i + 1) % 10 == 0 or i == len(timesteps) - 1:
|
| 244 |
+
print(f"Step {i+1}/{len(timesteps)}")
|
| 245 |
+
|
| 246 |
+
# Decode to image
|
| 247 |
+
image = self.decode_latents(latents)
|
| 248 |
+
|
| 249 |
+
if output_type != "pil":
|
| 250 |
+
return (image,) if not return_dict else {"images": [image]}
|
| 251 |
+
|
| 252 |
+
return [image] if not return_dict else {"images": [image]}
|
| 253 |
+
|
| 254 |
+
def enable_memory_efficient_mode(self):
|
| 255 |
+
"""Enable memory-efficient mode for CPU inference"""
|
| 256 |
+
# Clear CUDA cache if available
|
| 257 |
+
if torch.cuda.is_available():
|
| 258 |
+
torch.cuda.empty_cache()
|
| 259 |
+
|
| 260 |
+
# Force garbage collection
|
| 261 |
+
gc.collect()
|
| 262 |
+
|
| 263 |
+
print("Memory efficient mode enabled")
|
| 264 |
+
|
| 265 |
+
def optimize_for_cpu(self, threads: int = -1):
|
| 266 |
+
"""
|
| 267 |
+
Optimize pipeline for CPU inference
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
threads: Number of threads to use (-1 for all available)
|
| 271 |
+
"""
|
| 272 |
+
if threads > 0:
|
| 273 |
+
torch.set_num_threads(threads)
|
| 274 |
+
|
| 275 |
+
# Set optimal number of threads
|
| 276 |
+
if threads == -1:
|
| 277 |
+
import os
|
| 278 |
+
torch.set_num_threads(os.cpu_count())
|
| 279 |
+
|
| 280 |
+
print(f"Optimized for CPU with {torch.get_num_threads()} threads")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def create_pipeline(config, device: str = "cpu"):
|
| 284 |
+
"""
|
| 285 |
+
Factory function to create complete pipeline from config
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
config: Configuration dictionary
|
| 289 |
+
device: Target device
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
ByteDreamPipeline instance
|
| 293 |
+
"""
|
| 294 |
+
from .model import create_unet, create_vae, create_text_encoder
|
| 295 |
+
from .scheduler import create_scheduler
|
| 296 |
+
|
| 297 |
+
# Create components
|
| 298 |
+
text_encoder = create_text_encoder(config)
|
| 299 |
+
vae = create_vae(config)
|
| 300 |
+
unet = create_unet(config)
|
| 301 |
+
scheduler = create_scheduler(config)
|
| 302 |
+
|
| 303 |
+
# Create pipeline
|
| 304 |
+
pipeline = ByteDreamPipeline(
|
| 305 |
+
text_encoder=text_encoder,
|
| 306 |
+
vae=vae,
|
| 307 |
+
unet=unet,
|
| 308 |
+
scheduler=scheduler,
|
| 309 |
+
device=device,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return pipeline
|
bytedream/scheduler.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream Diffusion Scheduler
|
| 3 |
+
Implements DDIM (Denoising Diffusion Implicit Models) sampling for fast, high-quality generation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DDIMScheduler:
|
| 13 |
+
"""
|
| 14 |
+
DDIM Scheduler for diffusion sampling
|
| 15 |
+
Provides deterministic sampling with fewer steps than traditional DDPM
|
| 16 |
+
Optimized for CPU inference with efficient computation
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
num_train_timesteps: int = 1000,
|
| 22 |
+
beta_start: float = 0.00085,
|
| 23 |
+
beta_end: float = 0.012,
|
| 24 |
+
beta_schedule: str = "scaled_linear",
|
| 25 |
+
clip_sample: bool = False,
|
| 26 |
+
set_alpha_to_one: bool = False,
|
| 27 |
+
):
|
| 28 |
+
self.num_train_timesteps = num_train_timesteps
|
| 29 |
+
self.beta_start = beta_start
|
| 30 |
+
self.beta_end = beta_end
|
| 31 |
+
self.beta_schedule = beta_schedule
|
| 32 |
+
self.clip_sample = clip_sample
|
| 33 |
+
self.set_alpha_to_one = set_alpha_to_one
|
| 34 |
+
|
| 35 |
+
# Compute betas
|
| 36 |
+
if beta_schedule == "scaled_linear":
|
| 37 |
+
self.betas = torch.linspace(
|
| 38 |
+
beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32
|
| 39 |
+
) ** 2
|
| 40 |
+
elif beta_schedule == "linear":
|
| 41 |
+
self.betas = torch.linspace(
|
| 42 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"Unknown beta schedule: {beta_schedule}")
|
| 46 |
+
|
| 47 |
+
# Compute alphas
|
| 48 |
+
self.alphas = 1.0 - self.betas
|
| 49 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 50 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
| 51 |
+
|
| 52 |
+
# Set timesteps
|
| 53 |
+
self.timesteps = None
|
| 54 |
+
|
| 55 |
+
def set_timesteps(self, num_inference_steps: int) -> None:
|
| 56 |
+
"""
|
| 57 |
+
Set timesteps for inference
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
num_inference_steps: Number of denoising steps
|
| 61 |
+
"""
|
| 62 |
+
step_ratio = self.num_train_timesteps // num_inference_steps
|
| 63 |
+
self.timesteps = (
|
| 64 |
+
(torch.arange(0, num_inference_steps) * step_ratio)
|
| 65 |
+
.round()
|
| 66 |
+
.flip(dims=[0])
|
| 67 |
+
.to(dtype=torch.long)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
|
| 71 |
+
"""Compute variance for the diffusion step"""
|
| 72 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 73 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
| 74 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 75 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 76 |
+
|
| 77 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 78 |
+
return variance
|
| 79 |
+
|
| 80 |
+
def step(
|
| 81 |
+
self,
|
| 82 |
+
model_output: torch.Tensor,
|
| 83 |
+
timestep: int,
|
| 84 |
+
sample: torch.Tensor,
|
| 85 |
+
eta: float = 0.0,
|
| 86 |
+
use_clipped_model_output: bool = False,
|
| 87 |
+
generator: Optional[torch.Generator] = None,
|
| 88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""
|
| 90 |
+
Perform a single denoising step
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
model_output: Predicted noise from UNet
|
| 94 |
+
timestep: Current timestep
|
| 95 |
+
sample: Current noisy sample
|
| 96 |
+
eta: DDIM eta parameter (0 for deterministic)
|
| 97 |
+
use_clipped_model_output: Whether to clip model output
|
| 98 |
+
generator: Random number generator
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tuple of (previous_sample, pred_original_sample)
|
| 102 |
+
"""
|
| 103 |
+
# Get previous timestep
|
| 104 |
+
prev_timestep = timestep - self.num_train_timesteps // len(self.timesteps)
|
| 105 |
+
|
| 106 |
+
# Compute alpha and sigma
|
| 107 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 108 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
| 109 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 110 |
+
|
| 111 |
+
# Compute predicted original sample
|
| 112 |
+
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
| 113 |
+
|
| 114 |
+
if use_clipped_model_output:
|
| 115 |
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
| 116 |
+
|
| 117 |
+
# Compute direction pointing to x_t
|
| 118 |
+
model_output_direction = (1 - alpha_prod_t_prev) ** 0.5
|
| 119 |
+
|
| 120 |
+
# Compute sigma (eta * std_dev)
|
| 121 |
+
variance = self._get_variance(timestep, prev_timestep)
|
| 122 |
+
std_dev_t = eta * variance ** 0.5
|
| 123 |
+
|
| 124 |
+
# Compute x_{t-1}
|
| 125 |
+
pred_sample_direction = alpha_prod_t_prev ** 0.5
|
| 126 |
+
prev_sample = pred_sample_direction * pred_original_sample + model_output_direction * model_output
|
| 127 |
+
|
| 128 |
+
# Add noise if eta > 0
|
| 129 |
+
if eta > 0:
|
| 130 |
+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(model_output.device)
|
| 131 |
+
prev_sample = prev_sample + std_dev_t * noise
|
| 132 |
+
|
| 133 |
+
# Clip if necessary
|
| 134 |
+
if self.clip_sample:
|
| 135 |
+
prev_sample = torch.clamp(prev_sample, -1, 1)
|
| 136 |
+
|
| 137 |
+
return prev_sample, pred_original_sample
|
| 138 |
+
|
| 139 |
+
def add_noise(
|
| 140 |
+
self,
|
| 141 |
+
original_samples: torch.Tensor,
|
| 142 |
+
noise: torch.Tensor,
|
| 143 |
+
timesteps: torch.Tensor,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Add noise to samples (forward diffusion process)
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
original_samples: Original clean samples
|
| 150 |
+
noise: Noise to add
|
| 151 |
+
timesteps: Timesteps for each sample
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Noisy samples
|
| 155 |
+
"""
|
| 156 |
+
alpha_prod_t = self.alphas_cumprod[timesteps].view(-1, 1, 1, 1)
|
| 157 |
+
sqrt_alpha_prod = alpha_prod_t ** 0.5
|
| 158 |
+
sqrt_one_minus_alpha_prod = (1 - alpha_prod_t) ** 0.5
|
| 159 |
+
|
| 160 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 161 |
+
return noisy_samples
|
| 162 |
+
|
| 163 |
+
def scale_model_input(
|
| 164 |
+
self,
|
| 165 |
+
sample: torch.Tensor,
|
| 166 |
+
timestep: Optional[int] = None,
|
| 167 |
+
) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Scale sample by standard deviation (for compatibility)
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
sample: Input sample
|
| 173 |
+
timestep: Current timestep
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Scaled sample
|
| 177 |
+
"""
|
| 178 |
+
return sample
|
| 179 |
+
|
| 180 |
+
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
| 181 |
+
"""Get scalings for boundary condition"""
|
| 182 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 183 |
+
sigma_t = ((1 - alpha_prod_t) * alpha_prod_t / (alpha_prod_t)) ** 0.5
|
| 184 |
+
c_out = -sigma_t
|
| 185 |
+
c_in = 1 / (alpha_prod_t ** 0.5)
|
| 186 |
+
return c_out, c_in
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class EulerDiscreteScheduler:
|
| 190 |
+
"""
|
| 191 |
+
Euler discretization scheduler for ODE-based sampling
|
| 192 |
+
Alternative to DDIM with different sampling characteristics
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
num_train_timesteps: int = 1000,
|
| 198 |
+
beta_start: float = 0.00085,
|
| 199 |
+
beta_end: float = 0.012,
|
| 200 |
+
beta_schedule: str = "scaled_linear",
|
| 201 |
+
):
|
| 202 |
+
self.num_train_timesteps = num_train_timesteps
|
| 203 |
+
self.beta_start = beta_start
|
| 204 |
+
self.beta_end = beta_end
|
| 205 |
+
self.beta_schedule = beta_schedule
|
| 206 |
+
|
| 207 |
+
# Compute betas and sigmas
|
| 208 |
+
betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps) ** 2
|
| 209 |
+
alphas = 1.0 - betas
|
| 210 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 211 |
+
|
| 212 |
+
self.sigmas = torch.cat([
|
| 213 |
+
torch.ones(1),
|
| 214 |
+
((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
| 215 |
+
])
|
| 216 |
+
|
| 217 |
+
self.timesteps = None
|
| 218 |
+
|
| 219 |
+
def set_timesteps(self, num_inference_steps: int) -> None:
|
| 220 |
+
"""Set timesteps for Euler sampling"""
|
| 221 |
+
step_ratio = len(self.sigmas) // num_inference_steps
|
| 222 |
+
self.timesteps = torch.arange(0, num_inference_steps) * step_ratio
|
| 223 |
+
self.timesteps = self.timesteps.flip(0)
|
| 224 |
+
|
| 225 |
+
def step(
|
| 226 |
+
self,
|
| 227 |
+
model_output: torch.Tensor,
|
| 228 |
+
timestep: int,
|
| 229 |
+
sample: torch.Tensor,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
"""Perform Euler step"""
|
| 232 |
+
sigma_from = self.sigmas[timestep]
|
| 233 |
+
sigma_to = self.sigmas[timestep + 1] if timestep + 1 < len(self.sigmas) else torch.tensor(0.0)
|
| 234 |
+
|
| 235 |
+
sample_normalized = sample / ((sigma_from ** 2 + 1) ** 0.5)
|
| 236 |
+
derivative = (sample - sample_normalized) / sigma_from
|
| 237 |
+
|
| 238 |
+
dt = sigma_to - sigma_from
|
| 239 |
+
prev_sample = sample + derivative * dt
|
| 240 |
+
|
| 241 |
+
return prev_sample
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def create_scheduler(config):
|
| 245 |
+
"""
|
| 246 |
+
Factory function to create scheduler from config
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
config: Configuration dictionary
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Scheduler instance
|
| 253 |
+
"""
|
| 254 |
+
sched_config = config['model']['scheduler']
|
| 255 |
+
|
| 256 |
+
if sched_config['name'] == 'DDIM':
|
| 257 |
+
return DDIMScheduler(
|
| 258 |
+
num_train_timesteps=sched_config['num_train_timesteps'],
|
| 259 |
+
beta_start=sched_config['beta_start'],
|
| 260 |
+
beta_end=sched_config['beta_end'],
|
| 261 |
+
beta_schedule=sched_config['beta_schedule'],
|
| 262 |
+
clip_sample=sched_config['clip_sample'],
|
| 263 |
+
set_alpha_to_one=sched_config['set_alpha_to_one'],
|
| 264 |
+
)
|
| 265 |
+
elif sched_config['name'] == 'EulerDiscrete':
|
| 266 |
+
return EulerDiscreteScheduler(
|
| 267 |
+
num_train_timesteps=sched_config['num_train_timesteps'],
|
| 268 |
+
beta_start=sched_config['beta_start'],
|
| 269 |
+
beta_end=sched_config['beta_end'],
|
| 270 |
+
beta_schedule=sched_config['beta_schedule'],
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
raise ValueError(f"Unknown scheduler: {sched_config['name']}")
|
bytedream/utils.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream Utilities
|
| 3 |
+
Helper functions for image processing, model management, and optimization
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import hashlib
|
| 11 |
+
import json
|
| 12 |
+
from typing import Optional, Tuple, List
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_image(image_path: str) -> Image.Image:
|
| 16 |
+
"""
|
| 17 |
+
Load image from file
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
image_path: Path to image file
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
PIL Image object
|
| 24 |
+
"""
|
| 25 |
+
path = Path(image_path)
|
| 26 |
+
|
| 27 |
+
if not path.exists():
|
| 28 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
image = Image.open(path).convert('RGB')
|
| 32 |
+
return image
|
| 33 |
+
except Exception as e:
|
| 34 |
+
raise IOError(f"Error loading image: {e}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def save_image(
|
| 38 |
+
image: Image.Image,
|
| 39 |
+
output_path: str,
|
| 40 |
+
format: str = None,
|
| 41 |
+
quality: int = 95,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Save image to file
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
image: PIL Image to save
|
| 48 |
+
output_path: Output file path
|
| 49 |
+
format: Image format (PNG, JPEG, etc.)
|
| 50 |
+
quality: JPEG quality (1-100)
|
| 51 |
+
"""
|
| 52 |
+
path = Path(output_path)
|
| 53 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
# Auto-detect format from extension
|
| 56 |
+
if format is None:
|
| 57 |
+
format = path.suffix.upper().replace('.', '')
|
| 58 |
+
if format == 'JPG':
|
| 59 |
+
format = 'JPEG'
|
| 60 |
+
|
| 61 |
+
# Save with appropriate settings
|
| 62 |
+
if format == 'JPEG':
|
| 63 |
+
image.save(path, format=format, quality=quality, optimize=True)
|
| 64 |
+
else:
|
| 65 |
+
image.save(path, format=format, optimize=True)
|
| 66 |
+
|
| 67 |
+
print(f"Image saved to: {path}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def resize_image(
|
| 71 |
+
image: Image.Image,
|
| 72 |
+
width: Optional[int] = None,
|
| 73 |
+
height: Optional[int] = None,
|
| 74 |
+
maintain_aspect: bool = True,
|
| 75 |
+
) -> Image.Image:
|
| 76 |
+
"""
|
| 77 |
+
Resize image to specified dimensions
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
image: Input image
|
| 81 |
+
width: Target width
|
| 82 |
+
height: Target height
|
| 83 |
+
maintain_aspect: Maintain aspect ratio
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Resized PIL Image
|
| 87 |
+
"""
|
| 88 |
+
orig_width, orig_height = image.size
|
| 89 |
+
|
| 90 |
+
if width is None and height is None:
|
| 91 |
+
return image
|
| 92 |
+
|
| 93 |
+
if maintain_aspect:
|
| 94 |
+
if width and height:
|
| 95 |
+
# Fit within bounding box
|
| 96 |
+
ratio = min(width / orig_width, height / orig_height)
|
| 97 |
+
new_width = int(orig_width * ratio)
|
| 98 |
+
new_height = int(orig_height * ratio)
|
| 99 |
+
elif width:
|
| 100 |
+
ratio = width / orig_width
|
| 101 |
+
new_width = width
|
| 102 |
+
new_height = int(orig_height * ratio)
|
| 103 |
+
else:
|
| 104 |
+
ratio = height / orig_height
|
| 105 |
+
new_width = int(orig_width * ratio)
|
| 106 |
+
new_height = height
|
| 107 |
+
else:
|
| 108 |
+
new_width = width if width else orig_width
|
| 109 |
+
new_height = height if height else orig_height
|
| 110 |
+
|
| 111 |
+
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 112 |
+
return resized
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def center_crop(image: Image.Image, width: int, height: int) -> Image.Image:
|
| 116 |
+
"""
|
| 117 |
+
Center crop image to specified dimensions
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
image: Input image
|
| 121 |
+
width: Crop width
|
| 122 |
+
height: Crop height
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Cropped PIL Image
|
| 126 |
+
"""
|
| 127 |
+
orig_width, orig_height = image.size
|
| 128 |
+
|
| 129 |
+
left = (orig_width - width) // 2
|
| 130 |
+
top = (orig_height - height) // 2
|
| 131 |
+
right = left + width
|
| 132 |
+
bottom = top + height
|
| 133 |
+
|
| 134 |
+
cropped = image.crop((left, top, right, bottom))
|
| 135 |
+
return cropped
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def image_to_tensor(image: Image.Image) -> torch.Tensor:
|
| 139 |
+
"""
|
| 140 |
+
Convert PIL Image to PyTorch tensor
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
image: PIL Image
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Normalized tensor in range [-1, 1]
|
| 147 |
+
"""
|
| 148 |
+
# Convert to numpy array
|
| 149 |
+
img_array = np.array(image).astype(np.float32)
|
| 150 |
+
|
| 151 |
+
# Normalize to [0, 1]
|
| 152 |
+
img_array = img_array / 255.0
|
| 153 |
+
|
| 154 |
+
# Normalize to [-1, 1]
|
| 155 |
+
img_array = 2.0 * img_array - 1.0
|
| 156 |
+
|
| 157 |
+
# Convert to tensor and rearrange to CHW format
|
| 158 |
+
tensor = torch.from_numpy(img_array).permute(2, 0, 1)
|
| 159 |
+
|
| 160 |
+
return tensor
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
|
| 164 |
+
"""
|
| 165 |
+
Convert PyTorch tensor to PIL Image
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
tensor: Tensor in range [-1, 1], shape (B, C, H, W) or (C, H, W)
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
PIL Image
|
| 172 |
+
"""
|
| 173 |
+
# Handle batch dimension
|
| 174 |
+
if tensor.dim() == 4:
|
| 175 |
+
tensor = tensor[0]
|
| 176 |
+
|
| 177 |
+
# Convert from CHW to HWC
|
| 178 |
+
img_array = tensor.cpu().numpy().transpose(1, 2, 0)
|
| 179 |
+
|
| 180 |
+
# Clip to valid range
|
| 181 |
+
img_array = np.clip(img_array, -1, 1)
|
| 182 |
+
|
| 183 |
+
# Convert from [-1, 1] to [0, 255]
|
| 184 |
+
img_array = ((img_array + 1.0) * 127.5).round().astype(np.uint8)
|
| 185 |
+
|
| 186 |
+
# Ensure RGB format
|
| 187 |
+
if img_array.shape[2] == 1:
|
| 188 |
+
img_array = np.repeat(img_array, 3, axis=2)
|
| 189 |
+
|
| 190 |
+
image = Image.fromarray(img_array)
|
| 191 |
+
return image
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def generate_prompt_hash(prompt: str) -> str:
|
| 195 |
+
"""
|
| 196 |
+
Generate unique hash for a prompt
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
prompt: Text prompt
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Short hash string
|
| 203 |
+
"""
|
| 204 |
+
hash_object = hashlib.md5(prompt.encode())
|
| 205 |
+
return hash_object.hexdigest()[:8]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_model_statistics(model: torch.nn.Module) -> dict:
|
| 209 |
+
"""
|
| 210 |
+
Get model parameter statistics
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
model: PyTorch model
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Dictionary with parameter counts
|
| 217 |
+
"""
|
| 218 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 219 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 220 |
+
|
| 221 |
+
param_size = 0
|
| 222 |
+
for param in model.parameters():
|
| 223 |
+
param_size += param.numel() * param.element_size()
|
| 224 |
+
|
| 225 |
+
buffer_size = 0
|
| 226 |
+
for buffer in model.buffers():
|
| 227 |
+
buffer_size += buffer.numel() * buffer.element_size()
|
| 228 |
+
|
| 229 |
+
size_mb = (param_size + buffer_size) / 1024 ** 2
|
| 230 |
+
|
| 231 |
+
stats = {
|
| 232 |
+
'total_parameters': total_params,
|
| 233 |
+
'trainable_parameters': trainable_params,
|
| 234 |
+
'non_trainable_parameters': total_params - trainable_params,
|
| 235 |
+
'model_size_mb': round(size_mb, 2),
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
return stats
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def optimize_memory_usage(device: str = "cpu"):
|
| 242 |
+
"""
|
| 243 |
+
Optimize memory usage for inference
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
device: Target device
|
| 247 |
+
"""
|
| 248 |
+
import gc
|
| 249 |
+
|
| 250 |
+
# Clear CUDA cache if available
|
| 251 |
+
if torch.cuda.is_available():
|
| 252 |
+
torch.cuda.empty_cache()
|
| 253 |
+
|
| 254 |
+
# Force garbage collection
|
| 255 |
+
gc.collect()
|
| 256 |
+
|
| 257 |
+
# Set memory allocator for CPU
|
| 258 |
+
if device == "cpu":
|
| 259 |
+
# Enable memory efficient attention if available
|
| 260 |
+
try:
|
| 261 |
+
import os
|
| 262 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
| 263 |
+
except:
|
| 264 |
+
pass
|
| 265 |
+
|
| 266 |
+
print("Memory optimization applied")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def set_seed(seed: int):
|
| 270 |
+
"""
|
| 271 |
+
Set random seed for reproducibility
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
seed: Random seed value
|
| 275 |
+
"""
|
| 276 |
+
torch.manual_seed(seed)
|
| 277 |
+
if torch.cuda.is_available():
|
| 278 |
+
torch.cuda.manual_seed_all(seed)
|
| 279 |
+
np.random.seed(seed)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def validate_prompt(prompt: str) -> Tuple[bool, str]:
|
| 283 |
+
"""
|
| 284 |
+
Validate and sanitize prompt
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
prompt: Input prompt
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
Tuple of (is_valid, message)
|
| 291 |
+
"""
|
| 292 |
+
if not prompt or not prompt.strip():
|
| 293 |
+
return False, "Prompt cannot be empty"
|
| 294 |
+
|
| 295 |
+
if len(prompt) > 1000:
|
| 296 |
+
return False, "Prompt too long (max 1000 characters)"
|
| 297 |
+
|
| 298 |
+
# Check for potentially harmful content
|
| 299 |
+
forbidden_terms = []
|
| 300 |
+
for term in forbidden_terms:
|
| 301 |
+
if term.lower() in prompt.lower():
|
| 302 |
+
return False, f"Prompt contains forbidden term: {term}"
|
| 303 |
+
|
| 304 |
+
return True, "Valid prompt"
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def create_image_grid(
|
| 308 |
+
images: List[Image.Image],
|
| 309 |
+
rows: int = None,
|
| 310 |
+
cols: int = None,
|
| 311 |
+
) -> Image.Image:
|
| 312 |
+
"""
|
| 313 |
+
Create a grid of images
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
images: List of PIL Images
|
| 317 |
+
rows: Number of rows
|
| 318 |
+
cols: Number of columns
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Grid image
|
| 322 |
+
"""
|
| 323 |
+
if not images:
|
| 324 |
+
raise ValueError("No images provided")
|
| 325 |
+
|
| 326 |
+
num_images = len(images)
|
| 327 |
+
|
| 328 |
+
# Determine grid dimensions
|
| 329 |
+
if rows is None and cols is None:
|
| 330 |
+
cols = int(np.ceil(np.sqrt(num_images)))
|
| 331 |
+
rows = int(np.ceil(num_images / cols))
|
| 332 |
+
elif rows is None:
|
| 333 |
+
rows = int(np.ceil(num_images / cols))
|
| 334 |
+
elif cols is None:
|
| 335 |
+
cols = int(np.ceil(num_images / rows))
|
| 336 |
+
|
| 337 |
+
# Get image size (use first image as reference)
|
| 338 |
+
width, height = images[0].size
|
| 339 |
+
|
| 340 |
+
# Create grid image
|
| 341 |
+
grid_width = cols * width
|
| 342 |
+
grid_height = rows * height
|
| 343 |
+
grid_image = Image.new('RGB', (grid_width, grid_height), color='white')
|
| 344 |
+
|
| 345 |
+
# Paste images into grid
|
| 346 |
+
for i, image in enumerate(images):
|
| 347 |
+
row = i // cols
|
| 348 |
+
col = i % cols
|
| 349 |
+
x = col * width
|
| 350 |
+
y = row * height
|
| 351 |
+
grid_image.paste(image, (x, y))
|
| 352 |
+
|
| 353 |
+
return grid_image
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def get_device_info() -> dict:
|
| 357 |
+
"""
|
| 358 |
+
Get device information
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Dictionary with device info
|
| 362 |
+
"""
|
| 363 |
+
info = {
|
| 364 |
+
'cuda_available': torch.cuda.is_available(),
|
| 365 |
+
'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
| 366 |
+
'cpu_cores': __import__('os').cpu_count(),
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
if torch.cuda.is_available():
|
| 370 |
+
info['current_device'] = torch.cuda.current_device()
|
| 371 |
+
info['device_name'] = torch.cuda.get_device_name(0)
|
| 372 |
+
info['cuda_version'] = torch.version.cuda
|
| 373 |
+
|
| 374 |
+
return info
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class ProgressTracker:
|
| 378 |
+
"""Track progress of long-running operations"""
|
| 379 |
+
|
| 380 |
+
def __init__(self, total: int, description: str = ""):
|
| 381 |
+
self.total = total
|
| 382 |
+
self.current = 0
|
| 383 |
+
self.description = description
|
| 384 |
+
|
| 385 |
+
def update(self, n: int = 1):
|
| 386 |
+
"""Update progress"""
|
| 387 |
+
self.current += n
|
| 388 |
+
|
| 389 |
+
def get_progress(self) -> float:
|
| 390 |
+
"""Get progress percentage"""
|
| 391 |
+
return (self.current / self.total) * 100 if self.total > 0 else 0
|
| 392 |
+
|
| 393 |
+
def __str__(self):
|
| 394 |
+
percent = self.get_progress()
|
| 395 |
+
bar_length = 30
|
| 396 |
+
filled_length = int(bar_length * self.current // self.total)
|
| 397 |
+
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
| 398 |
+
return f"{self.description}: [{bar}] {percent:.1f}% ({self.current}/{self.total})"
|
config.yaml
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte Dream Configuration
|
| 2 |
+
|
| 3 |
+
model:
|
| 4 |
+
name: "Byte Dream"
|
| 5 |
+
version: "1.0.0"
|
| 6 |
+
|
| 7 |
+
# Model architecture parameters
|
| 8 |
+
unet:
|
| 9 |
+
in_channels: 4
|
| 10 |
+
out_channels: 4
|
| 11 |
+
block_out_channels: [320, 640, 1280, 1280]
|
| 12 |
+
layers_per_block: 2
|
| 13 |
+
attention_head_dim: 8
|
| 14 |
+
cross_attention_dim: 768
|
| 15 |
+
use_linear_projection: true
|
| 16 |
+
|
| 17 |
+
scheduler:
|
| 18 |
+
name: "DDIM" # Options: DDIM, PNDM, LMSDiscrete, EulerDiscrete
|
| 19 |
+
num_train_timesteps: 1000
|
| 20 |
+
beta_start: 0.00085
|
| 21 |
+
beta_end: 0.012
|
| 22 |
+
beta_schedule: "scaled_linear"
|
| 23 |
+
clip_sample: false
|
| 24 |
+
set_alpha_to_one: false
|
| 25 |
+
|
| 26 |
+
vae:
|
| 27 |
+
in_channels: 3
|
| 28 |
+
out_channels: 3
|
| 29 |
+
down_block_types: ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"]
|
| 30 |
+
up_block_types: ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"]
|
| 31 |
+
latent_channels: 4
|
| 32 |
+
sample_size: 512
|
| 33 |
+
|
| 34 |
+
text_encoder:
|
| 35 |
+
model: "openai/clip-vit-large-patch14"
|
| 36 |
+
max_length: 77
|
| 37 |
+
|
| 38 |
+
# Generation parameters
|
| 39 |
+
generation:
|
| 40 |
+
width: 512
|
| 41 |
+
height: 512
|
| 42 |
+
num_inference_steps: 50
|
| 43 |
+
guidance_scale: 7.5
|
| 44 |
+
negative_prompt: "ugly, blurry, low quality, distorted, deformed"
|
| 45 |
+
seed: null # null for random, or set integer
|
| 46 |
+
|
| 47 |
+
# CPU Optimization
|
| 48 |
+
cpu_optimization:
|
| 49 |
+
use_openvino: false
|
| 50 |
+
use_onnx: false
|
| 51 |
+
precision: "fp32" # fp32 or fp16
|
| 52 |
+
threads: -1 # -1 for all available threads
|
| 53 |
+
memory_limit: null # null for auto, or MB value
|
| 54 |
+
|
| 55 |
+
# Training parameters
|
| 56 |
+
training:
|
| 57 |
+
dataset_path: "./dataset"
|
| 58 |
+
output_dir: "./models/bytedream"
|
| 59 |
+
epochs: 100
|
| 60 |
+
batch_size: 4
|
| 61 |
+
gradient_accumulation_steps: 1
|
| 62 |
+
learning_rate: 1e-5
|
| 63 |
+
lr_scheduler: "constant_with_warmup"
|
| 64 |
+
lr_warmup_steps: 500
|
| 65 |
+
max_grad_norm: 1.0
|
| 66 |
+
mixed_precision: "no" # no, fp16, bf16
|
| 67 |
+
|
| 68 |
+
# Data augmentation
|
| 69 |
+
random_flip: true
|
| 70 |
+
random_crop: false
|
| 71 |
+
center_crop: true
|
| 72 |
+
|
| 73 |
+
# Logging
|
| 74 |
+
logging_dir: "./logs"
|
| 75 |
+
log_every_n_steps: 10
|
| 76 |
+
|
| 77 |
+
# Hugging Face
|
| 78 |
+
huggingface:
|
| 79 |
+
organization: "" # Your HF username/organization
|
| 80 |
+
private: false
|
| 81 |
+
push_to_hub: true
|
environment.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: bytedream
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- conda-forge
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.10
|
| 8 |
+
- pip
|
| 9 |
+
- pip:
|
| 10 |
+
- transformers>=4.35.0
|
| 11 |
+
- diffusers>=0.24.0
|
| 12 |
+
- torch>=2.1.0
|
| 13 |
+
- torchaudio>=2.1.0
|
| 14 |
+
- accelerate>=0.25.0
|
| 15 |
+
- numpy>=1.24.0
|
| 16 |
+
- pillow>=10.0.0
|
| 17 |
+
- opencv-python>=4.8.0
|
| 18 |
+
- safetensors>=0.4.0
|
| 19 |
+
- huggingface_hub>=0.19.0
|
| 20 |
+
- gradio>=4.0.0
|
| 21 |
+
- tqdm>=4.66.0
|
| 22 |
+
- pyyaml>=6.0
|
| 23 |
+
- matplotlib>=3.8.0
|
| 24 |
+
- scipy>=1.11.0
|
| 25 |
+
- einops>=0.7.0
|
examples.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream - Example Usage Scripts
|
| 3 |
+
Practical examples for different use cases
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from bytedream import ByteDreamGenerator
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def example_basic_generation():
|
| 11 |
+
"""Basic image generation example"""
|
| 12 |
+
print("\n" + "="*60)
|
| 13 |
+
print("Example 1: Basic Generation")
|
| 14 |
+
print("="*60)
|
| 15 |
+
|
| 16 |
+
generator = ByteDreamGenerator()
|
| 17 |
+
|
| 18 |
+
# Simple prompt
|
| 19 |
+
image = generator.generate(
|
| 20 |
+
prompt="A beautiful sunset over mountains, digital art",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
image.save("example_basic.png")
|
| 24 |
+
print("✓ Saved to: example_basic.png")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def example_advanced_parameters():
|
| 28 |
+
"""Advanced parameter tuning"""
|
| 29 |
+
print("\n" + "="*60)
|
| 30 |
+
print("Example 2: Advanced Parameters")
|
| 31 |
+
print("="*60)
|
| 32 |
+
|
| 33 |
+
generator = ByteDreamGenerator()
|
| 34 |
+
|
| 35 |
+
# Custom parameters
|
| 36 |
+
image = generator.generate(
|
| 37 |
+
prompt="Cyberpunk city at night, neon lights, futuristic architecture",
|
| 38 |
+
negative_prompt="ugly, blurry, low quality, distorted, dark",
|
| 39 |
+
width=768,
|
| 40 |
+
height=768,
|
| 41 |
+
num_inference_steps=75,
|
| 42 |
+
guidance_scale=9.0,
|
| 43 |
+
seed=42,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
image.save("example_advanced.png")
|
| 47 |
+
print("✓ Saved to: example_advanced.png")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def example_batch_generation():
|
| 51 |
+
"""Generate multiple images"""
|
| 52 |
+
print("\n" + "="*60)
|
| 53 |
+
print("Example 3: Batch Generation")
|
| 54 |
+
print("="*60)
|
| 55 |
+
|
| 56 |
+
generator = ByteDreamGenerator()
|
| 57 |
+
|
| 58 |
+
prompts = [
|
| 59 |
+
"Fantasy landscape with castle and waterfall, epic scenery",
|
| 60 |
+
"Underwater coral reef, tropical fish, sunlight through water",
|
| 61 |
+
"Space nebula, colorful clouds, stars, cosmic scene",
|
| 62 |
+
"Medieval knight in armor, dramatic lighting, portrait",
|
| 63 |
+
"Japanese garden, cherry blossoms, peaceful atmosphere",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
images = generator.generate_batch(
|
| 67 |
+
prompts=prompts,
|
| 68 |
+
width=512,
|
| 69 |
+
height=512,
|
| 70 |
+
num_inference_steps=50,
|
| 71 |
+
guidance_scale=7.5,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Save individually
|
| 75 |
+
for i, (prompt, image) in enumerate(zip(prompts, images)):
|
| 76 |
+
filename = f"batch_{i+1}.png"
|
| 77 |
+
image.save(filename)
|
| 78 |
+
print(f"✓ Saved: {filename}")
|
| 79 |
+
|
| 80 |
+
# Create grid
|
| 81 |
+
from bytedream.utils import create_image_grid
|
| 82 |
+
grid = create_image_grid(images)
|
| 83 |
+
grid.save("batch_grid.png")
|
| 84 |
+
print("✓ Grid saved to: batch_grid.png")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def example_artistic_styles():
|
| 88 |
+
"""Different artistic styles"""
|
| 89 |
+
print("\n" + "="*60)
|
| 90 |
+
print("Example 4: Artistic Styles")
|
| 91 |
+
print("="*60)
|
| 92 |
+
|
| 93 |
+
generator = ByteDreamGenerator()
|
| 94 |
+
|
| 95 |
+
style_prompts = [
|
| 96 |
+
("Oil Painting", "Portrait of a woman, oil painting style, brush strokes, classical art"),
|
| 97 |
+
("Watercolor", "Forest landscape, watercolor painting, soft colors, artistic"),
|
| 98 |
+
("Digital Art", "Sci-fi spaceship, digital art, concept art, highly detailed"),
|
| 99 |
+
("Sketch", "City skyline, pencil sketch, black and white, drawing"),
|
| 100 |
+
("Abstract", "Emotions and dreams, abstract art, colorful shapes, surreal"),
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
for style_name, prompt in style_prompts:
|
| 104 |
+
print(f"\nGenerating {style_name}...")
|
| 105 |
+
|
| 106 |
+
image = generator.generate(
|
| 107 |
+
prompt=prompt,
|
| 108 |
+
num_inference_steps=50,
|
| 109 |
+
guidance_scale=7.5,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
filename = f"style_{style_name.lower().replace(' ', '_')}.png"
|
| 113 |
+
image.save(filename)
|
| 114 |
+
print(f"✓ Saved: {filename}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def example_resolutions():
|
| 118 |
+
"""Test different resolutions"""
|
| 119 |
+
print("\n" + "="*60)
|
| 120 |
+
print("Example 5: Different Resolutions")
|
| 121 |
+
print("="*60)
|
| 122 |
+
|
| 123 |
+
generator = ByteDreamGenerator()
|
| 124 |
+
|
| 125 |
+
base_prompt = "Majestic mountain range, snow peaks, blue sky"
|
| 126 |
+
|
| 127 |
+
resolutions = [
|
| 128 |
+
(256, 256),
|
| 129 |
+
(512, 512),
|
| 130 |
+
(768, 768),
|
| 131 |
+
(512, 768), # Portrait
|
| 132 |
+
(768, 512), # Landscape
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
for width, height in resolutions:
|
| 136 |
+
print(f"\nGenerating {width}x{height}...")
|
| 137 |
+
|
| 138 |
+
image = generator.generate(
|
| 139 |
+
prompt=base_prompt,
|
| 140 |
+
width=width,
|
| 141 |
+
height=height,
|
| 142 |
+
num_inference_steps=40,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
filename = f"res_{width}x{height}.png"
|
| 146 |
+
image.save(filename)
|
| 147 |
+
print(f"✓ Saved: {filename}")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def example_reproducibility():
|
| 151 |
+
"""Demonstrate reproducibility with seeds"""
|
| 152 |
+
print("\n" + "="*60)
|
| 153 |
+
print("Example 6: Reproducibility with Seeds")
|
| 154 |
+
print("="*60)
|
| 155 |
+
|
| 156 |
+
generator = ByteDreamGenerator()
|
| 157 |
+
|
| 158 |
+
prompt = "A mystical forest with glowing mushrooms, fantasy art"
|
| 159 |
+
|
| 160 |
+
# Generate same image twice with same seed
|
| 161 |
+
print("\nGenerating with seed=123...")
|
| 162 |
+
image1 = generator.generate(
|
| 163 |
+
prompt=prompt,
|
| 164 |
+
seed=123,
|
| 165 |
+
)
|
| 166 |
+
image1.save("repro_1.png")
|
| 167 |
+
|
| 168 |
+
print("Generating again with seed=123...")
|
| 169 |
+
image2 = generator.generate(
|
| 170 |
+
prompt=prompt,
|
| 171 |
+
seed=123,
|
| 172 |
+
)
|
| 173 |
+
image2.save("repro_2.png")
|
| 174 |
+
|
| 175 |
+
print("\nBoth images should be identical!")
|
| 176 |
+
print("✓ Check repro_1.png and repro_2.png")
|
| 177 |
+
|
| 178 |
+
# Generate with different seed
|
| 179 |
+
print("\nGenerating with seed=456...")
|
| 180 |
+
image3 = generator.generate(
|
| 181 |
+
prompt=prompt,
|
| 182 |
+
seed=456,
|
| 183 |
+
)
|
| 184 |
+
image3.save("repro_3.png")
|
| 185 |
+
print("This one will be different!")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def example_negative_prompts():
|
| 189 |
+
"""Using negative prompts effectively"""
|
| 190 |
+
print("\n" + "="*60)
|
| 191 |
+
print("Example 7: Negative Prompts")
|
| 192 |
+
print("="*60)
|
| 193 |
+
|
| 194 |
+
generator = ByteDreamGenerator()
|
| 195 |
+
|
| 196 |
+
base_prompt = "Beautiful princess, elegant dress, castle background"
|
| 197 |
+
|
| 198 |
+
# Without negative prompt
|
| 199 |
+
print("\nWithout negative prompt...")
|
| 200 |
+
image1 = generator.generate(
|
| 201 |
+
prompt=base_prompt,
|
| 202 |
+
seed=789,
|
| 203 |
+
)
|
| 204 |
+
image1.save("no_negative.png")
|
| 205 |
+
|
| 206 |
+
# With negative prompt
|
| 207 |
+
print("With negative prompt...")
|
| 208 |
+
image2 = generator.generate(
|
| 209 |
+
prompt=base_prompt,
|
| 210 |
+
negative_prompt="ugly, deformed, noisy, blurry, bad anatomy, poorly drawn",
|
| 211 |
+
seed=789,
|
| 212 |
+
)
|
| 213 |
+
image2.save("with_negative.png")
|
| 214 |
+
|
| 215 |
+
print("\nCompare no_negative.png and with_negative.png")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def example_quick_preview():
|
| 219 |
+
"""Quick low-resolution previews"""
|
| 220 |
+
print("\n" + "="*60)
|
| 221 |
+
print("Example 8: Quick Preview Mode")
|
| 222 |
+
print("="*60)
|
| 223 |
+
|
| 224 |
+
generator = ByteDreamGenerator()
|
| 225 |
+
|
| 226 |
+
prompt = "Dragon breathing fire, epic fantasy battle scene"
|
| 227 |
+
|
| 228 |
+
# Quick preview
|
| 229 |
+
print("Generating quick preview (256x256, 20 steps)...")
|
| 230 |
+
preview = generator.generate(
|
| 231 |
+
prompt=prompt,
|
| 232 |
+
width=256,
|
| 233 |
+
height=256,
|
| 234 |
+
num_inference_steps=20,
|
| 235 |
+
)
|
| 236 |
+
preview.save("preview.png")
|
| 237 |
+
print("✓ Preview saved")
|
| 238 |
+
|
| 239 |
+
# Full resolution
|
| 240 |
+
print("\nGenerating full quality (768x768, 75 steps)...")
|
| 241 |
+
full = generator.generate(
|
| 242 |
+
prompt=prompt,
|
| 243 |
+
width=768,
|
| 244 |
+
height=768,
|
| 245 |
+
num_inference_steps=75,
|
| 246 |
+
)
|
| 247 |
+
full.save("full_quality.png")
|
| 248 |
+
print("✓ Full quality saved")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def run_all_examples():
|
| 252 |
+
"""Run all examples sequentially"""
|
| 253 |
+
print("\n" + "="*60)
|
| 254 |
+
print("Byte Dream - Complete Examples Suite")
|
| 255 |
+
print("="*60)
|
| 256 |
+
|
| 257 |
+
examples = [
|
| 258 |
+
example_basic_generation,
|
| 259 |
+
example_advanced_parameters,
|
| 260 |
+
example_batch_generation,
|
| 261 |
+
example_artistic_styles,
|
| 262 |
+
example_resolutions,
|
| 263 |
+
example_reproducibility,
|
| 264 |
+
example_negative_prompts,
|
| 265 |
+
example_quick_preview,
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
for example_func in examples:
|
| 269 |
+
try:
|
| 270 |
+
example_func()
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"\n✗ Error in {example_func.__name__}: {e}")
|
| 273 |
+
import traceback
|
| 274 |
+
traceback.print_exc()
|
| 275 |
+
|
| 276 |
+
print("\n" + "-"*60)
|
| 277 |
+
input("Press Enter to continue to next example...")
|
| 278 |
+
|
| 279 |
+
print("\n" + "="*60)
|
| 280 |
+
print("All examples completed!")
|
| 281 |
+
print("="*60)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == "__main__":
|
| 285 |
+
import argparse
|
| 286 |
+
|
| 287 |
+
parser = argparse.ArgumentParser(description="Byte Dream Examples")
|
| 288 |
+
parser.add_argument("--all", action="store_true", help="Run all examples")
|
| 289 |
+
parser.add_argument("--basic", action="store_true", help="Run basic example")
|
| 290 |
+
parser.add_argument("--advanced", action="store_true", help="Run advanced example")
|
| 291 |
+
parser.add_argument("--batch", action="store_true", help="Run batch generation")
|
| 292 |
+
parser.add_argument("--styles", action="store_true", help="Run artistic styles")
|
| 293 |
+
|
| 294 |
+
args = parser.parse_args()
|
| 295 |
+
|
| 296 |
+
if args.all:
|
| 297 |
+
run_all_examples()
|
| 298 |
+
elif args.basic:
|
| 299 |
+
example_basic_generation()
|
| 300 |
+
elif args.advanced:
|
| 301 |
+
example_advanced_parameters()
|
| 302 |
+
elif args.batch:
|
| 303 |
+
example_batch_generation()
|
| 304 |
+
elif args.styles:
|
| 305 |
+
example_artistic_styles()
|
| 306 |
+
else:
|
| 307 |
+
# Default: show menu
|
| 308 |
+
print("\nByte Dream Examples")
|
| 309 |
+
print("="*60)
|
| 310 |
+
print("Choose an example:")
|
| 311 |
+
print(" --basic : Basic generation")
|
| 312 |
+
print(" --advanced : Advanced parameters")
|
| 313 |
+
print(" --batch : Batch generation")
|
| 314 |
+
print(" --styles : Artistic styles")
|
| 315 |
+
print(" --all : Run all examples")
|
| 316 |
+
print("\nOr just run without arguments to see all examples")
|
infer.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream - Command Line Inference Tool
|
| 3 |
+
Generate images from text prompts using the command line
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
parser = argparse.ArgumentParser(
|
| 13 |
+
description="Byte Dream - AI Image Generation",
|
| 14 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 15 |
+
epilog="""
|
| 16 |
+
Examples:
|
| 17 |
+
# Basic usage
|
| 18 |
+
python infer.py --prompt "A beautiful sunset over mountains"
|
| 19 |
+
|
| 20 |
+
# With custom parameters
|
| 21 |
+
python infer.py --prompt "Cyberpunk city" --negative "blurry" --steps 75 --guidance 8.0
|
| 22 |
+
|
| 23 |
+
# Specify output and size
|
| 24 |
+
python infer.py --prompt "Fantasy landscape" --output fantasy.png --width 768 --height 768
|
| 25 |
+
|
| 26 |
+
# With seed for reproducibility
|
| 27 |
+
python infer.py --prompt "Dragon" --seed 42 --output dragon.png
|
| 28 |
+
"""
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--prompt", "-p",
|
| 33 |
+
type=str,
|
| 34 |
+
required=True,
|
| 35 |
+
help="Text prompt describing the desired image"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--negative", "-n",
|
| 40 |
+
type=str,
|
| 41 |
+
default="",
|
| 42 |
+
help="Negative prompt - things to avoid in the image"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--output", "-o",
|
| 47 |
+
type=str,
|
| 48 |
+
default="output.png",
|
| 49 |
+
help="Output image filename (default: output.png)"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--width", "-W",
|
| 54 |
+
type=int,
|
| 55 |
+
default=512,
|
| 56 |
+
help="Image width in pixels (default: 512)"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--height", "-H",
|
| 61 |
+
type=int,
|
| 62 |
+
default=512,
|
| 63 |
+
help="Image height in pixels (default: 512)"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--steps", "-s",
|
| 68 |
+
type=int,
|
| 69 |
+
default=50,
|
| 70 |
+
help="Number of inference steps (default: 50)"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--guidance", "-g",
|
| 75 |
+
type=float,
|
| 76 |
+
default=7.5,
|
| 77 |
+
help="Guidance scale - how closely to follow prompt (default: 7.5)"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--seed",
|
| 82 |
+
type=int,
|
| 83 |
+
default=None,
|
| 84 |
+
help="Random seed for reproducibility (default: random)"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--model", "-m",
|
| 89 |
+
type=str,
|
| 90 |
+
default=None,
|
| 91 |
+
help="Path to model directory (default: uses config)"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--config", "-c",
|
| 96 |
+
type=str,
|
| 97 |
+
default="config.yaml",
|
| 98 |
+
help="Path to config file (default: config.yaml)"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--device",
|
| 103 |
+
type=str,
|
| 104 |
+
default="cpu",
|
| 105 |
+
help="Device to run on: cpu or cuda (default: cpu)"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
args = parser.parse_args()
|
| 109 |
+
|
| 110 |
+
# Import generator
|
| 111 |
+
from bytedream.generator import ByteDreamGenerator
|
| 112 |
+
|
| 113 |
+
# Initialize generator
|
| 114 |
+
print("="*60)
|
| 115 |
+
print("Byte Dream - AI Image Generator")
|
| 116 |
+
print("="*60)
|
| 117 |
+
|
| 118 |
+
generator = ByteDreamGenerator(
|
| 119 |
+
model_path=args.model,
|
| 120 |
+
config_path=args.config,
|
| 121 |
+
device=args.device,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Print model info
|
| 125 |
+
info = generator.get_model_info()
|
| 126 |
+
print(f"\nModel: {info['name']} v{info['version']}")
|
| 127 |
+
print(f"Device: {info['device']}")
|
| 128 |
+
print(f"Parameters: {info['unet_parameters']}")
|
| 129 |
+
print("="*60)
|
| 130 |
+
|
| 131 |
+
# Generate image
|
| 132 |
+
image = generator.generate(
|
| 133 |
+
prompt=args.prompt,
|
| 134 |
+
negative_prompt=args.negative if args.negative else None,
|
| 135 |
+
width=args.width,
|
| 136 |
+
height=args.height,
|
| 137 |
+
num_inference_steps=args.steps,
|
| 138 |
+
guidance_scale=args.guidance,
|
| 139 |
+
seed=args.seed,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Save image
|
| 143 |
+
output_path = Path(args.output)
|
| 144 |
+
image.save(output_path)
|
| 145 |
+
print(f"\n✓ Image saved to: {output_path.absolute()}")
|
| 146 |
+
print("="*60)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream - Main Application Interface
|
| 3 |
+
Simple Python API for image generation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from bytedream.generator import ByteDreamGenerator
|
| 7 |
+
from bytedream.utils import (
|
| 8 |
+
load_image,
|
| 9 |
+
save_image,
|
| 10 |
+
resize_image,
|
| 11 |
+
create_image_grid,
|
| 12 |
+
)
|
| 13 |
+
from typing import Optional, List
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ByteDreamApp:
|
| 18 |
+
"""
|
| 19 |
+
High-level application interface for Byte Dream
|
| 20 |
+
Simplifies common tasks like image generation and batch processing
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
model_path: Optional[str] = None,
|
| 26 |
+
device: str = "cpu",
|
| 27 |
+
verbose: bool = True,
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Initialize Byte Dream application
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_path: Path to model weights
|
| 34 |
+
device: Device to run on
|
| 35 |
+
verbose: Enable verbose output
|
| 36 |
+
"""
|
| 37 |
+
self.verbose = verbose
|
| 38 |
+
|
| 39 |
+
if self.verbose:
|
| 40 |
+
print("Initializing Byte Dream Application...")
|
| 41 |
+
|
| 42 |
+
self.generator = ByteDreamGenerator(
|
| 43 |
+
model_path=model_path,
|
| 44 |
+
config_path="config.yaml",
|
| 45 |
+
device=device,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if self.verbose:
|
| 49 |
+
print("✓ Application ready!")
|
| 50 |
+
|
| 51 |
+
def generate(
|
| 52 |
+
self,
|
| 53 |
+
prompt: str,
|
| 54 |
+
output_path: str = "output.png",
|
| 55 |
+
negative_prompt: Optional[str] = None,
|
| 56 |
+
width: int = 512,
|
| 57 |
+
height: int = 512,
|
| 58 |
+
steps: int = 50,
|
| 59 |
+
guidance: float = 7.5,
|
| 60 |
+
seed: Optional[int] = None,
|
| 61 |
+
save: bool = True,
|
| 62 |
+
) -> Image.Image:
|
| 63 |
+
"""
|
| 64 |
+
Generate image from prompt and optionally save to file
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
prompt: Text description
|
| 68 |
+
output_path: Where to save the image
|
| 69 |
+
negative_prompt: What to avoid
|
| 70 |
+
width: Image width
|
| 71 |
+
height: Image height
|
| 72 |
+
steps: Inference steps
|
| 73 |
+
guidance: Guidance scale
|
| 74 |
+
seed: Random seed
|
| 75 |
+
save: Whether to save to file
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Generated PIL Image
|
| 79 |
+
"""
|
| 80 |
+
# Generate image
|
| 81 |
+
image = self.generator.generate(
|
| 82 |
+
prompt=prompt,
|
| 83 |
+
negative_prompt=negative_prompt,
|
| 84 |
+
width=width,
|
| 85 |
+
height=height,
|
| 86 |
+
num_inference_steps=steps,
|
| 87 |
+
guidance_scale=guidance,
|
| 88 |
+
seed=seed,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Save if requested
|
| 92 |
+
if save:
|
| 93 |
+
save_image(image, output_path)
|
| 94 |
+
if self.verbose:
|
| 95 |
+
print(f"✓ Image saved to: {output_path}")
|
| 96 |
+
|
| 97 |
+
return image
|
| 98 |
+
|
| 99 |
+
def generate_multiple(
|
| 100 |
+
self,
|
| 101 |
+
prompts: List[str],
|
| 102 |
+
output_dir: str = "./outputs",
|
| 103 |
+
negative_prompt: Optional[str] = None,
|
| 104 |
+
width: int = 512,
|
| 105 |
+
height: int = 512,
|
| 106 |
+
steps: int = 50,
|
| 107 |
+
guidance: float = 7.5,
|
| 108 |
+
seeds: Optional[List[int]] = None,
|
| 109 |
+
create_grid: bool = True,
|
| 110 |
+
) -> List[Image.Image]:
|
| 111 |
+
"""
|
| 112 |
+
Generate multiple images from prompts
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
prompts: List of prompts
|
| 116 |
+
output_dir: Directory to save images
|
| 117 |
+
negative_prompt: Negative prompt for all
|
| 118 |
+
width: Image width
|
| 119 |
+
height: Image height
|
| 120 |
+
steps: Inference steps
|
| 121 |
+
guidance: Guidance scale
|
| 122 |
+
seeds: Seeds for each image
|
| 123 |
+
create_grid: Create grid of all images
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
List of generated images
|
| 127 |
+
"""
|
| 128 |
+
from pathlib import Path
|
| 129 |
+
|
| 130 |
+
output_path = Path(output_dir)
|
| 131 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
images = []
|
| 134 |
+
|
| 135 |
+
for i, prompt in enumerate(prompts):
|
| 136 |
+
print(f"\n{'='*60}")
|
| 137 |
+
print(f"Generating image {i+1}/{len(prompts)}")
|
| 138 |
+
print(f"{'='*60}")
|
| 139 |
+
|
| 140 |
+
seed = seeds[i] if seeds else None
|
| 141 |
+
|
| 142 |
+
image = self.generate(
|
| 143 |
+
prompt=prompt,
|
| 144 |
+
output_path=str(output_path / f"image_{i+1:03d}.png"),
|
| 145 |
+
negative_prompt=negative_prompt,
|
| 146 |
+
width=width,
|
| 147 |
+
height=height,
|
| 148 |
+
steps=steps,
|
| 149 |
+
guidance=guidance,
|
| 150 |
+
seed=seed,
|
| 151 |
+
save=True,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
images.append(image)
|
| 155 |
+
|
| 156 |
+
# Create grid
|
| 157 |
+
if create_grid and len(images) > 1:
|
| 158 |
+
grid = create_image_grid(images)
|
| 159 |
+
grid_path = output_path / "grid.png"
|
| 160 |
+
grid.save(grid_path)
|
| 161 |
+
print(f"\n✓ Grid saved to: {grid_path}")
|
| 162 |
+
|
| 163 |
+
return images
|
| 164 |
+
|
| 165 |
+
def img2img(
|
| 166 |
+
self,
|
| 167 |
+
input_image_path: str,
|
| 168 |
+
prompt: str,
|
| 169 |
+
output_path: str = "output_img2img.png",
|
| 170 |
+
strength: float = 0.75,
|
| 171 |
+
negative_prompt: Optional[str] = None,
|
| 172 |
+
steps: int = 50,
|
| 173 |
+
guidance: float = 7.5,
|
| 174 |
+
seed: Optional[int] = None,
|
| 175 |
+
) -> Image.Image:
|
| 176 |
+
"""
|
| 177 |
+
Image-to-image transformation (placeholder for future implementation)
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
input_image_path: Input image path
|
| 181 |
+
prompt: Transformation prompt
|
| 182 |
+
output_path: Output path
|
| 183 |
+
strength: How much to transform (0-1)
|
| 184 |
+
negative_prompt: Negative prompt
|
| 185 |
+
steps: Inference steps
|
| 186 |
+
guidance: Guidance scale
|
| 187 |
+
seed: Random seed
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Transformed image
|
| 191 |
+
"""
|
| 192 |
+
print("⚠ img2img functionality will be available in a future update")
|
| 193 |
+
print(" For now, using text-to-image generation only")
|
| 194 |
+
|
| 195 |
+
# For now, just generate from prompt
|
| 196 |
+
return self.generate(
|
| 197 |
+
prompt=prompt,
|
| 198 |
+
output_path=output_path,
|
| 199 |
+
negative_prompt=negative_prompt,
|
| 200 |
+
steps=steps,
|
| 201 |
+
guidance=guidance,
|
| 202 |
+
seed=seed,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def info(self):
|
| 206 |
+
"""Print model information"""
|
| 207 |
+
info = self.generator.get_model_info()
|
| 208 |
+
|
| 209 |
+
print("\n" + "="*60)
|
| 210 |
+
print("Byte Dream Model Information")
|
| 211 |
+
print("="*60)
|
| 212 |
+
for key, value in info.items():
|
| 213 |
+
print(f"{key.replace('_', ' ').title()}: {value}")
|
| 214 |
+
print("="*60)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def demo():
|
| 218 |
+
"""Run a quick demo"""
|
| 219 |
+
print("\n" + "="*60)
|
| 220 |
+
print("Byte Dream - Quick Demo")
|
| 221 |
+
print("="*60)
|
| 222 |
+
|
| 223 |
+
app = ByteDreamApp(device="cpu", verbose=True)
|
| 224 |
+
|
| 225 |
+
# Demo prompts
|
| 226 |
+
prompts = [
|
| 227 |
+
"A beautiful sunset over mountains, digital art, vibrant colors",
|
| 228 |
+
"Cyberpunk city at night with neon lights, futuristic",
|
| 229 |
+
"Fantasy landscape with castle and waterfall, epic",
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
print("\nGenerating sample images...")
|
| 233 |
+
|
| 234 |
+
images = app.generate_multiple(
|
| 235 |
+
prompts=prompts,
|
| 236 |
+
output_dir="./demo_outputs",
|
| 237 |
+
steps=30, # Fewer steps for demo
|
| 238 |
+
guidance=7.5,
|
| 239 |
+
create_grid=True,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
print(f"\n✓ Demo complete! Generated {len(images)} images")
|
| 243 |
+
print(" Check ./demo_outputs/ for results")
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
import argparse
|
| 248 |
+
|
| 249 |
+
parser = argparse.ArgumentParser(description="Byte Dream Application")
|
| 250 |
+
parser.add_argument("--demo", action="store_true", help="Run demo")
|
| 251 |
+
args = parser.parse_args()
|
| 252 |
+
|
| 253 |
+
if args.demo:
|
| 254 |
+
demo()
|
| 255 |
+
else:
|
| 256 |
+
# Interactive mode
|
| 257 |
+
app = ByteDreamApp()
|
| 258 |
+
|
| 259 |
+
print("\nByte Dream Interactive Mode")
|
| 260 |
+
print("Type 'quit' to exit\n")
|
| 261 |
+
|
| 262 |
+
while True:
|
| 263 |
+
prompt = input("Prompt: ").strip()
|
| 264 |
+
|
| 265 |
+
if prompt.lower() in ['quit', 'exit', 'q']:
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
if not prompt:
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
image = app.generate(
|
| 273 |
+
prompt=prompt,
|
| 274 |
+
output_path=f"output_{len(prompt)}.png",
|
| 275 |
+
)
|
| 276 |
+
print("✓ Image generated!\n")
|
| 277 |
+
except Exception as e:
|
| 278 |
+
print(f"Error: {e}\n")
|
prepare_dataset.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Preparation Tool
|
| 3 |
+
Prepare and preprocess image-text datasets for training
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import json
|
| 10 |
+
import shutil
|
| 11 |
+
from typing import List, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def prepare_dataset(
|
| 15 |
+
input_dir: str,
|
| 16 |
+
output_dir: str,
|
| 17 |
+
image_size: int = 512,
|
| 18 |
+
min_resolution: int = 256,
|
| 19 |
+
filter_low_quality: bool = True,
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Prepare dataset for training
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
input_dir: Directory with raw images
|
| 26 |
+
output_dir: Output directory for processed data
|
| 27 |
+
image_size: Target image size
|
| 28 |
+
min_resolution: Minimum acceptable resolution
|
| 29 |
+
filter_low_quality: Filter out low quality images
|
| 30 |
+
"""
|
| 31 |
+
input_path = Path(input_dir)
|
| 32 |
+
output_path = Path(output_dir)
|
| 33 |
+
|
| 34 |
+
# Create output directories
|
| 35 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
(output_path / "images").mkdir(exist_ok=True)
|
| 37 |
+
(output_path / "captions").mkdir(exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# Find all images
|
| 40 |
+
image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
|
| 41 |
+
image_files = []
|
| 42 |
+
|
| 43 |
+
for ext in image_extensions:
|
| 44 |
+
image_files.extend(input_path.glob(f"*{ext}"))
|
| 45 |
+
image_files.extend(input_path.glob(f"**/*{ext}"))
|
| 46 |
+
|
| 47 |
+
print(f"Found {len(image_files)} images")
|
| 48 |
+
|
| 49 |
+
# Process each image
|
| 50 |
+
processed_count = 0
|
| 51 |
+
skipped_count = 0
|
| 52 |
+
|
| 53 |
+
for img_file in image_files:
|
| 54 |
+
try:
|
| 55 |
+
process_image(
|
| 56 |
+
img_path=img_file,
|
| 57 |
+
output_img_path=output_path / "images" / f"{img_file.stem}.jpg",
|
| 58 |
+
caption_path=output_path / "captions" / f"{img_file.stem}.txt",
|
| 59 |
+
image_size=image_size,
|
| 60 |
+
min_resolution=min_resolution,
|
| 61 |
+
filter_low_quality=filter_low_quality,
|
| 62 |
+
)
|
| 63 |
+
processed_count += 1
|
| 64 |
+
|
| 65 |
+
if processed_count % 10 == 0:
|
| 66 |
+
print(f"Processed: {processed_count}/{len(image_files)}")
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error processing {img_file}: {e}")
|
| 70 |
+
skipped_count += 1
|
| 71 |
+
|
| 72 |
+
# Save metadata
|
| 73 |
+
metadata = {
|
| 74 |
+
'total_images': processed_count,
|
| 75 |
+
'skipped_images': skipped_count,
|
| 76 |
+
'image_size': image_size,
|
| 77 |
+
'min_resolution': min_resolution,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
with open(output_path / "metadata.json", 'w') as f:
|
| 81 |
+
json.dump(metadata, f, indent=2)
|
| 82 |
+
|
| 83 |
+
print(f"\n✓ Dataset preparation complete!")
|
| 84 |
+
print(f" Processed: {processed_count} images")
|
| 85 |
+
print(f" Skipped: {skipped_count} images")
|
| 86 |
+
print(f" Output: {output_path}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def process_image(
|
| 90 |
+
img_path: Path,
|
| 91 |
+
output_img_path: Path,
|
| 92 |
+
caption_path: Path,
|
| 93 |
+
image_size: int = 512,
|
| 94 |
+
min_resolution: int = 256,
|
| 95 |
+
filter_low_quality: bool = True,
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
Process single image
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
img_path: Input image path
|
| 102 |
+
output_img_path: Output image path
|
| 103 |
+
caption_path: Output caption path
|
| 104 |
+
image_size: Target size
|
| 105 |
+
min_resolution: Minimum resolution
|
| 106 |
+
filter_low_quality: Filter low quality
|
| 107 |
+
"""
|
| 108 |
+
# Load image
|
| 109 |
+
image = Image.open(img_path).convert('RGB')
|
| 110 |
+
|
| 111 |
+
# Check resolution
|
| 112 |
+
width, height = image.size
|
| 113 |
+
|
| 114 |
+
if width < min_resolution or height < min_resolution:
|
| 115 |
+
raise ValueError(f"Image too small: {width}x{height}")
|
| 116 |
+
|
| 117 |
+
# Resize if necessary
|
| 118 |
+
if min(width, height) > image_size * 1.5:
|
| 119 |
+
# Downscale large images
|
| 120 |
+
scale = image_size / max(width, height)
|
| 121 |
+
new_width = int(width * scale)
|
| 122 |
+
new_height = int(height * scale)
|
| 123 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 124 |
+
|
| 125 |
+
# Center crop to square
|
| 126 |
+
size = min(image.size)
|
| 127 |
+
left = (image.size[0] - size) // 2
|
| 128 |
+
top = (image.size[1] - size) // 2
|
| 129 |
+
image = image.crop((left, top, left + size, top + size))
|
| 130 |
+
|
| 131 |
+
# Resize to target size
|
| 132 |
+
image = image.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
| 133 |
+
|
| 134 |
+
# Save processed image
|
| 135 |
+
image.save(output_img_path, quality=95, optimize=True)
|
| 136 |
+
|
| 137 |
+
# Generate or load caption
|
| 138 |
+
caption = generate_caption(img_path)
|
| 139 |
+
|
| 140 |
+
with open(caption_path, 'w', encoding='utf-8') as f:
|
| 141 |
+
f.write(caption)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def generate_caption(img_path: Path) -> str:
|
| 145 |
+
"""
|
| 146 |
+
Generate caption from image filename or load from adjacent text file
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
img_path: Path to image
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Caption text
|
| 153 |
+
"""
|
| 154 |
+
# Try to load from adjacent .txt file
|
| 155 |
+
txt_file = img_path.with_suffix('.txt')
|
| 156 |
+
|
| 157 |
+
if txt_file.exists():
|
| 158 |
+
with open(txt_file, 'r', encoding='utf-8') as f:
|
| 159 |
+
caption = f.read().strip()
|
| 160 |
+
if caption:
|
| 161 |
+
return caption
|
| 162 |
+
|
| 163 |
+
# Use filename as fallback
|
| 164 |
+
caption = img_path.stem.replace('_', ' ').replace('-', ' ')
|
| 165 |
+
|
| 166 |
+
# Capitalize first letter
|
| 167 |
+
caption = caption.capitalize()
|
| 168 |
+
|
| 169 |
+
return caption
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def create_training_splits(
|
| 173 |
+
data_dir: str,
|
| 174 |
+
train_ratio: float = 0.9,
|
| 175 |
+
val_ratio: float = 0.05,
|
| 176 |
+
test_ratio: float = 0.05,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Create train/val/test splits
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
data_dir: Directory with processed data
|
| 183 |
+
train_ratio: Training set ratio
|
| 184 |
+
val_ratio: Validation set ratio
|
| 185 |
+
test_ratio: Test set ratio
|
| 186 |
+
"""
|
| 187 |
+
data_path = Path(data_dir)
|
| 188 |
+
|
| 189 |
+
# Get all images
|
| 190 |
+
images = list((data_path / "images").glob("*.jpg"))
|
| 191 |
+
|
| 192 |
+
# Shuffle deterministically
|
| 193 |
+
import random
|
| 194 |
+
random.seed(42)
|
| 195 |
+
random.shuffle(images)
|
| 196 |
+
|
| 197 |
+
# Calculate split sizes
|
| 198 |
+
total = len(images)
|
| 199 |
+
train_size = int(total * train_ratio)
|
| 200 |
+
val_size = int(total * val_ratio)
|
| 201 |
+
|
| 202 |
+
# Split datasets
|
| 203 |
+
train_images = images[:train_size]
|
| 204 |
+
val_images = images[train_size:train_size + val_size]
|
| 205 |
+
test_images = images[train_size + val_size:]
|
| 206 |
+
|
| 207 |
+
# Save splits
|
| 208 |
+
def save_split(image_list, split_name):
|
| 209 |
+
split_data = {
|
| 210 |
+
'images': [str(img.name) for img in image_list],
|
| 211 |
+
'count': len(image_list),
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
with open(data_path / f"{split_name}.json", 'w') as f:
|
| 215 |
+
json.dump(split_data, f, indent=2)
|
| 216 |
+
|
| 217 |
+
print(f"{split_name}: {len(image_list)} images")
|
| 218 |
+
|
| 219 |
+
save_split(train_images, "train")
|
| 220 |
+
save_split(val_images, "validation")
|
| 221 |
+
save_split(test_images, "test")
|
| 222 |
+
|
| 223 |
+
print(f"\n✓ Created training splits")
|
| 224 |
+
print(f" Total: {total} images")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def main():
|
| 228 |
+
parser = argparse.ArgumentParser(description="Prepare dataset for Byte Dream training")
|
| 229 |
+
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--input", "-i",
|
| 232 |
+
type=str,
|
| 233 |
+
required=True,
|
| 234 |
+
help="Input directory with raw images"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--output", "-o",
|
| 239 |
+
type=str,
|
| 240 |
+
default="./processed_data",
|
| 241 |
+
help="Output directory for processed data"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--size", "-s",
|
| 246 |
+
type=int,
|
| 247 |
+
default=512,
|
| 248 |
+
help="Target image size (default: 512)"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--min_res",
|
| 253 |
+
type=int,
|
| 254 |
+
default=256,
|
| 255 |
+
help="Minimum image resolution (default: 256)"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--no_filter",
|
| 260 |
+
action="store_true",
|
| 261 |
+
help="Disable low quality filtering"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--create_splits",
|
| 266 |
+
action="store_true",
|
| 267 |
+
help="Create train/val/test splits"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
|
| 272 |
+
# Prepare dataset
|
| 273 |
+
prepare_dataset(
|
| 274 |
+
input_dir=args.input,
|
| 275 |
+
output_dir=args.output,
|
| 276 |
+
image_size=args.size,
|
| 277 |
+
min_resolution=args.min_res,
|
| 278 |
+
filter_low_quality=not args.no_filter,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Create splits if requested
|
| 282 |
+
if args.create_splits:
|
| 283 |
+
create_training_splits(args.output)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main()
|
publish_to_hf.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Upload Byte Dream to Hugging Face Hub
|
| 3 |
+
Using login() and upload_folder() methods
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from huggingface_hub import login, upload_folder
|
| 8 |
+
|
| 9 |
+
# Get token from command line argument or prompt
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
if len(sys.argv) > 1:
|
| 13 |
+
token = sys.argv[1]
|
| 14 |
+
else:
|
| 15 |
+
token = input("Enter your Hugging Face token: ")
|
| 16 |
+
|
| 17 |
+
# Login with your Hugging Face token
|
| 18 |
+
print("Logging in to Hugging Face...")
|
| 19 |
+
login(token=token)
|
| 20 |
+
|
| 21 |
+
# Push your model files
|
| 22 |
+
print("\nUploading model to Hugging Face Hub...")
|
| 23 |
+
upload_folder(
|
| 24 |
+
folder_path=".",
|
| 25 |
+
repo_id="Enzo8930302/ByteDream",
|
| 26 |
+
repo_type="model"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
print("\n✓ Model uploaded successfully!")
|
| 30 |
+
print("📦 View your model at: https://huggingface.co/Enzo8930302/ByteDream")
|
quick_start.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick Start Script
|
| 3 |
+
Setup and test Byte Dream in one command
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def check_requirements():
|
| 12 |
+
"""Check if requirements are installed"""
|
| 13 |
+
print("Checking requirements...")
|
| 14 |
+
|
| 15 |
+
required = [
|
| 16 |
+
'torch',
|
| 17 |
+
'transformers',
|
| 18 |
+
'diffusers',
|
| 19 |
+
'pillow',
|
| 20 |
+
'numpy',
|
| 21 |
+
'gradio',
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
missing = []
|
| 25 |
+
|
| 26 |
+
for package in required:
|
| 27 |
+
try:
|
| 28 |
+
__import__(package.replace('-', '_'))
|
| 29 |
+
print(f" ✓ {package}")
|
| 30 |
+
except ImportError:
|
| 31 |
+
print(f" ✗ {package} - MISSING")
|
| 32 |
+
missing.append(package)
|
| 33 |
+
|
| 34 |
+
if missing:
|
| 35 |
+
print(f"\nMissing packages: {', '.join(missing)}")
|
| 36 |
+
print("\nInstall with:")
|
| 37 |
+
print(" pip install -r requirements.txt")
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
print("\n✓ All requirements satisfied!")
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_model():
|
| 45 |
+
"""Test model generation"""
|
| 46 |
+
print("\n" + "="*60)
|
| 47 |
+
print("Testing Byte Dream Model")
|
| 48 |
+
print("="*60)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
from bytedream.generator import ByteDreamGenerator
|
| 52 |
+
|
| 53 |
+
print("\nInitializing generator...")
|
| 54 |
+
generator = ByteDreamGenerator(device="cpu")
|
| 55 |
+
|
| 56 |
+
print("\nModel info:")
|
| 57 |
+
info = generator.get_model_info()
|
| 58 |
+
for key, value in info.items():
|
| 59 |
+
print(f" {key}: {value}")
|
| 60 |
+
|
| 61 |
+
print("\nGenerating test image...")
|
| 62 |
+
print("Prompt: A simple test pattern, geometric shapes")
|
| 63 |
+
|
| 64 |
+
image = generator.generate(
|
| 65 |
+
prompt="A simple test pattern, geometric shapes, abstract art",
|
| 66 |
+
width=256,
|
| 67 |
+
height=256,
|
| 68 |
+
num_inference_steps=20,
|
| 69 |
+
seed=42,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
output_path = Path("test_output.png")
|
| 73 |
+
image.save(output_path)
|
| 74 |
+
|
| 75 |
+
print(f"\n✓ Test successful!")
|
| 76 |
+
print(f" Image saved to: {output_path.absolute()}")
|
| 77 |
+
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"\n✗ Error: {e}")
|
| 82 |
+
print("\nThe model needs to be trained or pretrained weights downloaded.")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def download_pretrained():
|
| 87 |
+
"""Download pretrained model from Hugging Face"""
|
| 88 |
+
print("\n" + "="*60)
|
| 89 |
+
print("Downloading Pretrained Model")
|
| 90 |
+
print("="*60)
|
| 91 |
+
|
| 92 |
+
print("\nTo download a pretrained model:")
|
| 93 |
+
print("1. Visit https://huggingface.co/models")
|
| 94 |
+
print("2. Search for 'stable-diffusion' or similar")
|
| 95 |
+
print("3. Download using:")
|
| 96 |
+
print("\n from huggingface_hub import snapshot_download")
|
| 97 |
+
print(" snapshot_download(repo_id='username/model', local_dir='./models/bytedream')")
|
| 98 |
+
print("\nOr train your own model with:")
|
| 99 |
+
print(" python train.py --train_data ./dataset --output_dir ./models/bytedream")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
print("="*60)
|
| 104 |
+
print("Byte Dream - Quick Start")
|
| 105 |
+
print("="*60)
|
| 106 |
+
|
| 107 |
+
# Check requirements
|
| 108 |
+
if not check_requirements():
|
| 109 |
+
print("\n⚠ Please install requirements first")
|
| 110 |
+
sys.exit(1)
|
| 111 |
+
|
| 112 |
+
# Test model
|
| 113 |
+
if test_model():
|
| 114 |
+
print("\n✓ Byte Dream is ready to use!")
|
| 115 |
+
print("\nNext steps:")
|
| 116 |
+
print(" - Run: python infer.py --prompt 'Your prompt here'")
|
| 117 |
+
print(" - Run: python app.py (for web interface)")
|
| 118 |
+
print(" - Run: python main.py --demo (for demo)")
|
| 119 |
+
else:
|
| 120 |
+
download_pretrained()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.35.0
|
| 2 |
+
diffusers>=0.24.0
|
| 3 |
+
torch>=2.1.0
|
| 4 |
+
torchaudio>=2.1.0
|
| 5 |
+
accelerate>=0.25.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pillow>=10.0.0
|
| 8 |
+
opencv-python>=4.8.0
|
| 9 |
+
safetensors>=0.4.0
|
| 10 |
+
huggingface_hub>=0.19.0
|
| 11 |
+
gradio>=4.0.0
|
| 12 |
+
tqdm>=4.66.0
|
| 13 |
+
pyyaml>=6.0
|
| 14 |
+
matplotlib>=3.8.0
|
| 15 |
+
scipy>=1.11.0
|
| 16 |
+
einops>=0.7.0
|
train.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte Dream Training Pipeline
|
| 3 |
+
Complete training system for diffusion models with CPU optimization
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import yaml
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Tuple, List, Optional
|
| 19 |
+
import gc
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ImageTextDataset(Dataset):
|
| 23 |
+
"""
|
| 24 |
+
Dataset for image-text pairs
|
| 25 |
+
Supports various data augmentations for better generalization
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
data_dir: str,
|
| 31 |
+
image_size: int = 512,
|
| 32 |
+
random_flip: bool = True,
|
| 33 |
+
random_crop: bool = False,
|
| 34 |
+
center_crop: bool = True,
|
| 35 |
+
):
|
| 36 |
+
self.data_dir = Path(data_dir)
|
| 37 |
+
self.image_paths = list(self.data_dir.glob("*.jpg")) + \
|
| 38 |
+
list(self.data_dir.glob("*.png")) + \
|
| 39 |
+
list(self.data_dir.glob("*.jpeg"))
|
| 40 |
+
|
| 41 |
+
self.image_size = image_size
|
| 42 |
+
self.random_flip = random_flip
|
| 43 |
+
self.random_crop = random_crop
|
| 44 |
+
self.center_crop = center_crop
|
| 45 |
+
|
| 46 |
+
# Transformations
|
| 47 |
+
self.transform = self._get_transform()
|
| 48 |
+
|
| 49 |
+
# Load captions
|
| 50 |
+
self.captions = self._load_captions()
|
| 51 |
+
|
| 52 |
+
def _get_transform(self) -> transforms.Compose:
|
| 53 |
+
"""Get image transformation pipeline"""
|
| 54 |
+
transforms_list = []
|
| 55 |
+
|
| 56 |
+
if self.random_crop:
|
| 57 |
+
transforms_list.append(transforms.RandomCrop(self.image_size))
|
| 58 |
+
elif self.center_crop:
|
| 59 |
+
transforms_list.append(transforms.CenterCrop(self.image_size))
|
| 60 |
+
else:
|
| 61 |
+
transforms_list.append(transforms.Resize((self.image_size, self.image_size)))
|
| 62 |
+
|
| 63 |
+
if self.random_flip:
|
| 64 |
+
transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
|
| 65 |
+
|
| 66 |
+
transforms_list.extend([
|
| 67 |
+
transforms.ToTensor(),
|
| 68 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
return transforms.Compose(transforms_list)
|
| 72 |
+
|
| 73 |
+
def _load_captions(self) -> dict:
|
| 74 |
+
"""Load captions from text files"""
|
| 75 |
+
captions = {}
|
| 76 |
+
|
| 77 |
+
for img_path in self.image_paths:
|
| 78 |
+
caption_path = img_path.with_suffix('.txt')
|
| 79 |
+
if caption_path.exists():
|
| 80 |
+
with open(caption_path, 'r', encoding='utf-8') as f:
|
| 81 |
+
captions[str(img_path)] = f.read().strip()
|
| 82 |
+
else:
|
| 83 |
+
# Use filename as caption if no text file
|
| 84 |
+
captions[str(img_path)] = img_path.stem.replace('_', ' ')
|
| 85 |
+
|
| 86 |
+
return captions
|
| 87 |
+
|
| 88 |
+
def __len__(self) -> int:
|
| 89 |
+
return len(self.image_paths)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx: int) -> dict:
|
| 92 |
+
img_path = self.image_paths[idx]
|
| 93 |
+
|
| 94 |
+
# Load image
|
| 95 |
+
try:
|
| 96 |
+
image = Image.open(img_path).convert('RGB')
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"Error loading image {img_path}: {e}")
|
| 99 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 100 |
+
|
| 101 |
+
# Transform image
|
| 102 |
+
pixel_values = self.transform(image)
|
| 103 |
+
|
| 104 |
+
# Get caption
|
| 105 |
+
caption = self.captions.get(str(img_path), "")
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
"pixel_values": pixel_values,
|
| 109 |
+
"input_ids": caption,
|
| 110 |
+
"image_path": str(img_path),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class LatentDiffusionTrainer:
|
| 115 |
+
"""
|
| 116 |
+
Trainer for latent diffusion models
|
| 117 |
+
Implements training loop with mixed precision and gradient accumulation
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
unet: nn.Module,
|
| 123 |
+
vae: nn.Module,
|
| 124 |
+
text_encoder: nn.Module,
|
| 125 |
+
scheduler,
|
| 126 |
+
config: dict,
|
| 127 |
+
device: str = "cpu",
|
| 128 |
+
):
|
| 129 |
+
self.unet = unet
|
| 130 |
+
self.vae = vae
|
| 131 |
+
self.text_encoder = text_encoder
|
| 132 |
+
self.scheduler = scheduler
|
| 133 |
+
self.config = config
|
| 134 |
+
self.device = torch.device(device)
|
| 135 |
+
|
| 136 |
+
# Training parameters
|
| 137 |
+
self.epochs = config['training']['epochs']
|
| 138 |
+
self.batch_size = config['training']['batch_size']
|
| 139 |
+
self.learning_rate = config['training']['learning_rate']
|
| 140 |
+
self.gradient_accumulation_steps = config['training']['gradient_accumulation_steps']
|
| 141 |
+
self.max_grad_norm = config['training']['max_grad_norm']
|
| 142 |
+
|
| 143 |
+
# Mixed precision
|
| 144 |
+
self.mixed_precision = config['training']['mixed_precision']
|
| 145 |
+
self.use_amp = self.mixed_precision != "no"
|
| 146 |
+
|
| 147 |
+
# Output directories
|
| 148 |
+
self.output_dir = Path(config['training']['output_dir'])
|
| 149 |
+
self.logging_dir = Path(config['training']['logging_dir'])
|
| 150 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
self.logging_dir.mkdir(parents=True, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
# Initialize optimizer
|
| 154 |
+
self.optimizer = torch.optim.AdamW(
|
| 155 |
+
unet.parameters(),
|
| 156 |
+
lr=self.learning_rate,
|
| 157 |
+
betas=(0.9, 0.999),
|
| 158 |
+
weight_decay=1e-2,
|
| 159 |
+
eps=1e-08,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Learning rate scheduler
|
| 163 |
+
self.lr_scheduler = self._create_lr_scheduler()
|
| 164 |
+
|
| 165 |
+
# Gradient scaler for mixed precision
|
| 166 |
+
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and torch.cuda.is_available() else None
|
| 167 |
+
|
| 168 |
+
# Move models to device
|
| 169 |
+
self._prepare_models()
|
| 170 |
+
|
| 171 |
+
def _prepare_models(self):
|
| 172 |
+
"""Prepare models for training"""
|
| 173 |
+
print(f"Preparing models on {self.device}...")
|
| 174 |
+
|
| 175 |
+
self.vae.to(self.device)
|
| 176 |
+
self.text_encoder.to(self.device)
|
| 177 |
+
self.unet.to(self.device)
|
| 178 |
+
|
| 179 |
+
# Set VAE and text encoder to eval mode (frozen)
|
| 180 |
+
self.vae.eval()
|
| 181 |
+
if hasattr(self.text_encoder, 'model'):
|
| 182 |
+
self.text_encoder.model.eval()
|
| 183 |
+
|
| 184 |
+
# Freeze VAE and text encoder parameters
|
| 185 |
+
for param in self.vae.parameters():
|
| 186 |
+
param.requires_grad = False
|
| 187 |
+
|
| 188 |
+
if hasattr(self.text_encoder, 'model'):
|
| 189 |
+
for param in self.text_encoder.model.parameters():
|
| 190 |
+
param.requires_grad = False
|
| 191 |
+
|
| 192 |
+
# Set UNet to train mode
|
| 193 |
+
self.unet.train()
|
| 194 |
+
|
| 195 |
+
def _create_lr_scheduler(self):
|
| 196 |
+
"""Create learning rate scheduler"""
|
| 197 |
+
sched_config = self.config['training']
|
| 198 |
+
|
| 199 |
+
if sched_config['lr_scheduler'] == "constant_with_warmup":
|
| 200 |
+
return torch.optim.lr_scheduler.ConstantLR(
|
| 201 |
+
self.optimizer,
|
| 202 |
+
factor=1.0,
|
| 203 |
+
total_iters=sched_config['lr_warmup_steps'],
|
| 204 |
+
)
|
| 205 |
+
elif sched_config['lr_scheduler'] == "linear":
|
| 206 |
+
return torch.optim.lr_scheduler.LinearLR(
|
| 207 |
+
self.optimizer,
|
| 208 |
+
start_factor=0.1,
|
| 209 |
+
end_factor=1.0,
|
| 210 |
+
total_iters=sched_config['lr_warmup_steps'],
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
return torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
|
| 214 |
+
|
| 215 |
+
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
|
| 216 |
+
"""Encode images to latent space"""
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
latents = self.vae.encode(images)
|
| 219 |
+
latents = latents * 0.18215 # Scale factor
|
| 220 |
+
return latents
|
| 221 |
+
|
| 222 |
+
def encode_text(self, texts: List[str]) -> torch.Tensor:
|
| 223 |
+
"""Encode text to embeddings"""
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
text_embeddings = self.text_encoder(texts, device=self.device)
|
| 226 |
+
return text_embeddings
|
| 227 |
+
|
| 228 |
+
def compute_loss(
|
| 229 |
+
self,
|
| 230 |
+
latents: torch.Tensor,
|
| 231 |
+
text_embeddings: torch.Tensor,
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
"""
|
| 234 |
+
Compute diffusion loss
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
latents: Latent representations of images
|
| 238 |
+
text_embeddings: Text embeddings
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Loss value
|
| 242 |
+
"""
|
| 243 |
+
batch_size = latents.shape[0]
|
| 244 |
+
|
| 245 |
+
# Sample random timesteps
|
| 246 |
+
timesteps = torch.randint(
|
| 247 |
+
0,
|
| 248 |
+
self.scheduler.num_train_timesteps,
|
| 249 |
+
(batch_size,),
|
| 250 |
+
device=self.device,
|
| 251 |
+
).long()
|
| 252 |
+
|
| 253 |
+
# Add noise to latents
|
| 254 |
+
noise = torch.randn_like(latents)
|
| 255 |
+
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
|
| 256 |
+
|
| 257 |
+
# Predict noise
|
| 258 |
+
timestep_tensor = timesteps
|
| 259 |
+
|
| 260 |
+
model_output = self.unet(
|
| 261 |
+
sample=noisy_latents,
|
| 262 |
+
timestep=timestep_tensor,
|
| 263 |
+
encoder_hidden_states=text_embeddings,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Compute loss
|
| 267 |
+
loss = F.mse_loss(model_output, noise, reduction="mean")
|
| 268 |
+
|
| 269 |
+
return loss
|
| 270 |
+
|
| 271 |
+
def train_step(
|
| 272 |
+
self,
|
| 273 |
+
batch: dict,
|
| 274 |
+
) -> float:
|
| 275 |
+
"""
|
| 276 |
+
Perform single training step
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
batch: Batch of data
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Loss value
|
| 283 |
+
"""
|
| 284 |
+
pixel_values = batch["pixel_values"].to(self.device)
|
| 285 |
+
input_ids = batch["input_ids"]
|
| 286 |
+
|
| 287 |
+
# Encode images and text
|
| 288 |
+
latents = self.encode_images(pixel_values)
|
| 289 |
+
text_embeddings = self.encode_text(input_ids)
|
| 290 |
+
|
| 291 |
+
# Compute loss
|
| 292 |
+
if self.use_amp and self.scaler is not None:
|
| 293 |
+
with torch.cuda.amp.autocast():
|
| 294 |
+
loss = self.compute_loss(latents, text_embeddings)
|
| 295 |
+
loss = loss / self.gradient_accumulation_steps
|
| 296 |
+
|
| 297 |
+
self.scaler.scale(loss).backward()
|
| 298 |
+
else:
|
| 299 |
+
loss = self.compute_loss(latents, text_embeddings)
|
| 300 |
+
loss = loss / self.gradient_accumulation_steps
|
| 301 |
+
loss.backward()
|
| 302 |
+
|
| 303 |
+
return loss.item() * self.gradient_accumulation_steps
|
| 304 |
+
|
| 305 |
+
def save_checkpoint(self, epoch: int, step: int):
|
| 306 |
+
"""Save model checkpoint"""
|
| 307 |
+
checkpoint_dir = self.output_dir / f"checkpoint-{epoch}-{step}"
|
| 308 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 309 |
+
|
| 310 |
+
# Save UNet
|
| 311 |
+
torch.save({
|
| 312 |
+
'epoch': epoch,
|
| 313 |
+
'step': step,
|
| 314 |
+
'unet_state_dict': self.unet.state_dict(),
|
| 315 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 316 |
+
'scheduler_state_dict': self.lr_scheduler.state_dict() if self.lr_scheduler else None,
|
| 317 |
+
}, checkpoint_dir / "pytorch_model.bin")
|
| 318 |
+
|
| 319 |
+
# Save config
|
| 320 |
+
with open(checkpoint_dir / "config.yaml", 'w') as f:
|
| 321 |
+
yaml.dump(self.config, f)
|
| 322 |
+
|
| 323 |
+
print(f"Checkpoint saved to {checkpoint_dir}")
|
| 324 |
+
|
| 325 |
+
def train(self, resume_from_checkpoint: Optional[str] = None):
|
| 326 |
+
"""
|
| 327 |
+
Main training loop
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
resume_from_checkpoint: Path to checkpoint to resume from
|
| 331 |
+
"""
|
| 332 |
+
# Create dataset and dataloader
|
| 333 |
+
train_config = self.config['training']
|
| 334 |
+
|
| 335 |
+
dataset = ImageTextDataset(
|
| 336 |
+
data_dir=train_config['dataset_path'],
|
| 337 |
+
image_size=512,
|
| 338 |
+
random_flip=train_config['random_flip'],
|
| 339 |
+
random_crop=train_config['random_crop'],
|
| 340 |
+
center_crop=train_config['center_crop'],
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
dataloader = DataLoader(
|
| 344 |
+
dataset,
|
| 345 |
+
batch_size=self.batch_size,
|
| 346 |
+
shuffle=True,
|
| 347 |
+
num_workers=0, # CPU training
|
| 348 |
+
pin_memory=False,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Resume from checkpoint
|
| 352 |
+
start_epoch = 0
|
| 353 |
+
global_step = 0
|
| 354 |
+
|
| 355 |
+
if resume_from_checkpoint:
|
| 356 |
+
print(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| 357 |
+
checkpoint = torch.load(resume_from_checkpoint, map_location=self.device)
|
| 358 |
+
self.unet.load_state_dict(checkpoint['unet_state_dict'])
|
| 359 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 360 |
+
if checkpoint['scheduler_state_dict']:
|
| 361 |
+
self.lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 362 |
+
start_epoch = checkpoint['epoch']
|
| 363 |
+
global_step = checkpoint['step']
|
| 364 |
+
|
| 365 |
+
# Training loop
|
| 366 |
+
total_steps = len(dataloader) * self.epochs
|
| 367 |
+
|
| 368 |
+
print(f"Starting training for {self.epochs} epochs...")
|
| 369 |
+
print(f"Total steps: {total_steps}")
|
| 370 |
+
print(f"Batch size: {self.batch_size}")
|
| 371 |
+
print(f"Mixed precision: {self.mixed_precision}")
|
| 372 |
+
|
| 373 |
+
for epoch in range(start_epoch, self.epochs):
|
| 374 |
+
self.unet.train()
|
| 375 |
+
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{self.epochs}")
|
| 376 |
+
|
| 377 |
+
epoch_loss = 0
|
| 378 |
+
num_steps = 0
|
| 379 |
+
|
| 380 |
+
for step, batch in enumerate(progress_bar):
|
| 381 |
+
# Training step
|
| 382 |
+
loss = self.train_step(batch)
|
| 383 |
+
epoch_loss += loss
|
| 384 |
+
num_steps += 1
|
| 385 |
+
|
| 386 |
+
# Gradient clipping and optimizer step
|
| 387 |
+
if (step + 1) % self.gradient_accumulation_steps == 0:
|
| 388 |
+
if self.use_amp and self.scaler is not None:
|
| 389 |
+
self.scaler.unscale_(self.optimizer)
|
| 390 |
+
torch.nn.utils.clip_grad_norm_(
|
| 391 |
+
self.unet.parameters(),
|
| 392 |
+
self.max_grad_norm,
|
| 393 |
+
)
|
| 394 |
+
self.scaler.step(self.optimizer)
|
| 395 |
+
self.scaler.update()
|
| 396 |
+
else:
|
| 397 |
+
torch.nn.utils.clip_grad_norm_(self.unet.parameters(), self.max_grad_norm)
|
| 398 |
+
self.optimizer.step()
|
| 399 |
+
|
| 400 |
+
# Learning rate scheduling
|
| 401 |
+
if self.lr_scheduler:
|
| 402 |
+
self.lr_scheduler.step()
|
| 403 |
+
|
| 404 |
+
# Zero gradients
|
| 405 |
+
self.optimizer.zero_grad()
|
| 406 |
+
|
| 407 |
+
# Update progress bar
|
| 408 |
+
avg_loss = epoch_loss / num_steps
|
| 409 |
+
progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
|
| 410 |
+
|
| 411 |
+
# Logging
|
| 412 |
+
if (global_step + 1) % self.config['training']['log_every_n_steps'] == 0:
|
| 413 |
+
print(f"\nStep {global_step + 1}: Loss = {avg_loss:.4f}")
|
| 414 |
+
|
| 415 |
+
# Save checkpoint periodically
|
| 416 |
+
if (global_step + 1) % 1000 == 0:
|
| 417 |
+
self.save_checkpoint(epoch, global_step)
|
| 418 |
+
|
| 419 |
+
global_step += 1
|
| 420 |
+
|
| 421 |
+
# End of epoch
|
| 422 |
+
avg_epoch_loss = epoch_loss / max(num_steps, 1)
|
| 423 |
+
print(f"\nEpoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}")
|
| 424 |
+
|
| 425 |
+
# Save epoch checkpoint
|
| 426 |
+
self.save_checkpoint(epoch, global_step)
|
| 427 |
+
|
| 428 |
+
# Clear memory
|
| 429 |
+
gc.collect()
|
| 430 |
+
if torch.cuda.is_available():
|
| 431 |
+
torch.cuda.empty_cache()
|
| 432 |
+
|
| 433 |
+
# Save final model
|
| 434 |
+
print("\nTraining completed!")
|
| 435 |
+
self.save_final_model()
|
| 436 |
+
|
| 437 |
+
def save_final_model(self):
|
| 438 |
+
"""Save final trained model"""
|
| 439 |
+
final_dir = self.output_dir / "final"
|
| 440 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 441 |
+
|
| 442 |
+
# Save UNet
|
| 443 |
+
torch.save({
|
| 444 |
+
'unet_state_dict': self.unet.state_dict(),
|
| 445 |
+
'config': self.config,
|
| 446 |
+
}, final_dir / "unet_pytorch_model.bin")
|
| 447 |
+
|
| 448 |
+
print(f"Final model saved to {final_dir}")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def main():
|
| 452 |
+
"""Main training function"""
|
| 453 |
+
parser = argparse.ArgumentParser(description="Train Byte Dream diffusion model")
|
| 454 |
+
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
|
| 455 |
+
parser.add_argument("--train_data", type=str, required=True, help="Path to training data")
|
| 456 |
+
parser.add_argument("--output_dir", type=str, default="./models/bytedream", help="Output directory")
|
| 457 |
+
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
|
| 458 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device to train on")
|
| 459 |
+
|
| 460 |
+
args = parser.parse_args()
|
| 461 |
+
|
| 462 |
+
# Load config
|
| 463 |
+
with open(args.config, 'r') as f:
|
| 464 |
+
config = yaml.safe_load(f)
|
| 465 |
+
|
| 466 |
+
# Override config with command line arguments
|
| 467 |
+
config['training']['dataset_path'] = args.train_data
|
| 468 |
+
config['training']['output_dir'] = args.output_dir
|
| 469 |
+
|
| 470 |
+
# Import model components
|
| 471 |
+
from bytedream.model import create_unet, create_vae, create_text_encoder
|
| 472 |
+
from bytedream.scheduler import create_scheduler
|
| 473 |
+
|
| 474 |
+
# Create components
|
| 475 |
+
print("Creating model components...")
|
| 476 |
+
unet = create_unet(config)
|
| 477 |
+
vae = create_vae(config)
|
| 478 |
+
text_encoder = create_text_encoder(config)
|
| 479 |
+
scheduler = create_scheduler(config)
|
| 480 |
+
|
| 481 |
+
# Count parameters
|
| 482 |
+
total_params = sum(p.numel() for p in unet.parameters())
|
| 483 |
+
print(f"UNet parameters: {total_params:,}")
|
| 484 |
+
|
| 485 |
+
# Create trainer
|
| 486 |
+
trainer = LatentDiffusionTrainer(
|
| 487 |
+
unet=unet,
|
| 488 |
+
vae=vae,
|
| 489 |
+
text_encoder=text_encoder,
|
| 490 |
+
scheduler=scheduler,
|
| 491 |
+
config=config,
|
| 492 |
+
device=args.device,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# Start training
|
| 496 |
+
trainer.train(resume_from_checkpoint=args.resume)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
if __name__ == "__main__":
|
| 500 |
+
main()
|
upload_to_hf.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Integration
|
| 3 |
+
Upload and deploy Byte Dream to Hugging Face Hub and Spaces
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def upload_to_huggingface(
|
| 12 |
+
model_path: str,
|
| 13 |
+
repo_id: str,
|
| 14 |
+
token: str = None,
|
| 15 |
+
private: bool = False,
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Upload model to Hugging Face Hub
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model_path: Path to model directory
|
| 22 |
+
repo_id: Repository ID (username/model-name)
|
| 23 |
+
token: Hugging Face API token
|
| 24 |
+
private: Whether to make repository private
|
| 25 |
+
"""
|
| 26 |
+
from huggingface_hub import HfApi, create_repo
|
| 27 |
+
|
| 28 |
+
print(f"Uploading model to Hugging Face Hub...")
|
| 29 |
+
print(f"Repository: {repo_id}")
|
| 30 |
+
|
| 31 |
+
# Initialize API
|
| 32 |
+
api = HfApi()
|
| 33 |
+
|
| 34 |
+
# Create repository
|
| 35 |
+
try:
|
| 36 |
+
create_repo(
|
| 37 |
+
repo_id=repo_id,
|
| 38 |
+
token=token,
|
| 39 |
+
private=private,
|
| 40 |
+
exist_ok=True,
|
| 41 |
+
repo_type="model",
|
| 42 |
+
)
|
| 43 |
+
print("✓ Repository created/verified")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error creating repository: {e}")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
# Upload model files
|
| 49 |
+
model_dir = Path(model_path)
|
| 50 |
+
|
| 51 |
+
if not model_dir.exists():
|
| 52 |
+
print(f"Error: Model directory {model_dir} does not exist")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
print("\nUploading files...")
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# Upload entire directory
|
| 59 |
+
api.upload_folder(
|
| 60 |
+
folder_path=str(model_dir),
|
| 61 |
+
repo_id=repo_id,
|
| 62 |
+
token=token,
|
| 63 |
+
repo_type="model",
|
| 64 |
+
)
|
| 65 |
+
print("✓ Model uploaded successfully!")
|
| 66 |
+
|
| 67 |
+
# Print repository URL
|
| 68 |
+
print(f"\n📦 View your model at:")
|
| 69 |
+
print(f"https://huggingface.co/{repo_id}")
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Error uploading model: {e}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def create_gradio_app():
|
| 76 |
+
"""Create Gradio app for Hugging Face Spaces"""
|
| 77 |
+
gradio_code = '''"""
|
| 78 |
+
Byte Dream - Gradio Web Interface
|
| 79 |
+
Deploy on Hugging Face Spaces
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
import gradio as gr
|
| 83 |
+
from bytedream.generator import ByteDreamGenerator
|
| 84 |
+
import torch
|
| 85 |
+
|
| 86 |
+
# Initialize generator
|
| 87 |
+
print("Loading Byte Dream model...")
|
| 88 |
+
generator = ByteDreamGenerator(
|
| 89 |
+
model_path="./models/bytedream",
|
| 90 |
+
config_path="config.yaml",
|
| 91 |
+
device="cpu",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def generate_image(
|
| 95 |
+
prompt,
|
| 96 |
+
negative_prompt,
|
| 97 |
+
width,
|
| 98 |
+
height,
|
| 99 |
+
num_steps,
|
| 100 |
+
guidance_scale,
|
| 101 |
+
seed,
|
| 102 |
+
):
|
| 103 |
+
"""Generate image from prompt"""
|
| 104 |
+
|
| 105 |
+
# Convert seed to None if -1
|
| 106 |
+
seed_value = None if seed == -1 else seed
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# Generate image
|
| 110 |
+
image = generator.generate(
|
| 111 |
+
prompt=prompt,
|
| 112 |
+
negative_prompt=negative_prompt if negative_prompt else None,
|
| 113 |
+
width=int(width),
|
| 114 |
+
height=int(height),
|
| 115 |
+
num_inference_steps=int(num_steps),
|
| 116 |
+
guidance_scale=float(guidance_scale),
|
| 117 |
+
seed=seed_value,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return image, "Success!"
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Error generating image: {e}")
|
| 124 |
+
return None, f"Error: {str(e)}"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Create Gradio interface
|
| 128 |
+
with gr.Blocks(title="Byte Dream - AI Image Generator", theme=gr.themes.Soft()) as demo:
|
| 129 |
+
gr.Markdown("""
|
| 130 |
+
# 🎨 Byte Dream - AI Image Generator
|
| 131 |
+
|
| 132 |
+
Generate stunning images from text descriptions using advanced diffusion models.
|
| 133 |
+
Optimized for CPU inference.
|
| 134 |
+
|
| 135 |
+
**Tips for better results:**
|
| 136 |
+
- Be specific and descriptive in your prompts
|
| 137 |
+
- Use negative prompts to avoid unwanted elements
|
| 138 |
+
- Higher steps = better quality but slower
|
| 139 |
+
- Adjust guidance scale for creativity vs accuracy
|
| 140 |
+
""")
|
| 141 |
+
|
| 142 |
+
with gr.Row():
|
| 143 |
+
with gr.Column(scale=1):
|
| 144 |
+
gr.Markdown("### 📝 Prompt")
|
| 145 |
+
prompt_input = gr.Textbox(
|
| 146 |
+
label="Prompt",
|
| 147 |
+
placeholder="A beautiful sunset over mountains, digital art, highly detailed",
|
| 148 |
+
lines=3,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
negative_prompt_input = gr.Textbox(
|
| 152 |
+
label="Negative Prompt (optional)",
|
| 153 |
+
placeholder="ugly, blurry, low quality, distorted",
|
| 154 |
+
lines=2,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
with gr.Row():
|
| 158 |
+
width_slider = gr.Slider(
|
| 159 |
+
minimum=256,
|
| 160 |
+
maximum=1024,
|
| 161 |
+
step=64,
|
| 162 |
+
value=512,
|
| 163 |
+
label="Width"
|
| 164 |
+
)
|
| 165 |
+
height_slider = gr.Slider(
|
| 166 |
+
minimum=256,
|
| 167 |
+
maximum=1024,
|
| 168 |
+
step=64,
|
| 169 |
+
value=512,
|
| 170 |
+
label="Height"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
with gr.Row():
|
| 174 |
+
steps_slider = gr.Slider(
|
| 175 |
+
minimum=10,
|
| 176 |
+
maximum=150,
|
| 177 |
+
step=5,
|
| 178 |
+
value=50,
|
| 179 |
+
label="Inference Steps"
|
| 180 |
+
)
|
| 181 |
+
guidance_slider = gr.Slider(
|
| 182 |
+
minimum=1.0,
|
| 183 |
+
maximum=20.0,
|
| 184 |
+
step=0.5,
|
| 185 |
+
value=7.5,
|
| 186 |
+
label="Guidance Scale"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
seed_input = gr.Number(
|
| 190 |
+
label="Seed (-1 for random)",
|
| 191 |
+
value=-1,
|
| 192 |
+
precision=0,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")
|
| 196 |
+
|
| 197 |
+
with gr.Column(scale=1):
|
| 198 |
+
gr.Markdown("### 🖼️ Generated Image")
|
| 199 |
+
output_image = gr.Image(
|
| 200 |
+
label="Generated Image",
|
| 201 |
+
type="pil",
|
| 202 |
+
)
|
| 203 |
+
status_text = gr.Textbox(label="Status")
|
| 204 |
+
|
| 205 |
+
# Examples
|
| 206 |
+
gr.Markdown("### 💡 Example Prompts")
|
| 207 |
+
gr.Examples(
|
| 208 |
+
examples=[
|
| 209 |
+
["A cyberpunk city at night with neon lights, futuristic architecture, flying cars, highly detailed, digital art"],
|
| 210 |
+
["A majestic dragon breathing fire, fantasy art, dramatic lighting, epic scene"],
|
| 211 |
+
["A peaceful cottage in a meadow, flowers, sunny day, studio ghibli style"],
|
| 212 |
+
["Portrait of a warrior princess, armor, fantasy, intricate details, character design"],
|
| 213 |
+
["Underwater coral reef, tropical fish, sunlight filtering through water, photorealistic"],
|
| 214 |
+
],
|
| 215 |
+
inputs=[prompt_input],
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Connect button
|
| 219 |
+
generate_btn.click(
|
| 220 |
+
fn=generate_image,
|
| 221 |
+
inputs=[
|
| 222 |
+
prompt_input,
|
| 223 |
+
negative_prompt_input,
|
| 224 |
+
width_slider,
|
| 225 |
+
height_slider,
|
| 226 |
+
steps_slider,
|
| 227 |
+
guidance_slider,
|
| 228 |
+
seed_input,
|
| 229 |
+
],
|
| 230 |
+
outputs=[output_image, status_text],
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
gr.Markdown("""
|
| 234 |
+
---
|
| 235 |
+
**Byte Dream** v1.0.0 | Powered by Latent Diffusion Models
|
| 236 |
+
""")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 241 |
+
'''
|
| 242 |
+
|
| 243 |
+
return gradio_code
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def create_readme_for_hf(repo_id: str):
|
| 247 |
+
"""Create README for Hugging Face repository"""
|
| 248 |
+
|
| 249 |
+
readme = f'''---
|
| 250 |
+
license: mit
|
| 251 |
+
language:
|
| 252 |
+
- en
|
| 253 |
+
tags:
|
| 254 |
+
- text-to-image
|
| 255 |
+
- diffusion
|
| 256 |
+
- generative-ai
|
| 257 |
+
- cpu-optimized
|
| 258 |
+
---
|
| 259 |
+
|
| 260 |
+
# {repo_id.split('/')[-1]}
|
| 261 |
+
|
| 262 |
+
{repo_id.split('/')[-1]} is a powerful text-to-image diffusion model optimized for CPU inference. Generate high-quality images from text prompts using advanced latent diffusion architecture.
|
| 263 |
+
|
| 264 |
+
## Features
|
| 265 |
+
|
| 266 |
+
- 🚀 **CPU Optimized**: Runs efficiently on CPU without GPU requirement
|
| 267 |
+
- 🎨 **High Quality**: Generates 512x512 and higher resolution images
|
| 268 |
+
- ⚡ **Fast Inference**: Optimized for speed with quality preservation
|
| 269 |
+
- 🔧 **Flexible**: Supports various sampling methods and customization
|
| 270 |
+
- 📦 **Easy to Use**: Simple Python API and web interface
|
| 271 |
+
|
| 272 |
+
## Installation
|
| 273 |
+
|
| 274 |
+
```bash
|
| 275 |
+
pip install -r requirements.txt
|
| 276 |
+
```
|
| 277 |
+
|
| 278 |
+
## Usage
|
| 279 |
+
|
| 280 |
+
### Python API
|
| 281 |
+
|
| 282 |
+
```python
|
| 283 |
+
from bytedream import ByteDreamGenerator
|
| 284 |
+
|
| 285 |
+
# Initialize generator
|
| 286 |
+
generator = ByteDreamGenerator()
|
| 287 |
+
|
| 288 |
+
# Generate image
|
| 289 |
+
image = generator.generate(
|
| 290 |
+
prompt="A beautiful sunset over mountains, digital art",
|
| 291 |
+
num_inference_steps=50,
|
| 292 |
+
guidance_scale=7.5
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
image.save("output.png")
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
### Command Line
|
| 299 |
+
|
| 300 |
+
```bash
|
| 301 |
+
python infer.py --prompt "A dragon flying over castle" --output dragon.png
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
### Web Interface
|
| 305 |
+
|
| 306 |
+
```bash
|
| 307 |
+
python app.py
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
## Model Details
|
| 311 |
+
|
| 312 |
+
- **Architecture**: Latent Diffusion Model (UNet + VAE + Text Encoder)
|
| 313 |
+
- **Parameters**: ~1.2B
|
| 314 |
+
- **Training**: Trained on diverse image-text pairs
|
| 315 |
+
- **Optimization**: CPU-optimized with efficient memory usage
|
| 316 |
+
|
| 317 |
+
## Examples
|
| 318 |
+
|
| 319 |
+
Try these prompts:
|
| 320 |
+
- "Cyberpunk city at night, neon lights, futuristic"
|
| 321 |
+
- "Fantasy landscape with mountains and waterfall"
|
| 322 |
+
- "Portrait of a warrior, detailed armor, dramatic lighting"
|
| 323 |
+
- "Abstract art, colorful, geometric shapes"
|
| 324 |
+
|
| 325 |
+
## Configuration
|
| 326 |
+
|
| 327 |
+
Edit `config.yaml` to customize:
|
| 328 |
+
- Model architecture parameters
|
| 329 |
+
- Generation settings (resolution, steps, guidance)
|
| 330 |
+
- CPU optimization options
|
| 331 |
+
|
| 332 |
+
## License
|
| 333 |
+
|
| 334 |
+
MIT License
|
| 335 |
+
|
| 336 |
+
## Acknowledgments
|
| 337 |
+
|
| 338 |
+
Built with:
|
| 339 |
+
- [PyTorch](https://pytorch.org/)
|
| 340 |
+
- [Hugging Face Diffusers](https://github.com/huggingface/diffusers)
|
| 341 |
+
- [CLIP](https://openai.com/research/clip)
|
| 342 |
+
|
| 343 |
+
Enjoy creating with Byte Dream! 🎨
|
| 344 |
+
'''
|
| 345 |
+
|
| 346 |
+
return readme
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def main():
|
| 350 |
+
parser = argparse.ArgumentParser(description="Upload Byte Dream to Hugging Face")
|
| 351 |
+
|
| 352 |
+
parser.add_argument(
|
| 353 |
+
"--model_path",
|
| 354 |
+
type=str,
|
| 355 |
+
default="./models/bytedream",
|
| 356 |
+
help="Path to model directory"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
parser.add_argument(
|
| 360 |
+
"--repo_id",
|
| 361 |
+
type=str,
|
| 362 |
+
required=True,
|
| 363 |
+
help="Repository ID (e.g., username/bytedream)"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--token",
|
| 368 |
+
type=str,
|
| 369 |
+
default=None,
|
| 370 |
+
help="Hugging Face API token"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
parser.add_argument(
|
| 374 |
+
"--private",
|
| 375 |
+
action="store_true",
|
| 376 |
+
help="Make repository private"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--create_space",
|
| 381 |
+
action="store_true",
|
| 382 |
+
help="Also create Gradio Space code"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
args = parser.parse_args()
|
| 386 |
+
|
| 387 |
+
# Upload model
|
| 388 |
+
upload_to_huggingface(
|
| 389 |
+
model_path=args.model_path,
|
| 390 |
+
repo_id=args.repo_id,
|
| 391 |
+
token=args.token,
|
| 392 |
+
private=args.private,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Create Space files if requested
|
| 396 |
+
if args.create_space:
|
| 397 |
+
print("\n\nCreating Gradio Space files...")
|
| 398 |
+
|
| 399 |
+
# Save Gradio app
|
| 400 |
+
with open("app.py", 'w') as f:
|
| 401 |
+
f.write(create_gradio_app())
|
| 402 |
+
print("✓ Created app.py for Gradio Space")
|
| 403 |
+
|
| 404 |
+
# Save README
|
| 405 |
+
readme = create_readme_for_hf(args.repo_id)
|
| 406 |
+
with open("README_HF.md", 'w') as f:
|
| 407 |
+
f.write(readme)
|
| 408 |
+
print("✓ Created README_HF.md")
|
| 409 |
+
|
| 410 |
+
print("\n📋 To deploy on Hugging Face Spaces:")
|
| 411 |
+
print("1. Go to https://huggingface.co/spaces")
|
| 412 |
+
print("2. Click 'Create new Space'")
|
| 413 |
+
print("3. Choose Gradio SDK")
|
| 414 |
+
print("4. Upload all files")
|
| 415 |
+
print("5. Select CPU hardware")
|
| 416 |
+
print("6. Deploy!")
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
main()
|