ItsMaxNorm commited on
Commit
b4c3b65
·
verified ·
1 Parent(s): 5c1cabe

Fix inference snippet: use snapshot_download for subfolder loading

Browse files
Files changed (1) hide show
  1. README.md +24 -16
README.md CHANGED
@@ -19,24 +19,31 @@ GRPO-based safety post-training for Stable Diffusion using a closed-form,
19
  CLIP-based **steering reward**. No separately trained safety classifier;
20
  no paired safe/unsafe image dataset.
21
 
22
- The repository contains three drop-in `StableDiffusionPipeline` variants
23
- (loadable via `subfolder=...`), each trained with a different anchor set.
24
 
25
- | Subfolder | Anchor set (safe + unsafe) | Notes |
26
  |---|---|---|
27
  | **`scaled`** | 25 + 20 | Main paper checkpoint (epoch 280). |
28
- | **`compact`** | 5 + 3 | Best MMA-Diffusion ASR (2.6%, epoch 300). |
29
  | **`empty-positive`** | 0 + 3 | Ablation: only negative anchors. |
30
 
31
  ## Quick inference
32
 
33
  ```python
 
34
  from diffusers import StableDiffusionPipeline
35
- import torch
36
 
37
- pipe = StableDiffusionPipeline.from_pretrained(
 
 
 
38
  "ItsMaxNorm/SafeDiffusion-R1",
39
- subfolder="scaled", # or "compact" / "empty-positive"
 
 
 
40
  torch_dtype=torch.float16,
41
  ).to("cuda")
42
 
@@ -46,13 +53,13 @@ img.save("out.png")
46
 
47
  ## Headline results (vs.\ SD-v1.4 baseline)
48
 
49
- | Benchmark | SD-v1.4 | SafeDiffusion-R1 (scaled) | Δ |
50
  |---|---|---|---|
51
- | I2P inappropriate-content rate | 48.9 % | **18.07 %** | −63 % |
52
- | NudeNet detections (I2P, 4 703 prompts) | 646 | **15** | **−97.7 %** |
53
- | GenEval compositional accuracy | 42.08 % | **47.83 %** | +5.75 pp |
54
- | MMA-Diffusion ASR (1 000-prompt benchmark) | 22.6 % | **2.6 %** (compact variant) | **8.7×** safer |
55
- | SneakyPrompt skip-rate (200 NSFW prompts) | 37 % | **89.5 %** | model resists most prompts before any attack |
56
 
57
  The safety gains generalise to **seven OOD harm categories** (hate,
58
  harassment, violence, self-harm, shocking, illegal-activity, sexual)
@@ -69,10 +76,11 @@ then nudges the UNet to satisfy this steered reward. Because `v_safe`
69
  is computed from a **frozen** CLIP encoder, the target is stationary —
70
  samples drift on-policy but the anchor they're regressed onto does not.
71
 
72
- ## Repository
73
 
74
- Training code, evaluation scripts, ablation checkpoints, and the rebuttal
75
- results:
 
76
  **[https://github.com/MAXNORM8650/SafeDiffusion-R1](https://github.com/MAXNORM8650/SafeDiffusion-R1)**
77
 
78
  ## Citation
 
19
  CLIP-based **steering reward**. No separately trained safety classifier;
20
  no paired safe/unsafe image dataset.
21
 
22
+ The repository contains three full `StableDiffusionPipeline` variants
23
+ (each in its own subfolder), trained with different anchor sets.
24
 
25
+ | Subfolder | Anchors (safe + unsafe) | Notes |
26
  |---|---|---|
27
  | **`scaled`** | 25 + 20 | Main paper checkpoint (epoch 280). |
28
+ | **`compact`** | 5 + 3 | Best MMA-Diffusion ASR (2.6 %, epoch 300). |
29
  | **`empty-positive`** | 0 + 3 | Ablation: only negative anchors. |
30
 
31
  ## Quick inference
32
 
33
  ```python
34
+ from huggingface_hub import snapshot_download
35
  from diffusers import StableDiffusionPipeline
36
+ import os, torch
37
 
38
+ # StableDiffusionPipeline.from_pretrained does not natively accept
39
+ # `subfolder=` for the FULL pipeline (only single components), so we
40
+ # snapshot the variant we want then load from the local path.
41
+ local_root = snapshot_download(
42
  "ItsMaxNorm/SafeDiffusion-R1",
43
+ allow_patterns="scaled/*", # or "compact/*" / "empty-positive/*"
44
+ )
45
+ pipe = StableDiffusionPipeline.from_pretrained(
46
+ os.path.join(local_root, "scaled"),
47
  torch_dtype=torch.float16,
48
  ).to("cuda")
49
 
 
53
 
54
  ## Headline results (vs.\ SD-v1.4 baseline)
55
 
56
+ | Benchmark | SD-v1.4 | SafeDiffusion-R1 | Δ |
57
  |---|---|---|---|
58
+ | I2P inappropriate-content rate | 48.9 % | **18.07 %** (scaled) | −63 % |
59
+ | NudeNet detections (I2P, 4 703 prompts) | 646 | **15** (scaled) | **−97.7 %** |
60
+ | GenEval compositional accuracy | 42.08 % | **47.83 %** (scaled) | +5.75 pp |
61
+ | MMA-Diffusion ASR (1 000-prompt benchmark) | 22.6 % | **2.6 %** (compact) | **8.7×** safer |
62
+ | SneakyPrompt skip-rate (200 NSFW prompts) | 37 % | **89.5 %** (compact) | model resists most prompts before any attack |
63
 
64
  The safety gains generalise to **seven OOD harm categories** (hate,
65
  harassment, violence, self-harm, shocking, illegal-activity, sexual)
 
76
  is computed from a **frozen** CLIP encoder, the target is stationary —
77
  samples drift on-policy but the anchor they're regressed onto does not.
78
 
79
+ ## Code, training, and evaluation
80
 
81
+ Training code, the steering reward, evaluation scripts (FID, CLIP-score,
82
+ NudeNet, Q16, LPIPS, style-loss), and the end-to-end eval wrapper that
83
+ works directly against this Hub release:
84
  **[https://github.com/MAXNORM8650/SafeDiffusion-R1](https://github.com/MAXNORM8650/SafeDiffusion-R1)**
85
 
86
  ## Citation