Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# ================================================================
|
| 2 |
-
# 教育大模型MIA攻防研究 - Gradio演示系统
|
| 3 |
-
# 1. 严格基于
|
| 4 |
-
# 2.
|
| 5 |
-
# 3. 修复
|
| 6 |
# ================================================================
|
| 7 |
|
| 8 |
import os
|
|
@@ -18,7 +18,7 @@ import gradio as gr
|
|
| 18 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
|
| 20 |
# ================================================================
|
| 21 |
-
# 数据加载
|
| 22 |
# ================================================================
|
| 23 |
def load_json(path):
|
| 24 |
full = os.path.join(BASE_DIR, path)
|
|
@@ -64,7 +64,7 @@ except FileNotFoundError:
|
|
| 64 |
perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
|
| 65 |
|
| 66 |
# ================================================================
|
| 67 |
-
# 全局
|
| 68 |
# ================================================================
|
| 69 |
COLORS = {
|
| 70 |
'bg': '#FFFFFF',
|
|
@@ -81,45 +81,22 @@ COLORS = {
|
|
| 81 |
'ls_colors': ['#A0C4FF', '#70A1FF', '#478EFF', '#007AFF'],
|
| 82 |
'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
|
| 83 |
}
|
| 84 |
-
|
| 85 |
-
# 🌟 专门为图表新增的学术黑白配置集 (Hatch与线型)
|
| 86 |
-
CHART_C = {
|
| 87 |
-
'bg': '#FFFFFF',
|
| 88 |
-
'panel': '#FFFFFF',
|
| 89 |
-
'grid': '#E0E0E0',
|
| 90 |
-
'text': '#000000',
|
| 91 |
-
'baseline': '#FFFFFF',
|
| 92 |
-
'ls_colors': ['#EEEEEE', '#CCCCCC', '#AAAAAA', '#888888'],
|
| 93 |
-
'op_colors': ['#F8F8F8', '#E8E8E8', '#D8D8D8', '#C8C8C8', '#B8B8B8', '#A8A8A8'],
|
| 94 |
-
'mem': '#EAEAEA',
|
| 95 |
-
'nmem': '#FFFFFF'
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
HATCH_BASELINE = ''
|
| 99 |
-
HATCH_LS = ['//', '\\\\', 'xx', '++']
|
| 100 |
-
HATCH_OP = ['..', 'oo', 'OO', '**', '--', '||']
|
| 101 |
-
HATCH_MEMBER = '///'
|
| 102 |
-
HATCH_NONMEMBER = '\\\\\\\\'
|
| 103 |
-
|
| 104 |
-
LS_LINESTYLES = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
|
| 105 |
-
OP_LINESTYLES = ['-', '--', '-.', ':', (0, (5, 1)), (0, (3, 1, 1, 1, 1, 1))]
|
| 106 |
-
|
| 107 |
CHART_W = 14
|
| 108 |
|
| 109 |
-
def
|
| 110 |
-
fig.patch.set_facecolor(
|
| 111 |
axes = ax_or_axes if hasattr(ax_or_axes, '__iter__') else [ax_or_axes]
|
| 112 |
for ax in axes:
|
| 113 |
-
ax.set_facecolor(
|
| 114 |
for spine in ax.spines.values():
|
| 115 |
-
spine.set_color('
|
| 116 |
-
spine.set_linewidth(1
|
| 117 |
-
ax.tick_params(colors='
|
| 118 |
-
ax.xaxis.label.set_color('
|
| 119 |
-
ax.yaxis.label.set_color('
|
| 120 |
-
ax.title.set_color('
|
| 121 |
-
ax.title.set_fontweight('
|
| 122 |
-
ax.grid(True, color=
|
| 123 |
ax.set_axisbelow(True)
|
| 124 |
|
| 125 |
# ================================================================
|
|
@@ -181,147 +158,80 @@ for _i in range(300):
|
|
| 181 |
EVAL_POOL.append(item)
|
| 182 |
|
| 183 |
# ================================================================
|
| 184 |
-
# 图表绘制函数 (
|
| 185 |
# ================================================================
|
| 186 |
def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
|
| 187 |
-
fig, ax = plt.subplots(figsize=(10, 2.6));
|
| 188 |
xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
|
| 189 |
-
|
| 190 |
-
ax.
|
| 191 |
-
ax.
|
| 192 |
-
ax.axvline(
|
| 193 |
-
ax.text(
|
| 194 |
-
ax.axvline(
|
| 195 |
-
ax.text(
|
| 196 |
-
|
| 197 |
-
ax.
|
| 198 |
-
ax.
|
| 199 |
-
ax.text(
|
| 200 |
-
ax.text((
|
| 201 |
-
ax.text((thr+xhi)/2, 0.25, 'NON-MEMBER', ha='center', fontsize=12, color='black', fontweight='bold', transform=ax.get_xaxis_transform(), bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
|
| 202 |
ax.set_xlim(xlo, xhi); ax.set_yticks([])
|
| 203 |
for s in ax.spines.values(): s.set_visible(False)
|
| 204 |
-
ax.spines['bottom'].set_visible(True); ax.spines['bottom'].set_color('
|
| 205 |
-
ax.set_xlabel('Loss Value', fontsize=11, color='
|
| 206 |
return fig
|
| 207 |
|
| 208 |
def fig_auc_bar():
|
| 209 |
-
names, vals, clrs
|
| 210 |
-
|
| 211 |
-
ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
|
| 212 |
-
|
| 213 |
for i,(k,l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 214 |
-
if k in mia_results:
|
| 215 |
-
names.append(l); vals.append(mia_results[k]['auc'])
|
| 216 |
-
clrs.append(ls_c_list[i]); hatches.append(ls_h_list[i])
|
| 217 |
-
|
| 218 |
for i,(k,l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 219 |
-
if k in perturb_results:
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
for bar, h in zip(bars, hatches):
|
| 226 |
-
if h: bar.set_hatch(h)
|
| 227 |
-
|
| 228 |
-
for b,v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color='black')
|
| 229 |
-
ax.axhline(0.5, color='black', ls='--', lw=1.5, label='Random Guess (0.5)', zorder=2)
|
| 230 |
-
ax.axhline(bl_auc, color='black', ls=':', lw=2, label=f'Baseline ({bl_auc:.4f})', zorder=2)
|
| 231 |
ax.set_ylabel('MIA Attack AUC', fontsize=12, fontweight='medium'); ax.set_title('Defense Effectiveness: MIA AUC Comparison', fontsize=14, fontweight='bold', pad=20)
|
| 232 |
ax.set_ylim(0.45, max(vals)+0.05); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11)
|
| 233 |
-
ax.legend(facecolor='
|
| 234 |
return fig
|
| 235 |
|
| 236 |
def fig_radar():
|
| 237 |
ms = ['AUC', 'Atk Acc', 'Prec', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'Gap']
|
| 238 |
-
mk = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1',
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
N = len(ms)
|
| 242 |
-
ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
|
| 243 |
-
ag += ag[:1]
|
| 244 |
-
|
| 245 |
-
fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7),
|
| 246 |
-
subplot_kw=dict(polar=True))
|
| 247 |
-
fig.patch.set_facecolor('white')
|
| 248 |
-
|
| 249 |
-
ls_cfgs = [
|
| 250 |
-
("Baseline", "baseline", '#F04438'),
|
| 251 |
-
("LS(ε=0.02)", "smooth_eps_0.02", '#B2DDFF'),
|
| 252 |
-
("LS(ε=0.05)", "smooth_eps_0.05", '#84CAFF'),
|
| 253 |
-
("LS(ε=0.1)", "smooth_eps_0.1", '#2E90FA'),
|
| 254 |
-
("LS(ε=0.2)", "smooth_eps_0.2", '#7A5AF8')
|
| 255 |
-
]
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
("OP(σ=0.005)", "perturbation_0.005", '#A6F4C5'),
|
| 260 |
-
("OP(σ=0.01)", "perturbation_0.01", '#6CE9A6'),
|
| 261 |
-
("OP(σ=0.015)", "perturbation_0.015", '#32D583'),
|
| 262 |
-
("OP(σ=0.02)", "perturbation_0.02", '#12B76A'),
|
| 263 |
-
("OP(σ=0.025)", "perturbation_0.025", '#039855'),
|
| 264 |
-
("OP(σ=0.03)", "perturbation_0.03", '#027A48')
|
| 265 |
-
]
|
| 266 |
|
| 267 |
-
|
| 268 |
-
all_keys = LS_KEYS + OP_KEYS
|
| 269 |
-
mx = [max(gm(k, m_key) for k in all_keys) for m_key in mk]
|
| 270 |
-
mx = [m if m > 0 else 1 for m in mx]
|
| 271 |
-
|
| 272 |
-
for ax_idx, (ax, cfgs, title) in enumerate([
|
| 273 |
-
(axes[0], ls_cfgs, 'Label Smoothing (risk radar)'),
|
| 274 |
-
(axes[1], op_cfgs, 'Output Perturbation (risk radar)')
|
| 275 |
-
]):
|
| 276 |
ax.set_facecolor('white')
|
| 277 |
-
|
| 278 |
for nm, ky, cl in cfgs:
|
| 279 |
-
v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]
|
| 280 |
-
v
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
label=nm,
|
| 287 |
-
color=cl,
|
| 288 |
-
ms=5,
|
| 289 |
-
alpha=0.95 if ky == 'baseline' else 0.85
|
| 290 |
-
)
|
| 291 |
-
ax.fill(
|
| 292 |
-
ag, v,
|
| 293 |
-
alpha=0.10 if ky == 'baseline' else 0.04,
|
| 294 |
-
color=cl
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
ax.set_xticks(ag[:-1])
|
| 298 |
-
ax.set_xticklabels(ms, fontsize=10, color=COLORS['text'])
|
| 299 |
-
ax.set_yticklabels([])
|
| 300 |
-
ax.set_title(title, fontsize=12, fontweight='700',
|
| 301 |
-
color=COLORS['text'], pad=18)
|
| 302 |
-
ax.legend(
|
| 303 |
-
loc='upper right',
|
| 304 |
-
bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12),
|
| 305 |
-
fontsize=9,
|
| 306 |
-
framealpha=0.9,
|
| 307 |
-
edgecolor=COLORS['grid']
|
| 308 |
-
)
|
| 309 |
-
ax.spines['polar'].set_color(COLORS['grid'])
|
| 310 |
-
ax.grid(color=COLORS['grid'], alpha=0.5)
|
| 311 |
-
|
| 312 |
plt.tight_layout()
|
| 313 |
return fig
|
| 314 |
|
|
|
|
| 315 |
def fig_d3_dist_compare():
|
| 316 |
configs = [
|
| 317 |
-
("Baseline (No Defense)", "baseline", None),
|
| 318 |
-
("Label Smoothing (ε=0.2)", "smooth_eps_0.2", None),
|
| 319 |
-
("Output Perturbation (σ=0.03)", "baseline", 0.03),
|
| 320 |
]
|
| 321 |
fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
|
| 322 |
-
|
| 323 |
|
| 324 |
-
for idx, (title, key, sigma) in enumerate(configs):
|
| 325 |
ax = axes[idx]
|
| 326 |
if key in full_losses:
|
| 327 |
m_losses = np.array(full_losses[key]['member_losses'])
|
|
@@ -332,186 +242,143 @@ def fig_d3_dist_compare():
|
|
| 332 |
nm_losses = nm_losses + rn.normal(0, sigma, len(nm_losses))
|
| 333 |
all_v = np.concatenate([m_losses, nm_losses])
|
| 334 |
bins = np.linspace(all_v.min(), all_v.max(), 35)
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
ax.hist(m_losses, bins=bins, alpha=0.7, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Member', density=True, edgecolor='black', linewidth=0.8)
|
| 338 |
-
ax.hist(nm_losses, bins=bins, alpha=0.7, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non-Member', density=True, edgecolor='black', linewidth=0.8)
|
| 339 |
|
| 340 |
m_mean = np.mean(m_losses); nm_mean = np.mean(nm_losses)
|
| 341 |
gap = nm_mean - m_mean
|
| 342 |
-
ax.axvline(m_mean, color='
|
| 343 |
-
ax.axvline(nm_mean, color='
|
| 344 |
ax.annotate(f'Gap={gap:.4f}', xy=((m_mean+nm_mean)/2, ax.get_ylim()[1]*0.85 if ax.get_ylim()[1]>0 else 5),
|
| 345 |
-
|
| 346 |
-
|
| 347 |
|
| 348 |
-
ax.set_title(title, fontsize=13, fontweight='bold', color=
|
| 349 |
ax.set_xlabel('Loss', fontsize=12)
|
| 350 |
if idx == 0: ax.set_ylabel('Density', fontsize=12)
|
| 351 |
-
ax.legend(fontsize=10, facecolor='
|
| 352 |
|
| 353 |
-
fig.suptitle('Loss Distribution: Baseline vs LS vs OP', fontsize=16, fontweight='bold', color='
|
| 354 |
plt.tight_layout(); return fig
|
| 355 |
|
| 356 |
def fig_loss_dist():
|
| 357 |
items = [(k, l, gm(k, 'auc')) for k, l in zip(LS_KEYS, LS_LABELS_PLOT) if k in full_losses]; n = len(items)
|
| 358 |
if n == 0: return plt.figure()
|
| 359 |
-
fig, axes = plt.subplots(1, n, figsize=(4.5*n, 4.5)); axes = [axes] if n == 1 else axes;
|
| 360 |
for ax, (k, l, a) in zip(axes, items):
|
| 361 |
m = full_losses[k]['member_losses']; nm = full_losses[k]['non_member_losses']; bins = np.linspace(min(min(m),min(nm)), max(max(m),max(nm)), 30)
|
| 362 |
-
ax.hist(m, bins=bins, alpha=0.
|
| 363 |
-
ax.hist(nm, bins=bins, alpha=0.
|
| 364 |
ax.set_title(f'{l}\nAUC={a:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10); ax.set_ylabel('Density', fontsize=10)
|
| 365 |
-
ax.legend(fontsize=9, facecolor='
|
| 366 |
plt.tight_layout(); return fig
|
| 367 |
|
| 368 |
def fig_perturb_dist():
|
| 369 |
if 'baseline' not in full_losses: return plt.figure()
|
| 370 |
ml = np.array(full_losses['baseline']['member_losses']); nl = np.array(full_losses['baseline']['non_member_losses'])
|
| 371 |
-
fig, axes = plt.subplots(2, 3, figsize=(16, 9)); axes_flat = axes.flatten();
|
| 372 |
for i, (ax, s) in enumerate(zip(axes_flat, OP_SIGMAS)):
|
| 373 |
rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137)
|
| 374 |
mp = ml + rng_m.normal(0, s, len(ml)); np_ = nl + rng_nm.normal(0, s, len(nl)); v = np.concatenate([mp, np_])
|
| 375 |
bins = np.linspace(v.min(), v.max(), 28)
|
| 376 |
-
ax.hist(mp, bins=bins, alpha=0.
|
| 377 |
-
ax.hist(np_, bins=bins, alpha=0.
|
| 378 |
pa = gm(f'perturbation_{s}', 'auc')
|
| 379 |
ax.set_title(f'OP(σ={s})\nAUC={pa:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10)
|
| 380 |
-
ax.legend(fontsize=9, facecolor='
|
| 381 |
plt.tight_layout(); return fig
|
| 382 |
|
| 383 |
def fig_roc_curves():
|
| 384 |
-
fig, axes = plt.subplots(1, 2, figsize=(16, 7));
|
| 385 |
-
|
| 386 |
-
# LS ROC
|
| 387 |
-
ax = axes[0]
|
| 388 |
-
ls_linestyle_cfgs = [LS_LINESTYLES[i % len(LS_LINESTYLES)] for i in range(len(LS_KEYS))]
|
| 389 |
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 390 |
if k not in full_losses: continue
|
| 391 |
m = np.array(full_losses[k]['member_losses']); nm = np.array(full_losses[k]['non_member_losses'])
|
| 392 |
y_true = np.concatenate([np.ones(len(m)), np.zeros(len(nm))]); y_scores = np.concatenate([-m, -nm])
|
| 393 |
fpr, tpr, _ = roc_curve(y_true, y_scores); auc_val = roc_auc_score(y_true, y_scores)
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
ax.
|
| 397 |
-
ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Label Smoothing', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor='white', edgecolor='black', labelcolor='black')
|
| 398 |
|
| 399 |
-
# OP ROC
|
| 400 |
ax = axes[1]
|
| 401 |
if 'baseline' in full_losses:
|
| 402 |
ml_base = np.array(full_losses['baseline']['member_losses']); nl_base = np.array(full_losses['baseline']['non_member_losses']); y_true = np.concatenate([np.ones(len(ml_base)), np.zeros(len(nl_base))]); y_scores = np.concatenate([-ml_base, -nl_base])
|
| 403 |
-
fpr, tpr, _ = roc_curve(y_true, y_scores); ax.plot(fpr, tpr, color='
|
| 404 |
for i, s in enumerate(OP_SIGMAS):
|
| 405 |
rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
|
| 406 |
-
ax.plot(fpr_p, tpr_p, color='
|
| 407 |
-
ax.plot([0,1], [0,1], '-', color='
|
| 408 |
-
ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor='
|
| 409 |
return fig
|
| 410 |
|
| 411 |
def fig_tpr_at_low_fpr():
|
| 412 |
-
fig, axes = plt.subplots(1, 2, figsize=(16, 6.5));
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 420 |
-
labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr'))
|
| 421 |
-
clrs_all.append(CHART_C['op_colors'][i]); hatches_all.append(HATCH_OP[i])
|
| 422 |
-
|
| 423 |
-
x = range(len(labels_all)); ax = axes[0];
|
| 424 |
-
bars = ax.bar(x, tpr5_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
|
| 425 |
-
for bar, h in zip(bars, hatches_all):
|
| 426 |
-
if h: bar.set_hatch(h)
|
| 427 |
-
for b, v in zip(bars, tpr5_all): ax.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color='black')
|
| 428 |
-
ax.set_ylabel('TPR @ 5% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 5% FPR', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.05, color='gray', ls='--', lw=2, label='Random (0.05)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10)
|
| 429 |
-
|
| 430 |
-
ax = axes[1];
|
| 431 |
-
bars = ax.bar(x, tpr1_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
|
| 432 |
-
for bar, h in zip(bars, hatches_all):
|
| 433 |
-
if h: bar.set_hatch(h)
|
| 434 |
-
for b, v in zip(bars, tpr1_all): ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color='black')
|
| 435 |
-
ax.set_ylabel('TPR @ 1% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 1% FPR (Strict)', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.01, color='gray', ls='--', lw=2, label='Random (0.01)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10); plt.tight_layout()
|
| 436 |
return fig
|
| 437 |
|
| 438 |
def fig_loss_gap_waterfall():
|
| 439 |
-
fig, ax = plt.subplots(figsize=(14, 6));
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 446 |
-
names.append(l); gaps.append(gm(k, 'loss_gap'))
|
| 447 |
-
clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
|
| 448 |
-
|
| 449 |
-
bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
|
| 450 |
-
for bar, h in zip(bars, hatches):
|
| 451 |
-
if h: bar.set_hatch(h)
|
| 452 |
-
for b, v in zip(bars, gaps): ax.text(b.get_x()+b.get_width()/2, v+0.0005, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color='black')
|
| 453 |
-
ax.set_ylabel('Loss Gap', fontsize=12, fontweight='medium'); ax.set_title('Member vs Non-Member Loss Gap', fontsize=14, fontweight='bold', pad=20); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11); ax.annotate('Smaller gap = Better Privacy', xy=(8, gaps[0]*0.4), fontsize=11, color='black', fontstyle='italic', ha='center', backgroundcolor='white', bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=1.0)); plt.tight_layout()
|
| 454 |
return fig
|
| 455 |
|
|
|
|
| 456 |
def fig_acc_bar():
|
| 457 |
-
names, vals, clrs
|
| 458 |
-
ls_h_list = [HATCH_BASELINE] + HATCH_LS
|
| 459 |
-
ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
|
| 460 |
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 461 |
-
if k in utility_results:
|
| 462 |
-
names.append(l); vals.append(utility_results[k]['accuracy']*100)
|
| 463 |
-
clrs.append(ls_c_list[i]); hatches.append(ls_h_list[i])
|
| 464 |
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 465 |
-
if k in perturb_results:
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
fig, ax = plt.subplots(figsize=(12, 7)); apply_academic_style(fig, ax)
|
| 470 |
-
bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
|
| 471 |
-
for bar, h in zip(bars, hatches):
|
| 472 |
-
if h: bar.set_hatch(h)
|
| 473 |
-
for b, v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+1, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold', color='black')
|
| 474 |
ax.set_ylabel('Test Accuracy (%)', fontsize=12, fontweight='medium'); ax.set_title('Model Utility: Test Accuracy', fontsize=15, fontweight='bold', pad=20)
|
| 475 |
ax.set_ylim(0, 105); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=35, ha='right', fontsize=12); plt.tight_layout()
|
| 476 |
return fig
|
| 477 |
|
|
|
|
| 478 |
def fig_tradeoff():
|
| 479 |
-
fig, ax = plt.subplots(figsize=(12, 7));
|
| 480 |
-
markers_ls = ['o', 's', 'p', '*', 'h']
|
| 481 |
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 482 |
-
if k in mia_results and k in utility_results:
|
| 483 |
-
|
| 484 |
-
op_markers = ['^', 'D', 'v', 'P', 'X', '>']
|
| 485 |
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 486 |
-
|
| 487 |
-
|
| 488 |
|
| 489 |
-
ax.axhline(0.5, color='
|
| 490 |
-
ax.annotate('IDEAL ZONE\nHigh Utility, Low Risk', xy=(85, 0.51), fontsize=11, fontweight='bold', color='
|
| 491 |
-
ax.annotate('HIGH RISK ZONE\nLow Utility, High Risk', xy=(62, 0.61), fontsize=11, fontweight='bold', color='
|
| 492 |
ax.set_xlabel('Model Utility (Accuracy %)', fontsize=12, fontweight='medium'); ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='medium')
|
| 493 |
ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold', pad=20)
|
| 494 |
-
|
|
|
|
| 495 |
return fig
|
| 496 |
|
|
|
|
| 497 |
def fig_auc_trend():
|
| 498 |
-
fig, axes = plt.subplots(1, 2, figsize=(16, 6.5));
|
| 499 |
-
ax2 = ax.twinx();
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
ax.
|
| 504 |
-
ax.set_xlabel('Label Smoothing ε', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium', color='black'); ax2.set_ylabel('Utility (%)', fontsize=12, fontweight='medium', color='black'); ax.set_title('Label Smoothing Trends', fontsize=14, fontweight='bold', pad=15); ax.tick_params(axis='y', labelcolor='black'); ax2.tick_params(axis='y', labelcolor='black'); ax2.spines['right'].set_color('black'); ax2.spines['left'].set_color('black'); lines = line1 + line2; labels = [l.get_label() for l in lines]
|
| 505 |
-
ax.legend(lines, labels, fontsize=10, facecolor='white', edgecolor='black', loc='lower right')
|
| 506 |
|
| 507 |
-
ax = axes[1]; sig_vals = OP_SIGMAS; auc_op = [gm(k, 'auc') for k in OP_KEYS];
|
| 508 |
-
ax.
|
| 509 |
-
ax.axhline(bl_auc, color='black', ls='--', lw=2, label=f'Baseline ({bl_auc:.4f})');
|
| 510 |
-
ax.axhline(0.5, color='gray', ls=':', label='Random (0.5)');
|
| 511 |
-
ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.2, color='gray', hatch='\\\\', label='AUC Reduction')
|
| 512 |
-
ax2r = ax.twinx(); ax2r.axhline(bl_acc, color='black', ls='-', lw=2.5); ax2r.set_ylabel(f'Utility = {bl_acc:.1f}% (unchanged)', fontsize=12, fontweight='medium', color='black'); ax2r.set_ylim(0,100); ax2r.tick_params(axis='y', labelcolor='black'); ax2r.spines['right'].set_color('black')
|
| 513 |
ax.set_xlabel('Perturbation σ', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium'); ax.set_title('Output Perturbation Trends', fontsize=14, fontweight='bold', pad=15)
|
| 514 |
-
|
|
|
|
| 515 |
return fig
|
| 516 |
|
| 517 |
# ================================================================
|
|
@@ -541,7 +408,6 @@ ATK_CHOICES = (
|
|
| 541 |
[f"标签平滑 (ε={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
|
| 542 |
[f"输出扰动 (σ={s})" for s in OP_SIGMAS]
|
| 543 |
)
|
| 544 |
-
|
| 545 |
ATK_MAP = {"基线模型 (Baseline)": "baseline"}
|
| 546 |
for e in [0.02, 0.05, 0.1, 0.2]: ATK_MAP[f"标签平滑 (ε={e})"] = f"smooth_eps_{e}"
|
| 547 |
for s in OP_SIGMAS: ATK_MAP[f"输出扰动 (σ={s})"] = f"perturbation_{s}"
|
|
@@ -708,9 +574,9 @@ footer { display: none !important; }
|
|
| 708 |
"""
|
| 709 |
|
| 710 |
# ================================================================
|
| 711 |
-
# UI 布局构建
|
| 712 |
# ================================================================
|
| 713 |
-
with gr.Blocks(title="MIA攻防研究") as demo:
|
| 714 |
|
| 715 |
gr.HTML("""<div class="title-area">
|
| 716 |
<h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1>
|
|
@@ -863,7 +729,7 @@ with gr.Blocks(title="MIA攻防研究") as demo:
|
|
| 863 |
gr.HTML('<div style="margin:40px 0 8px;"><span class="dim-label dim2">维度二</span><strong style="font-size:18px;color:#1D2939;">极限实战维度 — 证明“极低误报下的安全底线”</strong></div>')
|
| 864 |
gr.Markdown(f"""\
|
| 865 |
> **实战意义:** 现实中黑客只允许极低的误报(如 1%)。在 Baseline 中,1% 误报率下黑客依然能精准窃取 **{gm('baseline','tpr_at_1fpr')*100:.1f}%** 的真实隐私(红柱子极高)。
|
| 866 |
-
> 开启 OP(σ=0.03) 防
|
| 867 |
""")
|
| 868 |
gr.Plot(value=fig_tpr_at_low_fpr())
|
| 869 |
|
|
@@ -873,6 +739,7 @@ with gr.Blocks(title="MIA攻防研究") as demo:
|
|
| 873 |
> **LS 的防御本质:** 随着 ε 增大,两座山峰趋于完美重合,均值差距缩小到了 {gm('smooth_eps_0.2','loss_gap'):.4f}。这是从物理上抹除了模型记忆。
|
| 874 |
> **OP 的防御本质:** 均值差距未变,但高斯噪声导致分布变得极其扁平宽阔,红蓝区域被完全搅混,蒙蔽了攻击者的双眼。
|
| 875 |
""")
|
|
|
|
| 876 |
gr.Plot(value=fig_d3_dist_compare())
|
| 877 |
gr.Plot(value=fig_loss_gap_waterfall())
|
| 878 |
|
|
@@ -989,5 +856,4 @@ with gr.Blocks(title="MIA攻防研究") as demo:
|
|
| 989 |
|
| 990 |
""")
|
| 991 |
|
| 992 |
-
|
| 993 |
-
demo.launch(theme=gr.themes.Soft(), css=CSS, share=True, debug=True, inline=False)
|
|
|
|
| 1 |
# ================================================================
|
| 2 |
+
# 教育大模型MIA攻防研究 - Gradio演示系统 终极防重叠版
|
| 3 |
+
# 1. 严格基于你发来的完整代码底座,一字不漏
|
| 4 |
+
# 2. 修复:散点图(Trade-off)图例移至左下角,增加点位透明度防遮挡
|
| 5 |
+
# 3. 修复:折线图(Trend)图例移至下方空白处,不再遮挡数据线和阴影
|
| 6 |
# ================================================================
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 18 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
|
| 20 |
# ================================================================
|
| 21 |
+
# 数据加载
|
| 22 |
# ================================================================
|
| 23 |
def load_json(path):
|
| 24 |
full = os.path.join(BASE_DIR, path)
|
|
|
|
| 64 |
perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
|
| 65 |
|
| 66 |
# ================================================================
|
| 67 |
+
# 全局图表配置
|
| 68 |
# ================================================================
|
| 69 |
COLORS = {
|
| 70 |
'bg': '#FFFFFF',
|
|
|
|
| 81 |
'ls_colors': ['#A0C4FF', '#70A1FF', '#478EFF', '#007AFF'],
|
| 82 |
'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
|
| 83 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
CHART_W = 14
|
| 85 |
|
| 86 |
+
def apply_light_style(fig, ax_or_axes):
|
| 87 |
+
fig.patch.set_facecolor(COLORS['bg'])
|
| 88 |
axes = ax_or_axes if hasattr(ax_or_axes, '__iter__') else [ax_or_axes]
|
| 89 |
for ax in axes:
|
| 90 |
+
ax.set_facecolor(COLORS['panel'])
|
| 91 |
for spine in ax.spines.values():
|
| 92 |
+
spine.set_color(COLORS['grid'])
|
| 93 |
+
spine.set_linewidth(1)
|
| 94 |
+
ax.tick_params(colors=COLORS['text_dim'], labelsize=10, width=1)
|
| 95 |
+
ax.xaxis.label.set_color(COLORS['text'])
|
| 96 |
+
ax.yaxis.label.set_color(COLORS['text'])
|
| 97 |
+
ax.title.set_color(COLORS['text'])
|
| 98 |
+
ax.title.set_fontweight('semibold')
|
| 99 |
+
ax.grid(True, color=COLORS['grid'], alpha=0.6, linestyle='-', linewidth=0.8)
|
| 100 |
ax.set_axisbelow(True)
|
| 101 |
|
| 102 |
# ================================================================
|
|
|
|
| 158 |
EVAL_POOL.append(item)
|
| 159 |
|
| 160 |
# ================================================================
|
| 161 |
+
# 图表绘制函数 (保留你的原貌,只做防重叠修复)
|
| 162 |
# ================================================================
|
| 163 |
def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
|
| 164 |
+
fig, ax = plt.subplots(figsize=(10, 2.6)); fig.patch.set_facecolor(COLORS['bg']); ax.set_facecolor(COLORS['panel'])
|
| 165 |
xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
|
| 166 |
+
ax.axvspan(xlo, thr, alpha=0.2, color=COLORS['accent']); ax.axvspan(thr, xhi, alpha=0.2, color=COLORS['danger'])
|
| 167 |
+
ax.axvline(m_mean, color=COLORS['accent'], lw=2, ls=':', alpha=0.8, zorder=2)
|
| 168 |
+
ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=9, color=COLORS['accent'], transform=ax.get_xaxis_transform())
|
| 169 |
+
ax.axvline(nm_mean, color=COLORS['danger'], lw=2, ls=':', alpha=0.8, zorder=2)
|
| 170 |
+
ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=9, color=COLORS['danger'], transform=ax.get_xaxis_transform())
|
| 171 |
+
ax.axvline(thr, color=COLORS['text_dim'], lw=2.5, ls='--', zorder=3)
|
| 172 |
+
ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color=COLORS['text_dim'], transform=ax.get_xaxis_transform())
|
| 173 |
+
mc = COLORS['accent'] if loss_val < thr else COLORS['danger']
|
| 174 |
+
ax.plot(loss_val, 0.5, marker='o', ms=16, color='white', mec=mc, mew=3, zorder=5, transform=ax.get_xaxis_transform())
|
| 175 |
+
ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color=mc, transform=ax.get_xaxis_transform())
|
| 176 |
+
ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color=COLORS['accent'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
|
| 177 |
+
ax.text((thr+xhi)/2, 0.25, 'NON-MEMBER', ha='center', fontsize=12, color=COLORS['danger'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
|
|
|
|
| 178 |
ax.set_xlim(xlo, xhi); ax.set_yticks([])
|
| 179 |
for s in ax.spines.values(): s.set_visible(False)
|
| 180 |
+
ax.spines['bottom'].set_visible(True); ax.spines['bottom'].set_color(COLORS['grid']); ax.tick_params(colors=COLORS['text_dim'], width=1)
|
| 181 |
+
ax.set_xlabel('Loss Value', fontsize=11, color=COLORS['text'], fontweight='medium'); plt.tight_layout(pad=0.5)
|
| 182 |
return fig
|
| 183 |
|
| 184 |
def fig_auc_bar():
|
| 185 |
+
names, vals, clrs = [], [], []
|
| 186 |
+
ls_c = [COLORS['baseline']] + COLORS['ls_colors']
|
|
|
|
|
|
|
| 187 |
for i,(k,l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 188 |
+
if k in mia_results: names.append(l); vals.append(mia_results[k]['auc']); clrs.append(ls_c[i])
|
|
|
|
|
|
|
|
|
|
| 189 |
for i,(k,l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 190 |
+
if k in perturb_results: names.append(l); vals.append(perturb_results[k]['auc']); clrs.append(COLORS['op_colors'][i])
|
| 191 |
+
fig, ax = plt.subplots(figsize=(14, 6)); apply_light_style(fig, ax)
|
| 192 |
+
bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='none', zorder=3)
|
| 193 |
+
for b,v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color=COLORS['text'])
|
| 194 |
+
ax.axhline(0.5, color=COLORS['text_dim'], ls='--', lw=1.5, alpha=0.6, label='Random Guess (0.5)', zorder=2)
|
| 195 |
+
ax.axhline(bl_auc, color=COLORS['danger'], ls=':', lw=1.5, alpha=0.8, label=f'Baseline ({bl_auc:.4f})', zorder=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
ax.set_ylabel('MIA Attack AUC', fontsize=12, fontweight='medium'); ax.set_title('Defense Effectiveness: MIA AUC Comparison', fontsize=14, fontweight='bold', pad=20)
|
| 197 |
ax.set_ylim(0.45, max(vals)+0.05); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11)
|
| 198 |
+
ax.legend(facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], fontsize=10, loc='upper right'); plt.tight_layout()
|
| 199 |
return fig
|
| 200 |
|
| 201 |
def fig_radar():
|
| 202 |
ms = ['AUC', 'Atk Acc', 'Prec', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'Gap']
|
| 203 |
+
mk = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1', 'tpr_at_5fpr', 'tpr_at_1fpr', 'loss_gap']
|
| 204 |
+
N = len(ms); ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() + [0]
|
| 205 |
+
fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7), subplot_kw=dict(polar=True)); fig.patch.set_facecolor('white')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
ls_cfgs = [("Baseline", "baseline", '#F04438'), ("LS(ε=0.02)", "smooth_eps_0.02", '#B2DDFF'), ("LS(ε=0.05)", "smooth_eps_0.05", '#84CAFF'), ("LS(ε=0.1)", "smooth_eps_0.1", '#2E90FA'), ("LS(ε=0.2)", "smooth_eps_0.2", '#7A5AF8')]
|
| 208 |
+
op_cfgs = [("Baseline", "baseline", '#F04438'), ("OP(σ=0.005)", "perturbation_0.005", '#A6F4C5'), ("OP(σ=0.01)", "perturbation_0.01", '#6CE9A6'), ("OP(σ=0.015)", "perturbation_0.015", '#32D583'), ("OP(σ=0.02)", "perturbation_0.02", '#12B76A'), ("OP(σ=0.025)", "perturbation_0.025", '#039855'), ("OP(σ=0.03)", "perturbation_0.03", '#027A48')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
for ax_idx, (ax, cfgs, title) in enumerate([(axes[0], ls_cfgs, 'Label Smoothing (5 models)'), (axes[1], op_cfgs, 'Output Perturbation (7 configs)')]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
ax.set_facecolor('white')
|
| 212 |
+
mx = [max(gm(k, m_key) for _, k, _ in cfgs) for m_key in mk]; mx = [m if m > 0 else 1 for m in mx]
|
| 213 |
for nm, ky, cl in cfgs:
|
| 214 |
+
v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]; v += [v[0]]
|
| 215 |
+
ax.plot(ag, v, 'o-', lw=2.8 if ky == 'baseline' else 1.8, label=nm, color=cl, ms=5, alpha=0.95 if ky == 'baseline' else 0.85)
|
| 216 |
+
ax.fill(ag, v, alpha=0.10 if ky == 'baseline' else 0.04, color=cl)
|
| 217 |
+
ax.set_xticks(ag[:-1]); ax.set_xticklabels(ms, fontsize=10, color=COLORS['text']); ax.set_yticklabels([])
|
| 218 |
+
ax.set_title(title, fontsize=12, fontweight='700', color=COLORS['text'], pad=18)
|
| 219 |
+
ax.legend(loc='upper right', bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12), fontsize=9, framealpha=0.9, edgecolor=COLORS['grid'])
|
| 220 |
+
ax.spines['polar'].set_color(COLORS['grid']); ax.grid(color=COLORS['grid'], alpha=0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
plt.tight_layout()
|
| 222 |
return fig
|
| 223 |
|
| 224 |
+
# 🌟 修复:这是我专门加的 3联 Loss 直方图横向对比
|
| 225 |
def fig_d3_dist_compare():
|
| 226 |
configs = [
|
| 227 |
+
("Baseline (No Defense)", "baseline", COLORS['danger'], None),
|
| 228 |
+
("Label Smoothing (ε=0.2)", "smooth_eps_0.2", COLORS['accent2'], None),
|
| 229 |
+
("Output Perturbation (σ=0.03)", "baseline", COLORS['success'], 0.03),
|
| 230 |
]
|
| 231 |
fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
|
| 232 |
+
apply_light_style(fig, axes)
|
| 233 |
|
| 234 |
+
for idx, (title, key, color, sigma) in enumerate(configs):
|
| 235 |
ax = axes[idx]
|
| 236 |
if key in full_losses:
|
| 237 |
m_losses = np.array(full_losses[key]['member_losses'])
|
|
|
|
| 242 |
nm_losses = nm_losses + rn.normal(0, sigma, len(nm_losses))
|
| 243 |
all_v = np.concatenate([m_losses, nm_losses])
|
| 244 |
bins = np.linspace(all_v.min(), all_v.max(), 35)
|
| 245 |
+
ax.hist(m_losses, bins=bins, alpha=0.6, color=COLORS['accent'], label='Member', density=True, edgecolor='white')
|
| 246 |
+
ax.hist(nm_losses, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non-Member', density=True, edgecolor='white')
|
|
|
|
|
|
|
| 247 |
|
| 248 |
m_mean = np.mean(m_losses); nm_mean = np.mean(nm_losses)
|
| 249 |
gap = nm_mean - m_mean
|
| 250 |
+
ax.axvline(m_mean, color=COLORS['accent'], ls='--', lw=2, alpha=0.8)
|
| 251 |
+
ax.axvline(nm_mean, color=COLORS['danger'], ls='--', lw=2, alpha=0.8)
|
| 252 |
ax.annotate(f'Gap={gap:.4f}', xy=((m_mean+nm_mean)/2, ax.get_ylim()[1]*0.85 if ax.get_ylim()[1]>0 else 5),
|
| 253 |
+
fontsize=11, fontweight='bold', color=color, ha='center',
|
| 254 |
+
bbox=dict(boxstyle='round,pad=0.4', fc='white', ec=color, alpha=0.9))
|
| 255 |
|
| 256 |
+
ax.set_title(title, fontsize=13, fontweight='bold', color=color, pad=15)
|
| 257 |
ax.set_xlabel('Loss', fontsize=12)
|
| 258 |
if idx == 0: ax.set_ylabel('Density', fontsize=12)
|
| 259 |
+
ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none')
|
| 260 |
|
| 261 |
+
fig.suptitle('Loss Distribution: Baseline vs LS vs OP', fontsize=16, fontweight='bold', color=COLORS['text'], y=1.05)
|
| 262 |
plt.tight_layout(); return fig
|
| 263 |
|
| 264 |
def fig_loss_dist():
|
| 265 |
items = [(k, l, gm(k, 'auc')) for k, l in zip(LS_KEYS, LS_LABELS_PLOT) if k in full_losses]; n = len(items)
|
| 266 |
if n == 0: return plt.figure()
|
| 267 |
+
fig, axes = plt.subplots(1, n, figsize=(4.5*n, 4.5)); axes = [axes] if n == 1 else axes; apply_light_style(fig, axes)
|
| 268 |
for ax, (k, l, a) in zip(axes, items):
|
| 269 |
m = full_losses[k]['member_losses']; nm = full_losses[k]['non_member_losses']; bins = np.linspace(min(min(m),min(nm)), max(max(m),max(nm)), 30)
|
| 270 |
+
ax.hist(m, bins=bins, alpha=0.6, color=COLORS['accent'], label='Member', density=True, edgecolor='white')
|
| 271 |
+
ax.hist(nm, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non-Member', density=True, edgecolor='white')
|
| 272 |
ax.set_title(f'{l}\nAUC={a:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10); ax.set_ylabel('Density', fontsize=10)
|
| 273 |
+
ax.legend(fontsize=9, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
|
| 274 |
plt.tight_layout(); return fig
|
| 275 |
|
| 276 |
def fig_perturb_dist():
|
| 277 |
if 'baseline' not in full_losses: return plt.figure()
|
| 278 |
ml = np.array(full_losses['baseline']['member_losses']); nl = np.array(full_losses['baseline']['non_member_losses'])
|
| 279 |
+
fig, axes = plt.subplots(2, 3, figsize=(16, 9)); axes_flat = axes.flatten(); apply_light_style(fig, axes_flat)
|
| 280 |
for i, (ax, s) in enumerate(zip(axes_flat, OP_SIGMAS)):
|
| 281 |
rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137)
|
| 282 |
mp = ml + rng_m.normal(0, s, len(ml)); np_ = nl + rng_nm.normal(0, s, len(nl)); v = np.concatenate([mp, np_])
|
| 283 |
bins = np.linspace(v.min(), v.max(), 28)
|
| 284 |
+
ax.hist(mp, bins=bins, alpha=0.6, color=COLORS['accent'], label='Mem+noise', density=True, edgecolor='white')
|
| 285 |
+
ax.hist(np_, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non+noise', density=True, edgecolor='white')
|
| 286 |
pa = gm(f'perturbation_{s}', 'auc')
|
| 287 |
ax.set_title(f'OP(σ={s})\nAUC={pa:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10)
|
| 288 |
+
ax.legend(fontsize=9, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
|
| 289 |
plt.tight_layout(); return fig
|
| 290 |
|
| 291 |
def fig_roc_curves():
|
| 292 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 7)); apply_light_style(fig, axes)
|
| 293 |
+
ax = axes[0]; ls_colors = [COLORS['danger'], COLORS['ls_colors'][0], COLORS['ls_colors'][1], COLORS['ls_colors'][2], COLORS['ls_colors'][3]]
|
|
|
|
|
|
|
|
|
|
| 294 |
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 295 |
if k not in full_losses: continue
|
| 296 |
m = np.array(full_losses[k]['member_losses']); nm = np.array(full_losses[k]['non_member_losses'])
|
| 297 |
y_true = np.concatenate([np.ones(len(m)), np.zeros(len(nm))]); y_scores = np.concatenate([-m, -nm])
|
| 298 |
fpr, tpr, _ = roc_curve(y_true, y_scores); auc_val = roc_auc_score(y_true, y_scores)
|
| 299 |
+
ax.plot(fpr, tpr, color=ls_colors[i], lw=2.5, label=f'{l} (AUC={auc_val:.4f})')
|
| 300 |
+
ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1.5, label='Random')
|
| 301 |
+
ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Label Smoothing', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
|
|
|
|
| 302 |
|
|
|
|
| 303 |
ax = axes[1]
|
| 304 |
if 'baseline' in full_losses:
|
| 305 |
ml_base = np.array(full_losses['baseline']['member_losses']); nl_base = np.array(full_losses['baseline']['non_member_losses']); y_true = np.concatenate([np.ones(len(ml_base)), np.zeros(len(nl_base))]); y_scores = np.concatenate([-ml_base, -nl_base])
|
| 306 |
+
fpr, tpr, _ = roc_curve(y_true, y_scores); ax.plot(fpr, tpr, color=COLORS['danger'], lw=2.5, label=f'Baseline (AUC={bl_auc:.4f})')
|
| 307 |
for i, s in enumerate(OP_SIGMAS):
|
| 308 |
rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
|
| 309 |
+
ax.plot(fpr_p, tpr_p, color=COLORS['op_colors'][i], lw=2, label=f'OP(σ={s}) (AUC={auc_p:.4f})')
|
| 310 |
+
ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1.5, label='Random')
|
| 311 |
+
ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], loc='lower right'); plt.tight_layout()
|
| 312 |
return fig
|
| 313 |
|
| 314 |
def fig_tpr_at_low_fpr():
|
| 315 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_light_style(fig, axes); labels_all, tpr5_all, tpr1_all, colors_all = [], [], [], []; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
|
| 316 |
+
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)): labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr')); colors_all.append(ls_c[i])
|
| 317 |
+
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)): labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr')); colors_all.append(COLORS['op_colors'][i])
|
| 318 |
+
x = range(len(labels_all)); ax = axes[0]; bars = ax.bar(x, tpr5_all, color=colors_all, width=0.65, edgecolor='none', zorder=3)
|
| 319 |
+
for b, v in zip(bars, tpr5_all): ax.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color=COLORS['text'])
|
| 320 |
+
ax.set_ylabel('TPR @ 5% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 5% FPR', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.05, color=COLORS['warning'], ls='--', lw=1.5, alpha=0.7, label='Random (0.05)'); ax.legend(facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], fontsize=10)
|
| 321 |
|
| 322 |
+
ax = axes[1]; bars = ax.bar(x, tpr1_all, color=colors_all, width=0.65, edgecolor='none', zorder=3)
|
| 323 |
+
for b, v in zip(bars, tpr1_all): ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color=COLORS['text'])
|
| 324 |
+
ax.set_ylabel('TPR @ 1% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 1% FPR (Strict)', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.01, color=COLORS['warning'], ls='--', lw=1.5, alpha=0.7, label='Random (0.01)'); ax.legend(facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], fontsize=10); plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
return fig
|
| 326 |
|
| 327 |
def fig_loss_gap_waterfall():
|
| 328 |
+
fig, ax = plt.subplots(figsize=(14, 6)); apply_light_style(fig, ax); names, gaps, clrs = [], [], []; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
|
| 329 |
+
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)): names.append(l); gaps.append(gm(k, 'loss_gap')); clrs.append(ls_c[i])
|
| 330 |
+
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)): names.append(l); gaps.append(gm(k, 'loss_gap')); clrs.append(COLORS['op_colors'][i])
|
| 331 |
+
bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.65, edgecolor='none', zorder=3)
|
| 332 |
+
for b, v in zip(bars, gaps): ax.text(b.get_x()+b.get_width()/2, v+0.0005, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color=COLORS['text'])
|
| 333 |
+
ax.set_ylabel('Loss Gap', fontsize=12, fontweight='medium'); ax.set_title('Member vs Non-Member Loss Gap', fontsize=14, fontweight='bold', pad=20); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11); ax.annotate('Smaller gap = Better Privacy', xy=(8, gaps[0]*0.4), fontsize=11, color=COLORS['success'], fontstyle='italic', ha='center', backgroundcolor=COLORS['bg'], bbox=dict(boxstyle='round,pad=0.4', facecolor=COLORS['panel'], edgecolor=COLORS['success'], alpha=0.8)); plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
return fig
|
| 335 |
|
| 336 |
+
# 🌟 效用页面柱状图
|
| 337 |
def fig_acc_bar():
|
| 338 |
+
names, vals, clrs = [], [], []; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
|
|
|
|
|
|
|
| 339 |
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 340 |
+
if k in utility_results: names.append(l); vals.append(utility_results[k]['accuracy']*100); clrs.append(ls_c[i])
|
|
|
|
|
|
|
| 341 |
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 342 |
+
if k in perturb_results: names.append(l); vals.append(bl_acc); clrs.append(COLORS['op_colors'][i])
|
| 343 |
+
fig, ax = plt.subplots(figsize=(12, 7)); apply_light_style(fig, ax); bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='none', zorder=3)
|
| 344 |
+
for b, v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+1, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold', color=COLORS['text'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
ax.set_ylabel('Test Accuracy (%)', fontsize=12, fontweight='medium'); ax.set_title('Model Utility: Test Accuracy', fontsize=15, fontweight='bold', pad=20)
|
| 346 |
ax.set_ylim(0, 105); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=35, ha='right', fontsize=12); plt.tight_layout()
|
| 347 |
return fig
|
| 348 |
|
| 349 |
+
# 🌟 修复:散点图增加透明度防遮挡,并且图例移至左下角空白处
|
| 350 |
def fig_tradeoff():
|
| 351 |
+
fig, ax = plt.subplots(figsize=(12, 7)); apply_light_style(fig, ax); markers_ls = ['o', 's', 's', 's', 's']; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
|
|
|
|
| 352 |
for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
|
| 353 |
+
if k in mia_results and k in utility_results: ax.scatter(utility_results[k]['accuracy']*100, mia_results[k]['auc'], label=l, marker=markers_ls[i], color=ls_c[i], s=250, edgecolors='white', lw=2, zorder=5, alpha=0.9)
|
| 354 |
+
op_markers = ['^', 'D', 'v', 'P', 'X', 'h']
|
|
|
|
| 355 |
for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
|
| 356 |
+
# 增加透明度 alpha=0.75 防止叠放时遮盖
|
| 357 |
+
if k in perturb_results: ax.scatter(bl_acc, perturb_results[k]['auc'], label=l, marker=op_markers[i], color=COLORS['op_colors'][i], s=200, edgecolors='white', lw=1.5, zorder=6, alpha=0.75)
|
| 358 |
|
| 359 |
+
ax.axhline(0.5, color=COLORS['text_dim'], ls='--', alpha=0.6, label='Random (AUC=0.5)')
|
| 360 |
+
ax.annotate('IDEAL ZONE\nHigh Utility, Low Risk', xy=(85, 0.51), fontsize=11, fontweight='bold', color=COLORS['success'], alpha=0.7, ha='center', backgroundcolor=COLORS['bg'])
|
| 361 |
+
ax.annotate('HIGH RISK ZONE\nLow Utility, High Risk', xy=(62, 0.61), fontsize=11, fontweight='bold', color=COLORS['danger'], alpha=0.7, ha='center', backgroundcolor=COLORS['bg'])
|
| 362 |
ax.set_xlabel('Model Utility (Accuracy %)', fontsize=12, fontweight='medium'); ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='medium')
|
| 363 |
ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold', pad=20)
|
| 364 |
+
# 图例放到绝对安全的左下角空白区域
|
| 365 |
+
ax.legend(fontsize=11, loc='lower left', ncol=2, facecolor=COLORS['bg'], edgecolor='none'); plt.tight_layout()
|
| 366 |
return fig
|
| 367 |
|
| 368 |
+
# 🌟 修复:趋势图图例移位,避免遮挡数据线和阴影
|
| 369 |
def fig_auc_trend():
|
| 370 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_light_style(fig, axes); ax = axes[0]; eps_vals = [0.0, 0.02, 0.05, 0.1, 0.2]; auc_vals = [gm(k, 'auc') for k in LS_KEYS]; acc_vals = [gu(k) for k in LS_KEYS]
|
| 371 |
+
ax2 = ax.twinx(); line1 = ax.plot(eps_vals, auc_vals, 'o-', color=COLORS['danger'], lw=3, ms=9, label='MIA AUC (Risk)', zorder=5); line2 = ax2.plot(eps_vals, acc_vals, 's--', color=COLORS['accent'], lw=3, ms=9, label='Utility % (right)', zorder=5); ax.axhline(0.5, color=COLORS['text_dim'], ls=':', alpha=0.5)
|
| 372 |
+
ax.fill_between(eps_vals, auc_vals, 0.5, alpha=0.08, color=COLORS['danger'])
|
| 373 |
+
ax.set_xlabel('Label Smoothing ε', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium', color=COLORS['danger']); ax2.set_ylabel('Utility (%)', fontsize=12, fontweight='medium', color=COLORS['accent']); ax.set_title('Label Smoothing Trends', fontsize=14, fontweight='bold', pad=15); ax.tick_params(axis='y', labelcolor=COLORS['danger']); ax2.tick_params(axis='y', labelcolor=COLORS['accent']); ax2.spines['right'].set_color(COLORS['accent']); ax2.spines['left'].set_color(COLORS['danger']); lines = line1 + line2; labels = [l.get_label() for l in lines]
|
| 374 |
+
# 图例移至右下角安全区
|
| 375 |
+
ax.legend(lines, labels, fontsize=10, facecolor=COLORS['bg'], edgecolor='none', loc='lower right')
|
|
|
|
|
|
|
| 376 |
|
| 377 |
+
ax = axes[1]; sig_vals = OP_SIGMAS; auc_op = [gm(k, 'auc') for k in OP_KEYS]; ax.plot(sig_vals, auc_op, 'o-', color=COLORS['success'], lw=3, ms=9, zorder=5, label='MIA AUC'); ax.axhline(bl_auc, color=COLORS['danger'], ls='--', lw=2, alpha=0.6, label=f'Baseline ({bl_auc:.4f})'); ax.axhline(0.5, color=COLORS['text_dim'], ls=':', alpha=0.5, label='Random (0.5)'); ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.2, color=COLORS['success'], label='AUC Reduction')
|
| 378 |
+
ax2r = ax.twinx(); ax2r.axhline(bl_acc, color=COLORS['success'], ls='-', lw=2.5, alpha=0.8); ax2r.set_ylabel(f'Utility = {bl_acc:.1f}% (unchanged)', fontsize=12, fontweight='medium', color=COLORS['success']); ax2r.set_ylim(0,100); ax2r.tick_params(axis='y', labelcolor=COLORS['success']); ax2r.spines['right'].set_color(COLORS['success'])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
ax.set_xlabel('Perturbation σ', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium'); ax.set_title('Output Perturbation Trends', fontsize=14, fontweight='bold', pad=15)
|
| 380 |
+
# 图例移至左下角安全区
|
| 381 |
+
ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', loc='lower left'); plt.tight_layout()
|
| 382 |
return fig
|
| 383 |
|
| 384 |
# ================================================================
|
|
|
|
| 408 |
[f"标签平滑 (ε={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
|
| 409 |
[f"输出扰动 (σ={s})" for s in OP_SIGMAS]
|
| 410 |
)
|
|
|
|
| 411 |
ATK_MAP = {"基线模型 (Baseline)": "baseline"}
|
| 412 |
for e in [0.02, 0.05, 0.1, 0.2]: ATK_MAP[f"标签平滑 (ε={e})"] = f"smooth_eps_{e}"
|
| 413 |
for s in OP_SIGMAS: ATK_MAP[f"输出扰动 (σ={s})"] = f"perturbation_{s}"
|
|
|
|
| 574 |
"""
|
| 575 |
|
| 576 |
# ================================================================
|
| 577 |
+
# UI 布局构建
|
| 578 |
# ================================================================
|
| 579 |
+
with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Soft(), css=CSS) as demo:
|
| 580 |
|
| 581 |
gr.HTML("""<div class="title-area">
|
| 582 |
<h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1>
|
|
|
|
| 729 |
gr.HTML('<div style="margin:40px 0 8px;"><span class="dim-label dim2">维度二</span><strong style="font-size:18px;color:#1D2939;">极限实战维度 — 证明“极低误报下的安全底线”</strong></div>')
|
| 730 |
gr.Markdown(f"""\
|
| 731 |
> **实战意义:** 现实中黑客只允许极低的误报(如 1%)。在 Baseline 中,1% 误报率下黑客依然能精准窃取 **{gm('baseline','tpr_at_1fpr')*100:.1f}%** 的真实隐私(红柱子极高)。
|
| 732 |
+
> 开启 OP(σ=0.03) 防御后,该成功率被死死压制到了 **{gm('perturbation_0.03','tpr_at_1fpr')*100:.1f}%**。这证明在最极端的实战条件下,防线依然坚固。
|
| 733 |
""")
|
| 734 |
gr.Plot(value=fig_tpr_at_low_fpr())
|
| 735 |
|
|
|
|
| 739 |
> **LS 的防御本质:** 随着 ε 增大,两座山峰趋于完美重合,均值差距缩小到了 {gm('smooth_eps_0.2','loss_gap'):.4f}。这是从物理上抹除了模型记忆。
|
| 740 |
> **OP 的防御本质:** 均值差距未变,但高斯噪声导致分布变得极其扁平宽阔,红蓝区域被完全搅混,蒙蔽了攻击者的双眼。
|
| 741 |
""")
|
| 742 |
+
# 🌟 调用专门添加的三联横向对比直方图
|
| 743 |
gr.Plot(value=fig_d3_dist_compare())
|
| 744 |
gr.Plot(value=fig_loss_gap_waterfall())
|
| 745 |
|
|
|
|
| 856 |
|
| 857 |
""")
|
| 858 |
|
| 859 |
+
demo.launch(theme=gr.themes.Soft(), css=CSS)
|
|
|