Diffusers
Safetensors
AISkywalker commited on
Commit
a548c1e
Β·
verified Β·
1 Parent(s): 6f03a84

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +703 -3
README.md CHANGED
@@ -1,3 +1,703 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Image Synthesis and Data Augmentation Using Diffusion Models
2
+
3
+ *An Undergraduate Thesis Project on Ultrasound Image Analysis with DDPM, DDPM-variance, and Latent Diffusion Models*
4
+
5
+ ---
6
+
7
+ ## Abstract
8
+
9
+ This project implements and compares three generations of diffusion probabilistic models for medical image synthesis and data augmentation: **Standard DDPM**, **DDPM with learned variance prediction**, and **Latent Diffusion Models (LDM)**. The primary objective is to evaluate the effectiveness of synthetic data augmentation for improving classification performance on a 4-class udder ultrasound image dataset. Comprehensive experiments demonstrate that data augmentation with diffusion models significantly enhances downstream classifier accuracy across multiple backbone architectures. Additionally, we explore **mixed-data augmentation** strategies that combine samples from different generative models to leverage complementary strengths.
10
+
11
+ ---
12
+
13
+ ## Table of Contents
14
+
15
+ 1. [Introduction](#1-introduction)
16
+ 2. [Background and Related Work](#2-background-and-related-work)
17
+ 3. [Methodology](#3-methodology)
18
+ 4. [Experimental Setup](#4-experimental-setup)
19
+ 5. [Results and Analysis](#5-results-and-analysis)
20
+ - 5.6 [Mixed-Data Augmentation Experiments](#56-mixed-data-augmentation-experiments)
21
+ 6. [Project Structure](#6-project-structure)
22
+ 7. [Usage Instructions](#7-usage-instructions)
23
+ 8. [Citation](#8-citation)
24
+
25
+ ---
26
+
27
+ ## 1. Introduction
28
+
29
+ Medical imaging datasets, particularly in specialized domains like veterinary ultrasound, often suffer from limited sample sizes and class imbalances. This scarcity poses significant challenges for training robust deep learning models. Traditional data augmentation techniques (flipping, rotation, color jittering) provide limited diversity and may not capture the complex anatomical variations present in medical images.
30
+
31
+ This project addresses these limitations by exploring **diffusion-based data augmentation (DiffDA)** for udder ultrasound images. We implement and compare three state-of-the-art diffusion models:
32
+
33
+ 1. **Model 1 (DDPM)**: Standard Denoising Diffusion Probabilistic Model [1]
34
+ 2. **Model 2 (DDPM-variance)**: Improved DDPM with learned variance prediction and perceptual loss [2]
35
+ 3. **Model 3 (LDM)**: Latent Diffusion Model operating in compressed latent space [3]
36
+
37
+ The generated synthetic images are used to augment training data for downstream classification tasks, with systematic evaluation across four classifier architectures: ResNet-18, Swin-T, ViT-Tiny, and ConvNeXt-Tiny.
38
+
39
+ **Key Contributions**:
40
+ - Implementation of three diffusion model variants for medical image synthesis
41
+ - Comprehensive evaluation of diffusion-based data augmentation for ultrasound classification
42
+ - Analysis of the relationship between generative quality metrics (FID, LPIPS) and downstream classification performance
43
+ - **Mixed-data augmentation experiments** exploring complementary effects of combining LDM and DDPM-variance generated samples
44
+ - Open-source codebase for reproducible research in medical image augmentation
45
+
46
+ ---
47
+
48
+ ## 2. Background and Related Work
49
+
50
+ ### 2.1 Denoising Diffusion Probabilistic Models (DDPM)
51
+
52
+ DDPMs [1] are generative models that learn data distributions by reversing a gradual noising process. The forward process adds Gaussian noise over T timesteps:
53
+
54
+ $$
55
+ q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t\mathbf{I})
56
+ $$
57
+
58
+ The reverse process is learned by a neural network:
59
+
60
+ $$
61
+ p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_t, t), \sigma_t^2\mathbf{I})
62
+ $$
63
+
64
+ Training uses a simplified objective predicting the added noise:
65
+
66
+ $$
67
+ \mathcal{L}_{\text{simple}} = \mathbb{E}_{t, \mathbf{x}_0, \boldsymbol{\epsilon}} \left[ \| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \|^2 \right]
68
+ $$
69
+
70
+ ### 2.2 Improved DDPM with Learned Variance
71
+
72
+ Nichol & Dhariwal [2] extended DDPMs with several improvements:
73
+ - **Learned variance prediction**: Network outputs both noise and variance parameters
74
+ - **Cosine noise schedule**: More gradual noise addition than linear scheduling
75
+ - **Improved architecture**: U-Net with attention at multiple resolutions
76
+
77
+ These improvements yield better log-likelihoods and enable sampling with fewer steps while maintaining quality.
78
+
79
+ ### 2.3 Latent Diffusion Models (LDM)
80
+
81
+ LDMs [3] operate in a compressed latent space using a pretrained autoencoder, dramatically reducing computational requirements:
82
+
83
+ $$
84
+ \mathcal{E}(\mathbf{x}) = \mathbf{z} \quad \text{(Encoder)}, \quad \mathcal{D}(\mathbf{z}) = \tilde{\mathbf{x}} \quad \text{(Decoder)}
85
+ $$
86
+
87
+ The diffusion process occurs in the latent space $\mathbf{z} \in \mathbb{R}^{h \times w \times c}$, where $h \times w$ is 64Γ— smaller than the original image. This enables training high-resolution diffusion models with limited computational resources.
88
+
89
+ ### 2.4 Diffusion Models for Data Augmentation
90
+
91
+ Recent studies [4,5] have explored diffusion models for data augmentation (DiffDA). Key findings include:
92
+ - Traditional generative metrics (FID, IS) do not strongly correlate with downstream classification performance
93
+ - Classification accuracy is the gold standard for evaluating DiffDA effectiveness
94
+ - Diffusion models can generate diverse, high-quality medical images that improve classifier robustness
95
+
96
+ ---
97
+
98
+ ## 3. Methodology
99
+
100
+ ### 3.1 Model Architectures
101
+
102
+ #### 3.1.1 Model 1: Standard DDPM (`train.py`)
103
+ - **Backbone**: `UNet2DModel` from HuggingFace Diffusers
104
+ - **Conditioning**: Class labels via `num_class_embeds` parameter
105
+ - **Loss**: Charbonnier loss for robust denoising
106
+ - **Training**: DDPMScheduler with Ξ²-schedule="scaled_linear"
107
+ - **Key Features**: EMA weights, weighted sampling for class imbalance
108
+
109
+ #### 3.1.2 Model 2: DDPM-variance (`train_ddpm_variance.py`)
110
+ - **Backbone**: `UNet2DConditionModel` with cross-attention conditioning
111
+ - **Output Channels**: 6 (3 for noise prediction, 3 for variance prediction)
112
+ - **Conditioning**: Learned `LabelEmbedding` with 256-dim embeddings, 8 tokens
113
+ - **Loss Components**:
114
+ - MSE loss for noise prediction
115
+ - Perceptual loss (VGG16-based) for feature preservation
116
+ - Laplacian edge loss for anatomical detail preservation
117
+ - **Key Features**: Learned variance, classifier-free guidance (CFG), edge-aware training
118
+
119
+ #### 3.1.3 Model 3: LDM (`train_ldm.py`)
120
+ - **VAE**: Pretrained `AutoencoderKL` (stabilityai/sd-vae-ft-mse)
121
+ - **Latent Space**: 4 channels, 64Γ—64 resolution (scale_factor=0.18215)
122
+ - **Diffusion**: In latent space with DPMSolverMultistepScheduler
123
+ - **Conditioning**: Cross-attention with 512-dim label embeddings
124
+ - **Loss**: Combined noise prediction and perceptual loss
125
+ - **Key Features**: Latent space efficiency, DPM++ solver, CFG scaling
126
+
127
+ ### 3.2 Label Conditioning Strategy
128
+
129
+ All models implement class-conditional generation via label embeddings:
130
+
131
+ ```python
132
+ # DDPM-variance: Cross-attention conditioning
133
+ class LabelEmbedding(nn.Module):
134
+ def __init__(self, num_classes=5, embedding_dim=256, num_tokens=8):
135
+ self.embedding = nn.Embedding(num_classes, embedding_dim)
136
+ self.mlp = nn.Sequential(
137
+ nn.Linear(embedding_dim, embedding_dim * 2),
138
+ nn.SiLU(),
139
+ nn.Linear(embedding_dim * 2, embedding_dim * num_tokens)
140
+ )
141
+
142
+ # LDM: Similar architecture with 512-dim embeddings
143
+ ```
144
+
145
+ Conditioning is implemented with **classifier-free guidance (CFG)** during sampling, enabling trade-off between fidelity and diversity.
146
+
147
+ ### 3.3 Data Augmentation Pipeline
148
+
149
+ The complete augmentation workflow:
150
+
151
+ 1. **Model Training**: Train diffusion model on original dataset
152
+ 2. **Image Generation**: Generate synthetic images to reach target count (5000/class)
153
+ 3. **Dataset Composition**: Combine original and synthetic images
154
+ 4. **Classifier Training**: Train downstream classifiers on augmented datasets
155
+ 5. **Evaluation**: Compare performance against baseline (original data only)
156
+
157
+ ### 3.4 Evaluation Metrics
158
+
159
+ #### Generative Quality:
160
+ - **FrΓ©chet Inception Distance (FID)**: Measures distribution similarity
161
+ - **Inception Score (IS)**: Measures diversity and recognizability
162
+ - **LPIPS Diversity**: Learned Perceptual Image Patch Similarity within/between classes
163
+
164
+ #### Classification Performance:
165
+ - **Balanced Accuracy**: Accounts for class imbalance
166
+ - **F1-Score**: Harmonic mean of precision and recall
167
+ - **AUC-ROC**: Area under receiver operating characteristic curve
168
+ - **Confusion Matrices**: Per-class performance visualization
169
+
170
+ ---
171
+
172
+ ## 4. Experimental Setup
173
+
174
+ ### 4.1 Dataset
175
+
176
+ The project uses a **udder ultrasound image dataset** with 4 classes representing different anatomical or pathological conditions. The original dataset structure:
177
+
178
+ ```
179
+ datasets/
180
+ β”œβ”€β”€ train/
181
+ β”‚ β”œβ”€β”€ 1/ # Class 1 images
182
+ β”‚ β”œβ”€β”€ 2/ # Class 2 images
183
+ β”‚ β”œβ”€β”€ 3/ # Class 3 images
184
+ β”‚ └── 4/ # Class 4 images
185
+ └── test/ # Test set (not used for diffusion model training)
186
+ ```
187
+
188
+ **Key Statistics**:
189
+ - Original training images: Variable count per class (class-imbalanced)
190
+ - Target augmentation: 5000 images per class
191
+ - Image resolution: 256Γ—256 pixels (DDPM/DDPM-variance) / 512Γ—512 pixels (LDM)
192
+ - Preprocessing: Resize, random horizontal flip, color jitter, normalization to [-1, 1]
193
+
194
+ ### 4.2 Training Configuration
195
+
196
+ #### Diffusion Model Training:
197
+ - **Batch Size**: 6 (DDPM), 13 (DDPM-variance), 1 (LDM) - limited by GPU memory
198
+ - **Epochs**: 120 (DDPM), 50 (DDPM-variance), 100 (LDM)
199
+ - **Learning Rate**: 5e-6 (DDPM), 5e-5 (DDPM-variance), 5e-6 (LDM) with cosine warmup
200
+ - **Mixed Precision**: FP16 for memory efficiency
201
+ - **EMA Decay**: 0.9999 for stable training
202
+ - **Class Imbalance Handling**: WeightedRandomSampler with inverse class frequency weights
203
+
204
+ #### Classifier Training (`compare_4_model.py`):
205
+ - **Backbones**: ResNet-18, Swin-T, ViT-Tiny, ConvNeXt-Tiny (pretrained on ImageNet)
206
+ - **Advanced Techniques**:
207
+ - Mixup (Ξ±=0.05) for regularization
208
+ - Label smoothing (Ξ΅=0.05) for calibration
209
+ - RandAugment (N=2, M=9) for robustness
210
+ - Dropout (p=0.2) for preventing overfitting
211
+ - **Optimization**: SGD with momentum (0.9), weight decay (1e-4)
212
+ - **Scheduling**: CosineAnnealingWarmRestarts with T_0=10, T_mult=2
213
+ - **Early Stopping**: Patience=12 epochs based on validation loss
214
+
215
+ ### 4.3 Augmentation Experiments
216
+
217
+ Four augmentation scales were tested to study the effect of synthetic data quantity:
218
+
219
+ | Experiment | Target Images/Class | Model Used | Purpose |
220
+ |------------|---------------------|------------|---------|
221
+ | Base-500 | 500 | DDPM (Model 1) | Initial feasibility test |
222
+ | Base-1000 | 1000 | DDPM (Model 1) | Scaling effect analysis |
223
+ | Base-2000 | 2000 | DDPM (Model 1) | Intermediate scaling |
224
+ | **Base-5000** | **5000** | **DDPM (Model 1)** | **Full augmentation** |
225
+ | DDPM-variance-5000 | 5000 | DDPM-variance (Model 2) | Improved model comparison |
226
+ | LDM-5000 | 5000 | LDM (Model 3) | Latent space efficiency |
227
+ | Mixed-LDM-0.2 | 5000 | LDM (20%) + DDPM-variance (80%) | Complementary effects study |
228
+ | Mixed-LDM-0.5 | 5000 | LDM (50%) + DDPM-variance (50%) | Balanced fusion analysis |
229
+ | Mixed-LDM-0.8 | 5000 | LDM (80%) + DDPM-variance (20%) | LDM-dominant optimization |
230
+
231
+ ### 4.4 Evaluation Protocol
232
+
233
+ 1. **Generative Quality Assessment**:
234
+ - Compute FID between real and synthetic images (`compute_fid.py`)
235
+ - Calculate LPIPS diversity scores within and between classes (`LPIPS.py`)
236
+ - Visual inspection of generated samples
237
+
238
+ 2. **Classification Performance**:
239
+ - 5-fold cross-validation for reliable metrics
240
+ - Balanced accuracy as primary metric (handles class imbalance)
241
+ - Statistical significance testing between baseline and augmented results
242
+ - Confusion matrix analysis for per-class performance
243
+
244
+ 3. **Ablation Studies**:
245
+ - Effect of different augmentation quantities (500, 1000, 2000, 5000)
246
+ - Comparison across three diffusion model architectures
247
+ - Impact of advanced training techniques (mixup, label smoothing, etc.)
248
+
249
+ ---
250
+
251
+ ## 5. Results and Analysis
252
+
253
+ ### 5.1 Generative Quality Metrics
254
+
255
+ #### FID Scores (from `FID_lpips.txt`):
256
+ - **DDPM (Model 1)**: 71.65 (IS: 4.69 Β± 0.06) - Evaluated on `ddpm_augmented_v1/train`
257
+ - **DDPM-variance (Model 2)**: [Value not recorded in current log file]
258
+ - **LDM (Model 3)**: 54.20 (IS: 5.28 Β± 0.09) - Evaluated on `ldm_augmented_v2/train`
259
+
260
+ *Lower FID indicates better distribution matching with real images. LDM achieves the best FID score (54.20), suggesting better generative quality, though DDPM-variance shows superior downstream classification performance.*
261
+
262
+ #### LPIPS Diversity Scores (from `FID_lpips.txt`):
263
+ - **DDPM (Model 1)**: 0.659 overall diversity score (range: 0.583-0.664 per class)
264
+ - **LDM (Model 3)**: 0.649 overall diversity score (range: 0.574-0.645 per class)
265
+ - **DDPM-variance (Model 2)**: [Value not recorded in current log file]
266
+
267
+ *LPIPS measures perceptual diversity - higher scores indicate more diverse generated images. Both DDPM and LDM show good diversity (>0.64), with DDPM showing slightly higher overall diversity.*
268
+
269
+ ### 5.2 Classification Performance Improvement
270
+
271
+ The key finding: **All diffusion models improved downstream classification accuracy**, with DDPM-variance (Model 2) achieving the best results.
272
+
273
+ #### Performance Comparison (from `Classification_Experiments/Final_Analysis_Report_ddpm_variance_V2/Performance_Table.csv` - Model 2 Results):
274
+
275
+ | Model | Baseline Accuracy | Augmented Accuracy | Improvement |
276
+ |-------|-------------------|-------------------|-------------|
277
+ | ResNet-18 | 68.99% | 92.31% | **+23.32%** |
278
+ | Swin-T | 70.89% | 94.51% | **+23.62%** |
279
+ | ViT-Tiny | 76.58% | 90.11% | **+13.53%** |
280
+ | ConvNeXt-Tiny | 70.89% | 95.05% | **+24.16%** |
281
+
282
+ **Key Observations**:
283
+ 1. **ConvNeXt-Tiny achieves the highest absolute accuracy** (95.05%) and **largest improvement** (+24.16%) with DDPM-variance augmentation
284
+ 2. **All models exceed 90% accuracy** with DDPM-variance augmentation, demonstrating the superior effectiveness of this model for data augmentation
285
+ 3. **Swin-T shows remarkable improvement** (+23.62%), reaching 94.51% accuracy
286
+ 4. **Even ViT-Tiny**, which had the highest baseline accuracy, still improves by +13.53% to reach 90.11%
287
+
288
+ *Note: These results represent the performance with **DDPM-variance (Model 2)** augmentation, which significantly outperforms the standard DDPM (Model 1) results shown in earlier experiments.*
289
+
290
+ ##### ResNet-18 Performance Analysis with DDPM-variance (+23.32% Improvement)
291
+ ![ResNet-18 Confusion Matrix](Classification_Experiments/Final_Analysis_Report_ddpm_variance_V2/resnet18_confusion_matrix.png)
292
+ *Figure: Confusion matrix for ResNet-18 trained on DDPM-variance (Model 2) augmented data. Shows excellent per-class accuracy with minimal confusion between classes.*
293
+
294
+ ![ResNet-18 Performance Comparison](Classification_Experiments/Final_Analysis_Report_ddpm_variance_V2/resnet18_comparison.png)
295
+ *Figure: Performance comparison of ResNet-18 on baseline vs. DDPM-variance augmented datasets. Demonstrates dramatic accuracy improvement from 68.99% to 92.31% (+23.32%).*
296
+
297
+ ##### ConvNeXt-Tiny Performance Analysis with DDPM-variance (+24.16% Improvement - Best Overall)
298
+ ![ConvNeXt-Tiny Confusion Matrix](Classification_Experiments/Final_Analysis_Report_ddpm_variance_V2/convnext_tiny_confusion_matrix.png)
299
+ *Figure: Confusion matrix for ConvNeXt-Tiny trained on DDPM-variance augmented data. Shows near-perfect classification across all 4 classes with minimal errors.*
300
+
301
+ ![ConvNeXt-Tiny Performance Comparison](Classification_Experiments/Final_Analysis_Report_ddpm_variance_V2/convnext_tiny_comparison.png)
302
+ *Figure: ConvNeXt-Tiny performance comparison showing dramatic improvement from 70.89% to 95.05% (+24.16%) - the highest overall accuracy achieved in this study.*
303
+
304
+ ##### Overall Performance Summary with DDPM-variance
305
+ ![Final Bar Comparison](Classification_Experiments/Final_Analysis_Report_ddpm_variance_V2/Final_Bar_Comparison.png)
306
+ *Figure: Comprehensive bar chart comparing all four classifiers' performance on baseline vs. DDPM-variance augmented datasets. All models achieve over 90% accuracy, with ConvNeXt-Tiny reaching 95.05% - demonstrating the exceptional effectiveness of DDPM-variance for medical image data augmentation.*
307
+
308
+ *Additional detailed confusion matrices and comparison charts for Swin-T and ViT-Tiny trained on DDPM-variance augmented data are available in the `Final_Analysis_Report_ddpm_variance_V2/` folder.*
309
+
310
+ ##### Comparison Across All Three Models
311
+ The table below shows the performance improvement comparison across all three diffusion models for ResNet-18:
312
+
313
+ | Model | Baseline | Augmented | Improvement | Relative Advantage |
314
+ |-------|----------|-----------|-------------|-------------------|
315
+ | **DDPM (Model 1)** | 68.99% | 83.54% | +14.55% | Baseline |
316
+ | **DDPM-variance (Model 2)** | 68.99% | 92.31% | +23.32% | **+8.77% better than Model 1** |
317
+ | **LDM (Model 3)** | 68.99% | 92.41% | +23.42% | **+8.87% better than Model 1** |
318
+
319
+ *Analysis*: Both DDPM-variance (Model 2) and LDM (Model 3) show dramatically better performance than standard DDPM (Model 1), with improvements around +23.3% vs +14.55%. While LDM shows slightly better ResNet-18 performance (+23.42% vs +23.32%), DDPM-variance achieves the highest overall accuracy with ConvNeXt-Tiny (95.05%) and best balance across all metrics (see Section 5.3).
320
+
321
+ *Complete results for all models*:
322
+ - **Model 1 (DDPM)**: `Final_Analysis_Report_4_2/`
323
+ - **Model 2 (DDPM-variance)**: `Final_Analysis_Report_ddpm_variance_V2/`
324
+ - **Model 3 (LDM)**: `Final_Analysis_Report_vdm_2/`
325
+
326
+ ### 5.3 Model Comparison: DDPM vs DDPM-variance vs LDM
327
+
328
+ Analysis from `extra_images5000_added_ddpm_vdm_variance/DDPM_VARIACNE_LDM_Data_Augmentation_Comparison_Summary.csv`:
329
+
330
+ | Metric | DDPM (Model 1) | DDPM-variance (Model 2) | LDM (Model 3) |
331
+ |--------|----------------|-------------------------|---------------|
332
+ | **Generation Speed** | Medium | Fast (fewer steps) | Slow (VAE encode/decode) |
333
+ | **Memory Usage** | High | Medium | Low (latent space) |
334
+ | **Image Quality** | Good | Excellent (edge preservation) | Very Good |
335
+ | **Classification Gain** | Good | **Best** | Good |
336
+
337
+ **DDPM-variance (Model 2) emerges as the optimal choice** for this application, balancing quality, speed, and downstream performance.
338
+
339
+ ##### Three-Model Comprehensive Comparison
340
+ ![DDPM vs DDPM-variance vs LDM Bar Chart](Classification_Experiments/extra_images5000_added_ddpm_vdm_variance/ddpm_vdm_Bar_Chart.png)
341
+ *Figure: Comprehensive bar chart comparing all three diffusion models (DDPM, DDPM-variance, LDM) across multiple metrics including FID scores, generation speed, memory usage, and downstream classification improvement. DDPM-variance shows the best balance of quality and efficiency.*
342
+
343
+ *Detailed numerical comparison metrics are available in `extra_images5000_added_ddpm_vdm_variance/DDPM_VARIACNE_LDM_Data_Augmentation_Comparison_Summary.csv`.*
344
+
345
+ ### 5.4 Effect of Augmentation Quantity
346
+
347
+ Experiments with 500, 1000, 2000, and 5000 synthetic images per class reveal:
348
+ - **Diminishing returns**: Largest gains from 500β†’1000, smaller from 2000β†’5000
349
+ - **Threshold effect**: ~2000 images/class appears sufficient for most classifiers
350
+ - **Class-dependent benefits**: Minority classes benefit more from augmentation
351
+
352
+ ### 5.5 Visualization and Interpretability
353
+
354
+ #### 5.5.1 Model-Specific Analysis Reports
355
+ The project includes comprehensive analysis reports for each diffusion model in separate folders:
356
+
357
+ - **`Final_Analysis_Report_4_2/`**: Complete analysis for DDPM (Model 1) augmentation results
358
+ - **`Final_Analysis_Report_ddpm_variance_V2/`**: Complete analysis for DDPM-variance (Model 2) augmentation results
359
+ - **`Final_Analysis_Report_vdm_2/`**: Complete analysis for LDM (Model 3) augmentation results
360
+
361
+ Each folder contains confusion matrices, performance comparison charts, and detailed metrics for all four classifier backbones trained on that model's augmented dataset.
362
+
363
+ #### 5.5.2 Generated Sample Inspection
364
+ - **DDPM-variance** produces sharper anatomical details due to edge-aware loss and perceptual regularization
365
+ - **LDM** generates globally coherent structures but may lack fine details due to latent space compression
366
+ - All models maintain class-specific characteristics (important for medical validity)
367
+
368
+ #### 5.5.3 Label Embedding and Cross-Attention Analysis
369
+
370
+ The `label_test/` directory contains comprehensive visualizations of label embeddings and cross-attention mechanisms used in Models 2 & 3 (DDPM-variance and LDM). These visualizations demonstrate the effectiveness of the cross-attention conditioning mechanism:
371
+
372
+ ##### PCA Visualization of Label Embeddings
373
+ ![Label Embedding PCA](Classification_Experiments/label_test/Label%20Embedding%20PCA.png)
374
+ *Figure: 2D PCA projection of label embeddings showing clear separation of classes 0-4 in embedding space. Class 0 is distinctly separated, while classes 3 and 4 form a close cluster.*
375
+
376
+ ##### t-SNE Visualizations (Multiple Initializations)
377
+ ![t-SNE Visualization 1](Classification_Experiments/label_test/Figure_1.png)
378
+ *Figure: t-SNE visualization of label embeddings (seed 1). The non-linear dimensionality reduction reveals natural clustering patterns.*
379
+
380
+ ![t-SNE Visualization 2](Classification_Experiments/label_test/Figure_2.png)
381
+ *Figure: t-SNE visualization of label embeddings (seed 2). Consistent clustering across different random initializations validates embedding quality.*
382
+
383
+ ![Mean t-SNE](Classification_Experiments/label_test/Figure_mean.png)
384
+ *Figure: Mean t-SNE visualization showing stable embedding structure across multiple runs.*
385
+
386
+ ##### Similarity and Distance Analysis
387
+ ![Cosine Similarity Matrix](Classification_Experiments/label_test/label_cosine_similarity.png)
388
+ *Figure: 5Γ—5 cosine similarity matrix heatmap. Diagonal = 1.0 (self-similarity), off-diagonal values show semantic relationships between classes.*
389
+
390
+ ![Euclidean Distance Matrix](Classification_Experiments/label_test/label_distance.png)
391
+ *Figure: Euclidean distance matrix between class embeddings, quantifying separation in embedding space.*
392
+
393
+ ##### Token Norm Distribution
394
+ ![Token Norm Distribution](Classification_Experiments/label_test/Token%20Norm%20Distribution.png)
395
+ *Figure: Statistical distribution of attention token norms, validating proper initialization and training of cross-attention parameters.*
396
+
397
+ These visualizations confirm that the learned label embeddings form a semantically meaningful space where similar classes are closer together, enabling effective cross-attention conditioning during image generation. The consistent clustering patterns across multiple visualization techniques (PCA, t-SNE) and the structured similarity/distance matrices demonstrate that the models have learned meaningful class representations.
398
+
399
+ ---
400
+
401
+ ## 5.6 Mixed-Data Augmentation Experiments
402
+
403
+ ### 5.6.1 Experimental Design and Rationale
404
+
405
+ Recognizing the complementary strengths of different diffusion models, we conducted mixed-data augmentation experiments to investigate whether combining samples from multiple generative models could yield superior classification performance. The experimental design addresses:
406
+
407
+ - **Single-model limitations**: LDM offers fast sampling and high visual fidelity but limited diversity; DDPM-Variance provides excellent diversity but slower sampling and small-sample instability
408
+ - **Complementary effects**: By mixing samples from both models, we aim to leverage LDM's global coherence and DDPM-Variance's edge preservation capabilities
409
+
410
+ ### 5.6.2 Experiment Configuration
411
+
412
+ Three mixed-data groups were created with different LDM-to-DDPM-Variance ratios (where `a` = LDM proportion):
413
+
414
+ | Group | Ratio (LDM:DDPM-Var) | Description | Target Total |
415
+ |-------|----------------------|-------------|--------------|
416
+ | **A (a=0.2)** | 20:80 | DDPM-Variance dominant, LDM for FID correction | 5000/class |
417
+ | **B (a=0.5)** | 50:50 | Balanced fusion for feature coverage | 5000/class |
418
+ | **C (a=0.8)** | 80:20 | LDM dominant, DDPM-Variance for diversity boost | 5000/class |
419
+
420
+ **Methodology**:
421
+ 1. Start with DDPM-Variance augmented set (5000 images/class) as baseline
422
+ 2. Incrementally replace samples with LDM-generated images at specified ratios
423
+ 3. Maintain consistent base data subsets for fair comparison
424
+
425
+ ### 5.6.3 Performance Results
426
+
427
+ Results from `mixed_data_Comparison_Summary.csv` show mixed-data augmentation outperforms single-model approaches:
428
+
429
+ | Model | DDPM-Var Only | LDM Only | a=0.2 | a=0.5 | a=0.8 |
430
+ |-------|---------------|----------|-------|-------|-------|
431
+ | ConvNeXt-Tiny | 95.05% | 91.77% | 95.05% | **96.15%** | **97.80%** |
432
+ | ResNet-18 | 92.31% | 92.41% | 93.41% | 95.05% | 95.05% |
433
+ | Swin-T | 94.51% | 90.51% | **95.60%** | 93.96% | 94.51% |
434
+ | ViT-Tiny | 90.11% | 87.97% | **92.86%** | 90.66% | 91.76% |
435
+
436
+ ### 5.6.4 Key Findings
437
+
438
+ 1. **Best Overall Performance**: ConvNeXt-Tiny with a=0.8 achieves **97.80% accuracy** - the highest across all experiments in this study
439
+ 2. **Model-Specific Preferences**:
440
+ - ConvNeXt-Tiny benefits from LDM-dominant mixes (97.80% at a=0.8)
441
+ - Swin-T and ViT-Tiny perform best with DDPM-Variance-dominant mixes (95.60% and 92.86% at a=0.2)
442
+ - ResNet-18 shows balanced improvement across mixes
443
+ 3. **Complementary Benefits**: Mixed data consistently outperforms single-model augmentation, validating the hypothesis that LDM and DDPM-Variance have complementary strengths
444
+ 4. **Practical Implications**: For medical image augmentation, optimal mixing ratios depend on the target classifier architecture
445
+
446
+ ### 5.6.5 Visualization
447
+
448
+ ![Mixed Data Comparison Chart](Classification_Experiments/mixed_data/ddpm_vdm_Bar_Chart.png)
449
+ *Figure: Comprehensive bar chart comparing all augmentation strategies including mixed-data experiments. Shows performance advantages of mixed-data augmentation over single-model approaches.*
450
+
451
+ ![Classification Accuracy vs. LDM Ratio](Classification_Experiments/mixed_data/mixed_data.png)
452
+ *Figure: Multi-series line chart showing classification accuracy as a function of LDM mixing ratio (Alpha). Demonstrates how different classifier architectures respond to varying proportions of LDM and DDPM-Variance samples, with ConvNeXt-Tiny achieving peak performance at Ξ±=0.8 (80% LDM + 20% DDPM-Var).*
453
+
454
+ ---
455
+
456
+ ## 6. Project Structure
457
+
458
+ The project follows a well-organized directory structure designed for reproducibility and clarity:
459
+
460
+ ```
461
+ ddpm/
462
+ β”œβ”€β”€ Core Model Training Scripts
463
+ β”‚ β”œβ”€β”€ train.py # Model 1: Standard DDPM training
464
+ β”‚ β”œβ”€β”€ train_ddpm_variance.py # Model 2: DDPM with learned variance
465
+ β”‚ └── train_ldm.py # Model 3: Latent Diffusion Model
466
+ β”‚
467
+ β”œβ”€β”€ Data Augmentation Generators
468
+ β”‚ β”œβ”€β”€ add_p.py # Generate images using DDPM (Model 1)
469
+ β”‚ β”œβ”€β”€ ddpm_variance_add.py # Generate images using DDPM-variance (Model 2)
470
+ β”‚ └── vdm_add.py # Generate images using LDM (Model 3)
471
+ β”‚
472
+ β”œβ”€β”€ Evaluation & Metrics
473
+ β”‚ β”œβ”€β”€ FID.py # FID calculation for diffusion models
474
+ β”‚ β”œβ”€β”€ compute_fid.py # FID/IS for augmented datasets
475
+ β”‚ β”œβ”€β”€ LPIPS.py # LPIPS diversity score calculation
476
+ β”‚ └── all_experiments_log.json # Log of all experiment results
477
+ β”‚
478
+ β”œβ”€β”€ Classification Experiments
479
+ β”‚ β”œβ”€β”€ compare_4_model.py # Train 4 classifiers with advanced techniques
480
+ β”‚ β”œβ”€β”€ jieguo.py # Generate analysis reports and visualizations
481
+ β”‚ β”œβ”€β”€ comparison.py # Multi-experiment comparison
482
+ β”‚ └── Classification_Experiments/ # All experiment outputs
483
+ β”‚ β”œβ”€β”€ Augmented_Dataset_4_2/ # Best results with Model 1 augmentation
484
+ β”‚ β”‚ β”œβ”€β”€ models/ # Trained classifier weights
485
+ β”‚ β”‚ └── results/ # Performance metrics
486
+ β”‚ β”œβ”€β”€ Augmented_Dataset_ddpm_variance_V2/ # Best results with Model 2 augmentation
487
+ β”‚ β”œβ”€β”€ Augmented_Dataset_vdm_2/ # Best results with Model 3 augmentation
488
+ β”‚ β”œβ”€β”€ Final_Analysis_Report_4_2/ # Final analysis for Model 1
489
+ β”‚ β”œβ”€β”€ Final_Analysis_Report_ddpm_variance_V2/ # Final analysis for Model 2
490
+ β”‚ β”œβ”€β”€ Final_Analysis_Report_vdm_2/ # Final analysis for Model 3
491
+ β”‚ β”œβ”€β”€ label_test/ # Label embedding visualizations
492
+ β”‚ β”œβ”€β”€ mixed_data/ # Mixed-data experiment results (LDM + DDPM-Var)
493
+ β”‚ └── extra_images5000_added_ddpm_vdm_variance/ # Final comparison results
494
+ β”‚
495
+ β”œβ”€β”€ Datasets
496
+ β”‚ β”œβ”€β”€ datasets/ # Original dataset (train/test split)
497
+ β”‚ β”œβ”€β”€ Base_datasets/ # Original base datasets
498
+ β”‚ β”œβ”€β”€ Base_datasets_augmented/ # DDPM-augmented datasets (500, 1000, 2000, 5000/class)
499
+ β”‚ β”œβ”€β”€ ddpm_augmented_v1/ # Model 1 augmented dataset (5000/class)
500
+ β”‚ β”œβ”€β”€ ddpm_variance_augmented_v2/ # Model 2 augmented dataset (5000/class)
501
+ β”‚ └── ldm_augmented_v2/ # Model 3 augmented dataset (5000/class)
502
+ β”‚
503
+ β”œβ”€β”€ Model Checkpoints
504
+ β”‚ β”œβ”€β”€ ddpm-udder-results*/ # Model 1 training checkpoints
505
+ β”‚ β”œβ”€β”€ ddpm_variance_*/ # Model 2 training checkpoints
506
+ β”‚ └── ldm_udder_v*/ # Model 3 training checkpoints
507
+ β”‚
508
+ └── Utility Scripts
509
+ β”œβ”€β”€ split.py # Dataset splitting utility
510
+ β”œβ”€β”€ test_*.py # Various testing scripts
511
+ └── prompt.txt # Project notes and prompts
512
+ ```
513
+
514
+ ### 6.1 Key Directory Explanations
515
+
516
+ #### Base Datasets (Initial Experiments):
517
+ - `Base_datasets_augmented/`: Traditional DDPM augmentation to 500 images/class
518
+ - `Base_datasets_augmented_2/`: Augmentation to 1000 images/class
519
+ - `Base_datasets_augmented_3/`: Augmentation to 2000 images/class
520
+ - `Base_datasets_augmented_4/`: Augmentation to 5000 images/class
521
+ - *Purpose*: Validate that augmentation to 5000 images improves all 4 classifiers
522
+
523
+ #### Model Training Checkpoints:
524
+ - `ddpm_variance_v*/`: Intermediate weights for Model 2 training
525
+ - `ddpm_udder_v*/`: Intermediate weights for Model 1 training
526
+ - `ldm_udder_v*/`: Intermediate weights for Model 3 training
527
+ - *Note*: Highest numbered version (e.g., `ddpm_variance_22`) contains the final trained model
528
+
529
+ #### Classification Experiments:
530
+ - `Classification_Experiments/Augmented_Dataset_4_2/`: Best classifier results with Model 1 augmentation
531
+ - `Classification_Experiments/Augmented_Dataset_ddpm_variance_V2/`: Best results with Model 2 augmentation
532
+ - `Classification_Experiments/Augmented_Dataset_vdm_2/`: Best results with Model 3 augmentation
533
+ - `Classification_Experiments/Final_Analysis_Report_*/`: Complete analysis reports for each model
534
+
535
+ #### Special Directories:
536
+ - `Classification_Experiments/label_test/`: Visualization of attention mechanisms and label embeddings in Models 2 & 3
537
+ - `Classification_Experiments/mixed_data/`: Mixed-data experiment results (LDM + DDPM-Var combinations)
538
+ - `Classification_Experiments/extra_images5000_added_ddpm_vdm_variance/`: Final experimental results comparison
539
+
540
+ ---
541
+
542
+ ## 7. Usage Instructions
543
+
544
+ ### 7.1 Environment Setup
545
+
546
+ ```bash
547
+ # Clone the repository
548
+ git clone [repository_url]
549
+ cd ddpm
550
+
551
+ # Create and activate conda environment (recommended)
552
+ conda create -n ddpm python=3.9
553
+ conda activate ddpm
554
+
555
+ # Install PyTorch (CUDA 11.8 example)
556
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
557
+
558
+ # Install other dependencies
559
+ pip install diffusers accelerate transformers datasets
560
+ pip install lpips torchmetrics matplotlib scikit-learn
561
+ pip install timm # for classifier backbones
562
+ ```
563
+
564
+ ### 7.2 Training a Diffusion Model
565
+
566
+ #### Model 1 (Standard DDPM):
567
+ ```bash
568
+ python train.py
569
+ ```
570
+ *Configuration*: Edit `Config` class in `train.py` for data paths, batch size, etc.
571
+
572
+ #### Model 2 (DDPM-variance):
573
+ ```bash
574
+ python train_ddpm_variance.py
575
+ ```
576
+ *Configuration*: Modify `Config` class in `train_ddpm_variance.py`. Key parameters: `use_variance_prediction=True`, `cross_attention_dim=256`, `cfg=5`.
577
+
578
+ #### Model 3 (LDM):
579
+ ```bash
580
+ python train_ldm.py
581
+ ```
582
+ *Note*: Requires pretrained VAE from `stabilityai/sd-vae-ft-mse`. Set `vae_model` path in config.
583
+
584
+ ### 7.3 Generating Augmented Datasets
585
+
586
+ After training, generate synthetic images to reach 5000 images per class:
587
+
588
+ #### Using Model 1 (DDPM):
589
+ ```bash
590
+ python add_p.py
591
+ ```
592
+ *Configure*: Set `target_count=5000`, `model_path` to trained checkpoint in `add_p.py`.
593
+
594
+ #### Using Model 2 (DDPM-variance):
595
+ ```bash
596
+ python ddpm_variance_add.py
597
+ ```
598
+ *Configure*: Update `model_path` and `output_dir` in `ddpm_variance_add.py`.
599
+
600
+ #### Using Model 3 (LDM):
601
+ ```bash
602
+ python vdm_add.py
603
+ ```
604
+ *Configure*: Set `model_path`, `vae_path`, and `output_dir` in `vdm_add.py`.
605
+
606
+ ### 7.4 Evaluating Generative Quality
607
+
608
+ ```bash
609
+ # Compute FID and Inception Score
610
+ python compute_fid.py
611
+
612
+ # Calculate LPIPS diversity scores
613
+ python LPIPS.py
614
+ ```
615
+
616
+ *Configuration*: Update `REAL_DIR` and `FAKE_DIR` in `compute_fid.py` to point to real and generated datasets.
617
+
618
+ ### 7.5 Training and Evaluating Classifiers
619
+
620
+ ```bash
621
+ # Train all 4 classifiers on augmented dataset
622
+ python compare_4_model.py
623
+
624
+ # Generate comprehensive analysis reports
625
+ python jieguo.py
626
+
627
+ # Compare multiple experiments
628
+ python comparison.py
629
+ ```
630
+
631
+ *Configuration*: Modify dataset paths and training parameters in `compare_4_model.py`.
632
+
633
+ ### 7.6 Reproducing Specific Experiments
634
+
635
+ To reproduce the key experiments from this study:
636
+
637
+ 1. **Baseline Augmentation (Model 1)**:
638
+ ```bash
639
+ # Use checkpoints from ddpm-udder-results6/ or ddpm-udder-results7/
640
+ python add_p.py # Set model_path accordingly
641
+ ```
642
+
643
+ 2. **DDPM-variance Augmentation (Model 2)**:
644
+ ```bash
645
+ # Use checkpoints from ddpm_variance_22/
646
+ python ddpm_variance_add.py
647
+ ```
648
+
649
+ 3. **LDM Augmentation (Model 3)**:
650
+ ```bash
651
+ # Use checkpoints from ldm_udder_v22/
652
+ python vdm_add.py
653
+ ```
654
+
655
+ ### 7.7 Visualizing Results
656
+
657
+ - **Label Embeddings**: `python test_label.py` generates t-SNE and cosine similarity plots
658
+ - **Generated Samples**: `test_ddpm_variance.py`, `test_ldm.py` produce sample images
659
+ - **Attention Maps**: Check `label_test/` directory for cross-attention visualizations
660
+
661
+ ---
662
+
663
+ ## 8. Citation
664
+
665
+ If you use this codebase or build upon this work, please cite:
666
+
667
+ ```bibtex
668
+ @article{your_thesis_2026,
669
+ title={Diffusion-Based Data Augmentation for Medical Image Classification: A Comparative Study of DDPM, DDPM-variance, and Latent Diffusion Models},
670
+ author={Junze Ye},
671
+ journal={Undergraduate Thesis},
672
+ year={2026},
673
+ publisher={NJAU}
674
+ }
675
+ ```
676
+
677
+ ### References
678
+
679
+ [1] **Ho, J., Jain, A., & Abbeel, P.** (2020). *Denoising Diffusion Probabilistic Models*. Advances in Neural Information Processing Systems.
680
+
681
+ [2] **Nichol, A. Q., & Dhariwal, P.** (2021). *Improved Denoising Diffusion Probabilistic Models*. International Conference on Machine Learning.
682
+
683
+ [3] **Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B.** (2022). *High-Resolution Image Synthesis with Latent Diffusion Models*. IEEE/CVF Conference on Computer Vision and Pattern Recognition.
684
+
685
+ [4] **Groh, M., et al.** (2023). *Evaluating the Performance of StyleGAN2-ADA on Medical Images*. International Conference on Medical Image Computing and Computer-Assisted Intervention.
686
+
687
+ [5] **Diffusion Models for Data Augmentation Survey** (2023). *arXiv preprint arXiv:2308.12453*.
688
+
689
+ ......
690
+ ---
691
+
692
+ ## Acknowledgments
693
+
694
+ This work was conducted as part of an undergraduate thesis project at [NJAU]. Special thanks to my supervisor, Professor Zhai, for guidance and support throughout this project. Also thanks to the open-source community for the Diffusers library that made this research possible.
695
+
696
+ ## License
697
+
698
+ This project is available for academic research purposes. For commercial use, please contact the author.
699
+
700
+ ---
701
+
702
+ *Last Updated: April 2026*
703
+ *Project Status: Completed Research*