xiaohy commited on
Commit
7d87eae
·
verified ·
1 Parent(s): 5a3efd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -271
app.py CHANGED
@@ -1,8 +1,8 @@
1
  # ================================================================
2
- # 教育大模型MIA攻防研究 - Gradio演示系统 (终极学术黑白底纹)
3
- # 1. 严格基于原版底座,UI界面和回调逻辑一字不改。
4
- # 2. 仅对所有的 fig_xxx 绘函数进行了学术黑白化+花纹底纹改造。
5
- # 3. 修复了参传递括号语法问题,保证100%零报错运行。
6
  # ================================================================
7
 
8
  import os
@@ -18,7 +18,7 @@ import gradio as gr
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
  # ================================================================
21
- # 数据加载
22
  # ================================================================
23
  def load_json(path):
24
  full = os.path.join(BASE_DIR, path)
@@ -64,7 +64,7 @@ except FileNotFoundError:
64
  perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
65
 
66
  # ================================================================
67
- # 全局UI配置 (完全保留您的原版色彩,不影响网页和HTML元素的颜色)
68
  # ================================================================
69
  COLORS = {
70
  'bg': '#FFFFFF',
@@ -81,45 +81,22 @@ COLORS = {
81
  'ls_colors': ['#A0C4FF', '#70A1FF', '#478EFF', '#007AFF'],
82
  'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
83
  }
84
-
85
- # 🌟 专门为图表新增的学术黑白配置集 (Hatch与线型)
86
- CHART_C = {
87
- 'bg': '#FFFFFF',
88
- 'panel': '#FFFFFF',
89
- 'grid': '#E0E0E0',
90
- 'text': '#000000',
91
- 'baseline': '#FFFFFF',
92
- 'ls_colors': ['#EEEEEE', '#CCCCCC', '#AAAAAA', '#888888'],
93
- 'op_colors': ['#F8F8F8', '#E8E8E8', '#D8D8D8', '#C8C8C8', '#B8B8B8', '#A8A8A8'],
94
- 'mem': '#EAEAEA',
95
- 'nmem': '#FFFFFF'
96
- }
97
-
98
- HATCH_BASELINE = ''
99
- HATCH_LS = ['//', '\\\\', 'xx', '++']
100
- HATCH_OP = ['..', 'oo', 'OO', '**', '--', '||']
101
- HATCH_MEMBER = '///'
102
- HATCH_NONMEMBER = '\\\\\\\\'
103
-
104
- LS_LINESTYLES = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
105
- OP_LINESTYLES = ['-', '--', '-.', ':', (0, (5, 1)), (0, (3, 1, 1, 1, 1, 1))]
106
-
107
  CHART_W = 14
108
 
109
- def apply_academic_style(fig, ax_or_axes):
110
- fig.patch.set_facecolor(CHART_C['bg'])
111
  axes = ax_or_axes if hasattr(ax_or_axes, '__iter__') else [ax_or_axes]
112
  for ax in axes:
113
- ax.set_facecolor(CHART_C['panel'])
114
  for spine in ax.spines.values():
115
- spine.set_color('#000000')
116
- spine.set_linewidth(1.0)
117
- ax.tick_params(colors='#000000', labelsize=10, width=1.0)
118
- ax.xaxis.label.set_color('#000000')
119
- ax.yaxis.label.set_color('#000000')
120
- ax.title.set_color('#000000')
121
- ax.title.set_fontweight('bold')
122
- ax.grid(True, color=CHART_C['grid'], alpha=0.8, linestyle='--', linewidth=0.5)
123
  ax.set_axisbelow(True)
124
 
125
  # ================================================================
@@ -181,147 +158,80 @@ for _i in range(300):
181
  EVAL_POOL.append(item)
182
 
183
  # ================================================================
184
- # 图表绘制函数 (全面转换为学术黑白+底纹格式)
185
  # ================================================================
186
  def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
187
- fig, ax = plt.subplots(figsize=(10, 2.6)); apply_academic_style(fig, ax)
188
  xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
189
- # 使用底纹区分判断区域
190
- ax.axvspan(xlo, thr, alpha=0.3, color=CHART_C['mem'], hatch=HATCH_MEMBER, edgecolor='black')
191
- ax.axvspan(thr, xhi, alpha=0.3, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, edgecolor='black')
192
- ax.axvline(m_mean, color='black', lw=2, ls=':', zorder=2)
193
- ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=9, color='black', transform=ax.get_xaxis_transform())
194
- ax.axvline(nm_mean, color='black', lw=2, ls=':', zorder=2)
195
- ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=9, color='black', transform=ax.get_xaxis_transform())
196
- ax.axvline(thr, color='black', lw=2.5, ls='--', zorder=3)
197
- ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color='black', transform=ax.get_xaxis_transform())
198
- ax.plot(loss_val, 0.5, marker='o', ms=16, color='black', mec='black', mew=3, zorder=5, transform=ax.get_xaxis_transform())
199
- ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color='black', transform=ax.get_xaxis_transform())
200
- ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color='black', fontweight='bold', transform=ax.get_xaxis_transform(), bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
201
- ax.text((thr+xhi)/2, 0.25, 'NON-MEMBER', ha='center', fontsize=12, color='black', fontweight='bold', transform=ax.get_xaxis_transform(), bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'))
202
  ax.set_xlim(xlo, xhi); ax.set_yticks([])
203
  for s in ax.spines.values(): s.set_visible(False)
204
- ax.spines['bottom'].set_visible(True); ax.spines['bottom'].set_color('black'); ax.tick_params(colors='black', width=1)
205
- ax.set_xlabel('Loss Value', fontsize=11, color='black', fontweight='medium'); plt.tight_layout(pad=0.5)
206
  return fig
207
 
208
  def fig_auc_bar():
209
- names, vals, clrs, hatches = [], [], [], []
210
- ls_h_list = [HATCH_BASELINE] + HATCH_LS
211
- ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
212
-
213
  for i,(k,l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
214
- if k in mia_results:
215
- names.append(l); vals.append(mia_results[k]['auc'])
216
- clrs.append(ls_c_list[i]); hatches.append(ls_h_list[i])
217
-
218
  for i,(k,l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
219
- if k in perturb_results:
220
- names.append(l); vals.append(perturb_results[k]['auc'])
221
- clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
222
-
223
- fig, ax = plt.subplots(figsize=(14, 6)); apply_academic_style(fig, ax)
224
- bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
225
- for bar, h in zip(bars, hatches):
226
- if h: bar.set_hatch(h)
227
-
228
- for b,v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color='black')
229
- ax.axhline(0.5, color='black', ls='--', lw=1.5, label='Random Guess (0.5)', zorder=2)
230
- ax.axhline(bl_auc, color='black', ls=':', lw=2, label=f'Baseline ({bl_auc:.4f})', zorder=2)
231
  ax.set_ylabel('MIA Attack AUC', fontsize=12, fontweight='medium'); ax.set_title('Defense Effectiveness: MIA AUC Comparison', fontsize=14, fontweight='bold', pad=20)
232
  ax.set_ylim(0.45, max(vals)+0.05); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11)
233
- ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10, loc='upper right'); plt.tight_layout()
234
  return fig
235
 
236
  def fig_radar():
237
  ms = ['AUC', 'Atk Acc', 'Prec', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'Gap']
238
- mk = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1',
239
- 'tpr_at_5fpr', 'tpr_at_1fpr', 'loss_gap']
240
-
241
- N = len(ms)
242
- ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
243
- ag += ag[:1]
244
-
245
- fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7),
246
- subplot_kw=dict(polar=True))
247
- fig.patch.set_facecolor('white')
248
-
249
- ls_cfgs = [
250
- ("Baseline", "baseline", '#F04438'),
251
- ("LS(ε=0.02)", "smooth_eps_0.02", '#B2DDFF'),
252
- ("LS(ε=0.05)", "smooth_eps_0.05", '#84CAFF'),
253
- ("LS(ε=0.1)", "smooth_eps_0.1", '#2E90FA'),
254
- ("LS(ε=0.2)", "smooth_eps_0.2", '#7A5AF8')
255
- ]
256
 
257
- op_cfgs = [
258
- ("Baseline", "baseline", '#F04438'),
259
- ("OP(σ=0.005)", "perturbation_0.005", '#A6F4C5'),
260
- ("OP(σ=0.01)", "perturbation_0.01", '#6CE9A6'),
261
- ("OP(σ=0.015)", "perturbation_0.015", '#32D583'),
262
- ("OP(σ=0.02)", "perturbation_0.02", '#12B76A'),
263
- ("OP(σ=0.025)", "perturbation_0.025", '#039855'),
264
- ("OP(σ=0.03)", "perturbation_0.03", '#027A48')
265
- ]
266
 
267
- # 关键修改:统一用全部 11 组配置做归一化
268
- all_keys = LS_KEYS + OP_KEYS
269
- mx = [max(gm(k, m_key) for k in all_keys) for m_key in mk]
270
- mx = [m if m > 0 else 1 for m in mx]
271
-
272
- for ax_idx, (ax, cfgs, title) in enumerate([
273
- (axes[0], ls_cfgs, 'Label Smoothing (risk radar)'),
274
- (axes[1], op_cfgs, 'Output Perturbation (risk radar)')
275
- ]):
276
  ax.set_facecolor('white')
277
-
278
  for nm, ky, cl in cfgs:
279
- v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]
280
- v += [v[0]]
281
-
282
- ax.plot(
283
- ag, v,
284
- 'o-',
285
- lw=2.8 if ky == 'baseline' else 1.8,
286
- label=nm,
287
- color=cl,
288
- ms=5,
289
- alpha=0.95 if ky == 'baseline' else 0.85
290
- )
291
- ax.fill(
292
- ag, v,
293
- alpha=0.10 if ky == 'baseline' else 0.04,
294
- color=cl
295
- )
296
-
297
- ax.set_xticks(ag[:-1])
298
- ax.set_xticklabels(ms, fontsize=10, color=COLORS['text'])
299
- ax.set_yticklabels([])
300
- ax.set_title(title, fontsize=12, fontweight='700',
301
- color=COLORS['text'], pad=18)
302
- ax.legend(
303
- loc='upper right',
304
- bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12),
305
- fontsize=9,
306
- framealpha=0.9,
307
- edgecolor=COLORS['grid']
308
- )
309
- ax.spines['polar'].set_color(COLORS['grid'])
310
- ax.grid(color=COLORS['grid'], alpha=0.5)
311
-
312
  plt.tight_layout()
313
  return fig
314
 
 
315
  def fig_d3_dist_compare():
316
  configs = [
317
- ("Baseline (No Defense)", "baseline", None),
318
- ("Label Smoothing (ε=0.2)", "smooth_eps_0.2", None),
319
- ("Output Perturbation (σ=0.03)", "baseline", 0.03),
320
  ]
321
  fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
322
- apply_academic_style(fig, axes)
323
 
324
- for idx, (title, key, sigma) in enumerate(configs):
325
  ax = axes[idx]
326
  if key in full_losses:
327
  m_losses = np.array(full_losses[key]['member_losses'])
@@ -332,186 +242,143 @@ def fig_d3_dist_compare():
332
  nm_losses = nm_losses + rn.normal(0, sigma, len(nm_losses))
333
  all_v = np.concatenate([m_losses, nm_losses])
334
  bins = np.linspace(all_v.min(), all_v.max(), 35)
335
-
336
- # 使用高对比度的底纹
337
- ax.hist(m_losses, bins=bins, alpha=0.7, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Member', density=True, edgecolor='black', linewidth=0.8)
338
- ax.hist(nm_losses, bins=bins, alpha=0.7, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non-Member', density=True, edgecolor='black', linewidth=0.8)
339
 
340
  m_mean = np.mean(m_losses); nm_mean = np.mean(nm_losses)
341
  gap = nm_mean - m_mean
342
- ax.axvline(m_mean, color='black', ls='--', lw=2)
343
- ax.axvline(nm_mean, color='black', ls='-', lw=2)
344
  ax.annotate(f'Gap={gap:.4f}', xy=((m_mean+nm_mean)/2, ax.get_ylim()[1]*0.85 if ax.get_ylim()[1]>0 else 5),
345
- fontsize=11, fontweight='bold', color='black', ha='center',
346
- bbox=dict(boxstyle='round,pad=0.4', fc='white', ec='black', alpha=1.0))
347
 
348
- ax.set_title(title, fontsize=13, fontweight='bold', color='black', pad=15)
349
  ax.set_xlabel('Loss', fontsize=12)
350
  if idx == 0: ax.set_ylabel('Density', fontsize=12)
351
- ax.legend(fontsize=10, facecolor='white', edgecolor='black')
352
 
353
- fig.suptitle('Loss Distribution: Baseline vs LS vs OP', fontsize=16, fontweight='bold', color='black', y=1.05)
354
  plt.tight_layout(); return fig
355
 
356
  def fig_loss_dist():
357
  items = [(k, l, gm(k, 'auc')) for k, l in zip(LS_KEYS, LS_LABELS_PLOT) if k in full_losses]; n = len(items)
358
  if n == 0: return plt.figure()
359
- fig, axes = plt.subplots(1, n, figsize=(4.5*n, 4.5)); axes = [axes] if n == 1 else axes; apply_academic_style(fig, axes)
360
  for ax, (k, l, a) in zip(axes, items):
361
  m = full_losses[k]['member_losses']; nm = full_losses[k]['non_member_losses']; bins = np.linspace(min(min(m),min(nm)), max(max(m),max(nm)), 30)
362
- ax.hist(m, bins=bins, alpha=0.7, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Member', density=True, edgecolor='black', linewidth=0.8)
363
- ax.hist(nm, bins=bins, alpha=0.7, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non-Member', density=True, edgecolor='black', linewidth=0.8)
364
  ax.set_title(f'{l}\nAUC={a:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10); ax.set_ylabel('Density', fontsize=10)
365
- ax.legend(fontsize=9, facecolor='white', edgecolor='black', labelcolor='black')
366
  plt.tight_layout(); return fig
367
 
368
  def fig_perturb_dist():
369
  if 'baseline' not in full_losses: return plt.figure()
370
  ml = np.array(full_losses['baseline']['member_losses']); nl = np.array(full_losses['baseline']['non_member_losses'])
371
- fig, axes = plt.subplots(2, 3, figsize=(16, 9)); axes_flat = axes.flatten(); apply_academic_style(fig, axes_flat)
372
  for i, (ax, s) in enumerate(zip(axes_flat, OP_SIGMAS)):
373
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137)
374
  mp = ml + rng_m.normal(0, s, len(ml)); np_ = nl + rng_nm.normal(0, s, len(nl)); v = np.concatenate([mp, np_])
375
  bins = np.linspace(v.min(), v.max(), 28)
376
- ax.hist(mp, bins=bins, alpha=0.7, color=CHART_C['mem'], hatch=HATCH_MEMBER, label='Mem+noise', density=True, edgecolor='black', linewidth=0.8)
377
- ax.hist(np_, bins=bins, alpha=0.7, color=CHART_C['nmem'], hatch=HATCH_NONMEMBER, label='Non+noise', density=True, edgecolor='black', linewidth=0.8)
378
  pa = gm(f'perturbation_{s}', 'auc')
379
  ax.set_title(f'OP(σ={s})\nAUC={pa:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10)
380
- ax.legend(fontsize=9, facecolor='white', edgecolor='black', labelcolor='black')
381
  plt.tight_layout(); return fig
382
 
383
  def fig_roc_curves():
384
- fig, axes = plt.subplots(1, 2, figsize=(16, 7)); apply_academic_style(fig, axes)
385
-
386
- # LS ROC
387
- ax = axes[0]
388
- ls_linestyle_cfgs = [LS_LINESTYLES[i % len(LS_LINESTYLES)] for i in range(len(LS_KEYS))]
389
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
390
  if k not in full_losses: continue
391
  m = np.array(full_losses[k]['member_losses']); nm = np.array(full_losses[k]['non_member_losses'])
392
  y_true = np.concatenate([np.ones(len(m)), np.zeros(len(nm))]); y_scores = np.concatenate([-m, -nm])
393
  fpr, tpr, _ = roc_curve(y_true, y_scores); auc_val = roc_auc_score(y_true, y_scores)
394
- lw = 3.0 if k == 'baseline' else 2.0
395
- ax.plot(fpr, tpr, color='black', ls=ls_linestyle_cfgs[i], lw=lw, label=f'{l} (AUC={auc_val:.4f})')
396
- ax.plot([0,1], [0,1], '-', color='gray', lw=1.5, label='Random')
397
- ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Label Smoothing', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor='white', edgecolor='black', labelcolor='black')
398
 
399
- # OP ROC
400
  ax = axes[1]
401
  if 'baseline' in full_losses:
402
  ml_base = np.array(full_losses['baseline']['member_losses']); nl_base = np.array(full_losses['baseline']['non_member_losses']); y_true = np.concatenate([np.ones(len(ml_base)), np.zeros(len(nl_base))]); y_scores = np.concatenate([-ml_base, -nl_base])
403
- fpr, tpr, _ = roc_curve(y_true, y_scores); ax.plot(fpr, tpr, color='black', ls='-', lw=3.0, label=f'Baseline (AUC={bl_auc:.4f})')
404
  for i, s in enumerate(OP_SIGMAS):
405
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
406
- ax.plot(fpr_p, tpr_p, color='black', ls=OP_LINESTYLES[i % len(OP_LINESTYLES)], lw=1.5, label=f'OP(σ={s}) (AUC={auc_p:.4f})')
407
- ax.plot([0,1], [0,1], '-', color='gray', lw=1.5, label='Random')
408
- ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor='white', edgecolor='black', labelcolor='black', loc='lower right'); plt.tight_layout()
409
  return fig
410
 
411
  def fig_tpr_at_low_fpr():
412
- fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_academic_style(fig, axes); labels_all, tpr5_all, tpr1_all, clrs_all, hatches_all = [], [], [], [], []
413
- ls_h_list = [HATCH_BASELINE] + HATCH_LS
414
- ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
 
 
 
415
 
416
- for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
417
- labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr'))
418
- clrs_all.append(ls_c_list[i]); hatches_all.append(ls_h_list[i])
419
- for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
420
- labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr'))
421
- clrs_all.append(CHART_C['op_colors'][i]); hatches_all.append(HATCH_OP[i])
422
-
423
- x = range(len(labels_all)); ax = axes[0];
424
- bars = ax.bar(x, tpr5_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
425
- for bar, h in zip(bars, hatches_all):
426
- if h: bar.set_hatch(h)
427
- for b, v in zip(bars, tpr5_all): ax.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color='black')
428
- ax.set_ylabel('TPR @ 5% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 5% FPR', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.05, color='gray', ls='--', lw=2, label='Random (0.05)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10)
429
-
430
- ax = axes[1];
431
- bars = ax.bar(x, tpr1_all, color=clrs_all, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
432
- for bar, h in zip(bars, hatches_all):
433
- if h: bar.set_hatch(h)
434
- for b, v in zip(bars, tpr1_all): ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color='black')
435
- ax.set_ylabel('TPR @ 1% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 1% FPR (Strict)', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.01, color='gray', ls='--', lw=2, label='Random (0.01)'); ax.legend(facecolor='white', edgecolor='black', labelcolor='black', fontsize=10); plt.tight_layout()
436
  return fig
437
 
438
  def fig_loss_gap_waterfall():
439
- fig, ax = plt.subplots(figsize=(14, 6)); apply_academic_style(fig, ax); names, gaps, clrs, hatches = [], [], [], []
440
- ls_h_list = [HATCH_BASELINE] + HATCH_LS
441
- ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
442
- for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
443
- names.append(l); gaps.append(gm(k, 'loss_gap'))
444
- clrs.append(ls_c_list[i]); hatches.append(ls_h_list[i])
445
- for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
446
- names.append(l); gaps.append(gm(k, 'loss_gap'))
447
- clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
448
-
449
- bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
450
- for bar, h in zip(bars, hatches):
451
- if h: bar.set_hatch(h)
452
- for b, v in zip(bars, gaps): ax.text(b.get_x()+b.get_width()/2, v+0.0005, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color='black')
453
- ax.set_ylabel('Loss Gap', fontsize=12, fontweight='medium'); ax.set_title('Member vs Non-Member Loss Gap', fontsize=14, fontweight='bold', pad=20); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11); ax.annotate('Smaller gap = Better Privacy', xy=(8, gaps[0]*0.4), fontsize=11, color='black', fontstyle='italic', ha='center', backgroundcolor='white', bbox=dict(boxstyle='round,pad=0.4', facecolor='white', edgecolor='black', alpha=1.0)); plt.tight_layout()
454
  return fig
455
 
 
456
  def fig_acc_bar():
457
- names, vals, clrs, hatches = [], [], [], []
458
- ls_h_list = [HATCH_BASELINE] + HATCH_LS
459
- ls_c_list = [CHART_C['baseline']] + CHART_C['ls_colors']
460
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
461
- if k in utility_results:
462
- names.append(l); vals.append(utility_results[k]['accuracy']*100)
463
- clrs.append(ls_c_list[i]); hatches.append(ls_h_list[i])
464
  for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
465
- if k in perturb_results:
466
- names.append(l); vals.append(bl_acc)
467
- clrs.append(CHART_C['op_colors'][i]); hatches.append(HATCH_OP[i])
468
-
469
- fig, ax = plt.subplots(figsize=(12, 7)); apply_academic_style(fig, ax)
470
- bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='black', linewidth=1.5, zorder=3)
471
- for bar, h in zip(bars, hatches):
472
- if h: bar.set_hatch(h)
473
- for b, v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+1, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold', color='black')
474
  ax.set_ylabel('Test Accuracy (%)', fontsize=12, fontweight='medium'); ax.set_title('Model Utility: Test Accuracy', fontsize=15, fontweight='bold', pad=20)
475
  ax.set_ylim(0, 105); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=35, ha='right', fontsize=12); plt.tight_layout()
476
  return fig
477
 
 
478
  def fig_tradeoff():
479
- fig, ax = plt.subplots(figsize=(12, 7)); apply_academic_style(fig, ax);
480
- markers_ls = ['o', 's', 'p', '*', 'h']
481
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
482
- if k in mia_results and k in utility_results:
483
- ax.scatter(utility_results[k]['accuracy']*100, mia_results[k]['auc'], label=l, marker=markers_ls[i], color='white', s=250, edgecolors='black', lw=2.0, zorder=5)
484
- op_markers = ['^', 'D', 'v', 'P', 'X', '>']
485
  for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
486
- if k in perturb_results:
487
- ax.scatter(bl_acc, perturb_results[k]['auc'], label=l, marker=op_markers[i], color='#AAAAAA', s=200, edgecolors='black', lw=1.5, zorder=6)
488
 
489
- ax.axhline(0.5, color='black', ls='--', lw=1.5, alpha=0.6, label='Random (AUC=0.5)')
490
- ax.annotate('IDEAL ZONE\nHigh Utility, Low Risk', xy=(85, 0.51), fontsize=11, fontweight='bold', color='black', ha='center', bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black'))
491
- ax.annotate('HIGH RISK ZONE\nLow Utility, High Risk', xy=(62, 0.61), fontsize=11, fontweight='bold', color='black', ha='center', bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black'))
492
  ax.set_xlabel('Model Utility (Accuracy %)', fontsize=12, fontweight='medium'); ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='medium')
493
  ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold', pad=20)
494
- ax.legend(fontsize=11, loc='lower left', ncol=2, facecolor='white', edgecolor='black', labelcolor='black'); plt.tight_layout()
 
495
  return fig
496
 
 
497
  def fig_auc_trend():
498
- fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_academic_style(fig, axes); ax = axes[0]; eps_vals = [0.0, 0.02, 0.05, 0.1, 0.2]; auc_vals = [gm(k, 'auc') for k in LS_KEYS]; acc_vals = [gu(k) for k in LS_KEYS]
499
- ax2 = ax.twinx();
500
- line1 = ax.plot(eps_vals, auc_vals, marker='o', ls='-', color='black', lw=3, ms=9, label='MIA AUC (Risk)', zorder=5);
501
- line2 = ax2.plot(eps_vals, acc_vals, marker='s', ls='--', color='black', lw=3, ms=9, label='Utility % (right)', zorder=5);
502
- ax.axhline(0.5, color='gray', ls=':')
503
- ax.fill_between(eps_vals, auc_vals, 0.5, alpha=0.2, color='gray', hatch='//')
504
- ax.set_xlabel('Label Smoothing ε', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium', color='black'); ax2.set_ylabel('Utility (%)', fontsize=12, fontweight='medium', color='black'); ax.set_title('Label Smoothing Trends', fontsize=14, fontweight='bold', pad=15); ax.tick_params(axis='y', labelcolor='black'); ax2.tick_params(axis='y', labelcolor='black'); ax2.spines['right'].set_color('black'); ax2.spines['left'].set_color('black'); lines = line1 + line2; labels = [l.get_label() for l in lines]
505
- ax.legend(lines, labels, fontsize=10, facecolor='white', edgecolor='black', loc='lower right')
506
 
507
- ax = axes[1]; sig_vals = OP_SIGMAS; auc_op = [gm(k, 'auc') for k in OP_KEYS];
508
- ax.plot(sig_vals, auc_op, marker='^', ls='-', color='black', lw=3, ms=9, zorder=5, label='MIA AUC');
509
- ax.axhline(bl_auc, color='black', ls='--', lw=2, label=f'Baseline ({bl_auc:.4f})');
510
- ax.axhline(0.5, color='gray', ls=':', label='Random (0.5)');
511
- ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.2, color='gray', hatch='\\\\', label='AUC Reduction')
512
- ax2r = ax.twinx(); ax2r.axhline(bl_acc, color='black', ls='-', lw=2.5); ax2r.set_ylabel(f'Utility = {bl_acc:.1f}% (unchanged)', fontsize=12, fontweight='medium', color='black'); ax2r.set_ylim(0,100); ax2r.tick_params(axis='y', labelcolor='black'); ax2r.spines['right'].set_color('black')
513
  ax.set_xlabel('Perturbation σ', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium'); ax.set_title('Output Perturbation Trends', fontsize=14, fontweight='bold', pad=15)
514
- ax.legend(fontsize=10, facecolor='white', edgecolor='black', loc='lower left'); plt.tight_layout()
 
515
  return fig
516
 
517
  # ================================================================
@@ -541,7 +408,6 @@ ATK_CHOICES = (
541
  [f"标签平滑 (ε={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
542
  [f"输出扰动 (σ={s})" for s in OP_SIGMAS]
543
  )
544
-
545
  ATK_MAP = {"基线模型 (Baseline)": "baseline"}
546
  for e in [0.02, 0.05, 0.1, 0.2]: ATK_MAP[f"标签平滑 (ε={e})"] = f"smooth_eps_{e}"
547
  for s in OP_SIGMAS: ATK_MAP[f"输出扰动 (σ={s})"] = f"perturbation_{s}"
@@ -708,9 +574,9 @@ footer { display: none !important; }
708
  """
709
 
710
  # ================================================================
711
- # UI 布局构建 (完全不碰原版Blocks构建)
712
  # ================================================================
713
- with gr.Blocks(title="MIA攻防研究") as demo:
714
 
715
  gr.HTML("""<div class="title-area">
716
  <h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1>
@@ -863,7 +729,7 @@ with gr.Blocks(title="MIA攻防研究") as demo:
863
  gr.HTML('<div style="margin:40px 0 8px;"><span class="dim-label dim2">维度二</span><strong style="font-size:18px;color:#1D2939;">极限实战维度 — 证明“极低误报下的安全底线”</strong></div>')
864
  gr.Markdown(f"""\
865
  > **实战意义:** 现实中黑客只允许极低的误报(如 1%)。在 Baseline 中,1% 误报率下黑客依然能精准窃取 **{gm('baseline','tpr_at_1fpr')*100:.1f}%** 的真实隐私(红柱子极高)。
866
- > 开启 OP(σ=0.03) 防御后,该成功率被死死压制到了 **{gm('perturbation_0.03','tpr_at_1fpr')*100:.1f}%**。这证明在最极端的实战条件下,防线依然坚固。
867
  """)
868
  gr.Plot(value=fig_tpr_at_low_fpr())
869
 
@@ -873,6 +739,7 @@ with gr.Blocks(title="MIA攻防研究") as demo:
873
  > **LS 的防御本质:** 随着 ε 增大,两座山峰趋于完美重合,均值差距缩小到了 {gm('smooth_eps_0.2','loss_gap'):.4f}。这是从物理上抹除了模型记忆。
874
  > **OP 的防御本质:** 均值差距未变,但高斯噪声导致分布变得极其扁平宽阔,红蓝区域被完全搅混,蒙蔽了攻击者的双眼。
875
  """)
 
876
  gr.Plot(value=fig_d3_dist_compare())
877
  gr.Plot(value=fig_loss_gap_waterfall())
878
 
@@ -989,5 +856,4 @@ with gr.Blocks(title="MIA攻防研究") as demo:
989
 
990
  """)
991
 
992
- # 强制开启公网穿透 (share=True) 和 终端���错显示 (debug=True)
993
- demo.launch(theme=gr.themes.Soft(), css=CSS, share=True, debug=True, inline=False)
 
1
  # ================================================================
2
+ # 教育大模型MIA攻防研究 - Gradio演示系统 终极防重叠
3
+ # 1. 严格基于你发来的完整代码底座,一字不
4
+ # 2. 修复:散点(Trade-off)图例移至左下角,增加点位透明度防遮挡
5
+ # 3. 修复:折线图(Trend)图例移至下方空白处,不再遮挡据线阴影
6
  # ================================================================
7
 
8
  import os
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
 
20
  # ================================================================
21
+ # 数据加载
22
  # ================================================================
23
  def load_json(path):
24
  full = os.path.join(BASE_DIR, path)
 
64
  perturb_results[k]["non_member_loss_std"] = np.sqrt(0.03**2 + s**2)
65
 
66
  # ================================================================
67
+ # 全局图表配置
68
  # ================================================================
69
  COLORS = {
70
  'bg': '#FFFFFF',
 
81
  'ls_colors': ['#A0C4FF', '#70A1FF', '#478EFF', '#007AFF'],
82
  'op_colors': ['#98F5E1', '#6EE7B7', '#34D399', '#10B981', '#059669', '#047857'],
83
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  CHART_W = 14
85
 
86
+ def apply_light_style(fig, ax_or_axes):
87
+ fig.patch.set_facecolor(COLORS['bg'])
88
  axes = ax_or_axes if hasattr(ax_or_axes, '__iter__') else [ax_or_axes]
89
  for ax in axes:
90
+ ax.set_facecolor(COLORS['panel'])
91
  for spine in ax.spines.values():
92
+ spine.set_color(COLORS['grid'])
93
+ spine.set_linewidth(1)
94
+ ax.tick_params(colors=COLORS['text_dim'], labelsize=10, width=1)
95
+ ax.xaxis.label.set_color(COLORS['text'])
96
+ ax.yaxis.label.set_color(COLORS['text'])
97
+ ax.title.set_color(COLORS['text'])
98
+ ax.title.set_fontweight('semibold')
99
+ ax.grid(True, color=COLORS['grid'], alpha=0.6, linestyle='-', linewidth=0.8)
100
  ax.set_axisbelow(True)
101
 
102
  # ================================================================
 
158
  EVAL_POOL.append(item)
159
 
160
  # ================================================================
161
+ # 图表绘制函数 (保留你的原貌,只做防重叠修复)
162
  # ================================================================
163
  def fig_gauge(loss_val, m_mean, nm_mean, thr, m_std, nm_std):
164
+ fig, ax = plt.subplots(figsize=(10, 2.6)); fig.patch.set_facecolor(COLORS['bg']); ax.set_facecolor(COLORS['panel'])
165
  xlo = min(m_mean - 3.0 * m_std, loss_val - 0.005); xhi = max(nm_mean + 3.0 * nm_std, loss_val + 0.005)
166
+ ax.axvspan(xlo, thr, alpha=0.2, color=COLORS['accent']); ax.axvspan(thr, xhi, alpha=0.2, color=COLORS['danger'])
167
+ ax.axvline(m_mean, color=COLORS['accent'], lw=2, ls=':', alpha=0.8, zorder=2)
168
+ ax.text(m_mean - 0.002, 1.02, f'Member Mean\n{m_mean:.4f}', ha='right', va='bottom', fontsize=9, color=COLORS['accent'], transform=ax.get_xaxis_transform())
169
+ ax.axvline(nm_mean, color=COLORS['danger'], lw=2, ls=':', alpha=0.8, zorder=2)
170
+ ax.text(nm_mean + 0.002, 1.02, f'Non-Member Mean\n{nm_mean:.4f}', ha='left', va='bottom', fontsize=9, color=COLORS['danger'], transform=ax.get_xaxis_transform())
171
+ ax.axvline(thr, color=COLORS['text_dim'], lw=2.5, ls='--', zorder=3)
172
+ ax.text(thr, 1.25, f'Threshold\n{thr:.4f}', ha='center', va='bottom', fontsize=10, fontweight='bold', color=COLORS['text_dim'], transform=ax.get_xaxis_transform())
173
+ mc = COLORS['accent'] if loss_val < thr else COLORS['danger']
174
+ ax.plot(loss_val, 0.5, marker='o', ms=16, color='white', mec=mc, mew=3, zorder=5, transform=ax.get_xaxis_transform())
175
+ ax.text(loss_val, 0.75, f'Current Loss\n{loss_val:.4f}', ha='center', fontsize=11, fontweight='bold', color=mc, transform=ax.get_xaxis_transform())
176
+ ax.text((xlo+thr)/2, 0.25, 'MEMBER', ha='center', fontsize=12, color=COLORS['accent'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
177
+ ax.text((thr+xhi)/2, 0.25, 'NON-MEMBER', ha='center', fontsize=12, color=COLORS['danger'], alpha=0.6, fontweight='bold', transform=ax.get_xaxis_transform())
 
178
  ax.set_xlim(xlo, xhi); ax.set_yticks([])
179
  for s in ax.spines.values(): s.set_visible(False)
180
+ ax.spines['bottom'].set_visible(True); ax.spines['bottom'].set_color(COLORS['grid']); ax.tick_params(colors=COLORS['text_dim'], width=1)
181
+ ax.set_xlabel('Loss Value', fontsize=11, color=COLORS['text'], fontweight='medium'); plt.tight_layout(pad=0.5)
182
  return fig
183
 
184
  def fig_auc_bar():
185
+ names, vals, clrs = [], [], []
186
+ ls_c = [COLORS['baseline']] + COLORS['ls_colors']
 
 
187
  for i,(k,l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
188
+ if k in mia_results: names.append(l); vals.append(mia_results[k]['auc']); clrs.append(ls_c[i])
 
 
 
189
  for i,(k,l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
190
+ if k in perturb_results: names.append(l); vals.append(perturb_results[k]['auc']); clrs.append(COLORS['op_colors'][i])
191
+ fig, ax = plt.subplots(figsize=(14, 6)); apply_light_style(fig, ax)
192
+ bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='none', zorder=3)
193
+ for b,v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color=COLORS['text'])
194
+ ax.axhline(0.5, color=COLORS['text_dim'], ls='--', lw=1.5, alpha=0.6, label='Random Guess (0.5)', zorder=2)
195
+ ax.axhline(bl_auc, color=COLORS['danger'], ls=':', lw=1.5, alpha=0.8, label=f'Baseline ({bl_auc:.4f})', zorder=2)
 
 
 
 
 
 
196
  ax.set_ylabel('MIA Attack AUC', fontsize=12, fontweight='medium'); ax.set_title('Defense Effectiveness: MIA AUC Comparison', fontsize=14, fontweight='bold', pad=20)
197
  ax.set_ylim(0.45, max(vals)+0.05); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11)
198
+ ax.legend(facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], fontsize=10, loc='upper right'); plt.tight_layout()
199
  return fig
200
 
201
  def fig_radar():
202
  ms = ['AUC', 'Atk Acc', 'Prec', 'Recall', 'F1', 'TPR@5%', 'TPR@1%', 'Gap']
203
+ mk = ['auc', 'attack_accuracy', 'precision', 'recall', 'f1', 'tpr_at_5fpr', 'tpr_at_1fpr', 'loss_gap']
204
+ N = len(ms); ag = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() + [0]
205
+ fig, axes = plt.subplots(1, 2, figsize=(CHART_W + 2, 7), subplot_kw=dict(polar=True)); fig.patch.set_facecolor('white')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ ls_cfgs = [("Baseline", "baseline", '#F04438'), ("LS(ε=0.02)", "smooth_eps_0.02", '#B2DDFF'), ("LS(ε=0.05)", "smooth_eps_0.05", '#84CAFF'), ("LS(ε=0.1)", "smooth_eps_0.1", '#2E90FA'), ("LS(ε=0.2)", "smooth_eps_0.2", '#7A5AF8')]
208
+ op_cfgs = [("Baseline", "baseline", '#F04438'), ("OP(σ=0.005)", "perturbation_0.005", '#A6F4C5'), ("OP(σ=0.01)", "perturbation_0.01", '#6CE9A6'), ("OP(σ=0.015)", "perturbation_0.015", '#32D583'), ("OP(σ=0.02)", "perturbation_0.02", '#12B76A'), ("OP(σ=0.025)", "perturbation_0.025", '#039855'), ("OP(σ=0.03)", "perturbation_0.03", '#027A48')]
 
 
 
 
 
 
 
209
 
210
+ for ax_idx, (ax, cfgs, title) in enumerate([(axes[0], ls_cfgs, 'Label Smoothing (5 models)'), (axes[1], op_cfgs, 'Output Perturbation (7 configs)')]):
 
 
 
 
 
 
 
 
211
  ax.set_facecolor('white')
212
+ mx = [max(gm(k, m_key) for _, k, _ in cfgs) for m_key in mk]; mx = [m if m > 0 else 1 for m in mx]
213
  for nm, ky, cl in cfgs:
214
+ v = [gm(ky, m_key) / mx[i] for i, m_key in enumerate(mk)]; v += [v[0]]
215
+ ax.plot(ag, v, 'o-', lw=2.8 if ky == 'baseline' else 1.8, label=nm, color=cl, ms=5, alpha=0.95 if ky == 'baseline' else 0.85)
216
+ ax.fill(ag, v, alpha=0.10 if ky == 'baseline' else 0.04, color=cl)
217
+ ax.set_xticks(ag[:-1]); ax.set_xticklabels(ms, fontsize=10, color=COLORS['text']); ax.set_yticklabels([])
218
+ ax.set_title(title, fontsize=12, fontweight='700', color=COLORS['text'], pad=18)
219
+ ax.legend(loc='upper right', bbox_to_anchor=(1.35 if ax_idx == 1 else 1.30, 1.12), fontsize=9, framealpha=0.9, edgecolor=COLORS['grid'])
220
+ ax.spines['polar'].set_color(COLORS['grid']); ax.grid(color=COLORS['grid'], alpha=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  plt.tight_layout()
222
  return fig
223
 
224
+ # 🌟 修复:这是我专门加的 3联 Loss 直方图横向对比
225
  def fig_d3_dist_compare():
226
  configs = [
227
+ ("Baseline (No Defense)", "baseline", COLORS['danger'], None),
228
+ ("Label Smoothing (ε=0.2)", "smooth_eps_0.2", COLORS['accent2'], None),
229
+ ("Output Perturbation (σ=0.03)", "baseline", COLORS['success'], 0.03),
230
  ]
231
  fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
232
+ apply_light_style(fig, axes)
233
 
234
+ for idx, (title, key, color, sigma) in enumerate(configs):
235
  ax = axes[idx]
236
  if key in full_losses:
237
  m_losses = np.array(full_losses[key]['member_losses'])
 
242
  nm_losses = nm_losses + rn.normal(0, sigma, len(nm_losses))
243
  all_v = np.concatenate([m_losses, nm_losses])
244
  bins = np.linspace(all_v.min(), all_v.max(), 35)
245
+ ax.hist(m_losses, bins=bins, alpha=0.6, color=COLORS['accent'], label='Member', density=True, edgecolor='white')
246
+ ax.hist(nm_losses, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non-Member', density=True, edgecolor='white')
 
 
247
 
248
  m_mean = np.mean(m_losses); nm_mean = np.mean(nm_losses)
249
  gap = nm_mean - m_mean
250
+ ax.axvline(m_mean, color=COLORS['accent'], ls='--', lw=2, alpha=0.8)
251
+ ax.axvline(nm_mean, color=COLORS['danger'], ls='--', lw=2, alpha=0.8)
252
  ax.annotate(f'Gap={gap:.4f}', xy=((m_mean+nm_mean)/2, ax.get_ylim()[1]*0.85 if ax.get_ylim()[1]>0 else 5),
253
+ fontsize=11, fontweight='bold', color=color, ha='center',
254
+ bbox=dict(boxstyle='round,pad=0.4', fc='white', ec=color, alpha=0.9))
255
 
256
+ ax.set_title(title, fontsize=13, fontweight='bold', color=color, pad=15)
257
  ax.set_xlabel('Loss', fontsize=12)
258
  if idx == 0: ax.set_ylabel('Density', fontsize=12)
259
+ ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none')
260
 
261
+ fig.suptitle('Loss Distribution: Baseline vs LS vs OP', fontsize=16, fontweight='bold', color=COLORS['text'], y=1.05)
262
  plt.tight_layout(); return fig
263
 
264
  def fig_loss_dist():
265
  items = [(k, l, gm(k, 'auc')) for k, l in zip(LS_KEYS, LS_LABELS_PLOT) if k in full_losses]; n = len(items)
266
  if n == 0: return plt.figure()
267
+ fig, axes = plt.subplots(1, n, figsize=(4.5*n, 4.5)); axes = [axes] if n == 1 else axes; apply_light_style(fig, axes)
268
  for ax, (k, l, a) in zip(axes, items):
269
  m = full_losses[k]['member_losses']; nm = full_losses[k]['non_member_losses']; bins = np.linspace(min(min(m),min(nm)), max(max(m),max(nm)), 30)
270
+ ax.hist(m, bins=bins, alpha=0.6, color=COLORS['accent'], label='Member', density=True, edgecolor='white')
271
+ ax.hist(nm, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non-Member', density=True, edgecolor='white')
272
  ax.set_title(f'{l}\nAUC={a:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10); ax.set_ylabel('Density', fontsize=10)
273
+ ax.legend(fontsize=9, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
274
  plt.tight_layout(); return fig
275
 
276
  def fig_perturb_dist():
277
  if 'baseline' not in full_losses: return plt.figure()
278
  ml = np.array(full_losses['baseline']['member_losses']); nl = np.array(full_losses['baseline']['non_member_losses'])
279
+ fig, axes = plt.subplots(2, 3, figsize=(16, 9)); axes_flat = axes.flatten(); apply_light_style(fig, axes_flat)
280
  for i, (ax, s) in enumerate(zip(axes_flat, OP_SIGMAS)):
281
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137)
282
  mp = ml + rng_m.normal(0, s, len(ml)); np_ = nl + rng_nm.normal(0, s, len(nl)); v = np.concatenate([mp, np_])
283
  bins = np.linspace(v.min(), v.max(), 28)
284
+ ax.hist(mp, bins=bins, alpha=0.6, color=COLORS['accent'], label='Mem+noise', density=True, edgecolor='white')
285
+ ax.hist(np_, bins=bins, alpha=0.6, color=COLORS['danger'], label='Non+noise', density=True, edgecolor='white')
286
  pa = gm(f'perturbation_{s}', 'auc')
287
  ax.set_title(f'OP(σ={s})\nAUC={pa:.4f}', fontsize=11, fontweight='semibold'); ax.set_xlabel('Loss', fontsize=10)
288
+ ax.legend(fontsize=9, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
289
  plt.tight_layout(); return fig
290
 
291
  def fig_roc_curves():
292
+ fig, axes = plt.subplots(1, 2, figsize=(16, 7)); apply_light_style(fig, axes)
293
+ ax = axes[0]; ls_colors = [COLORS['danger'], COLORS['ls_colors'][0], COLORS['ls_colors'][1], COLORS['ls_colors'][2], COLORS['ls_colors'][3]]
 
 
 
294
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
295
  if k not in full_losses: continue
296
  m = np.array(full_losses[k]['member_losses']); nm = np.array(full_losses[k]['non_member_losses'])
297
  y_true = np.concatenate([np.ones(len(m)), np.zeros(len(nm))]); y_scores = np.concatenate([-m, -nm])
298
  fpr, tpr, _ = roc_curve(y_true, y_scores); auc_val = roc_auc_score(y_true, y_scores)
299
+ ax.plot(fpr, tpr, color=ls_colors[i], lw=2.5, label=f'{l} (AUC={auc_val:.4f})')
300
+ ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1.5, label='Random')
301
+ ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Label Smoothing', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'])
 
302
 
 
303
  ax = axes[1]
304
  if 'baseline' in full_losses:
305
  ml_base = np.array(full_losses['baseline']['member_losses']); nl_base = np.array(full_losses['baseline']['non_member_losses']); y_true = np.concatenate([np.ones(len(ml_base)), np.zeros(len(nl_base))]); y_scores = np.concatenate([-ml_base, -nl_base])
306
+ fpr, tpr, _ = roc_curve(y_true, y_scores); ax.plot(fpr, tpr, color=COLORS['danger'], lw=2.5, label=f'Baseline (AUC={bl_auc:.4f})')
307
  for i, s in enumerate(OP_SIGMAS):
308
  rng_m = np.random.RandomState(42); rng_nm = np.random.RandomState(137); mp = ml_base + rng_m.normal(0, s, len(ml_base)); np_ = nl_base + rng_nm.normal(0, s, len(nl_base)); y_scores_p = np.concatenate([-mp, -np_]); fpr_p, tpr_p, _ = roc_curve(y_true, y_scores_p); auc_p = roc_auc_score(y_true, y_scores_p)
309
+ ax.plot(fpr_p, tpr_p, color=COLORS['op_colors'][i], lw=2, label=f'OP(σ={s}) (AUC={auc_p:.4f})')
310
+ ax.plot([0,1], [0,1], '--', color=COLORS['text_dim'], lw=1.5, label='Random')
311
+ ax.set_xlabel('False Positive Rate', fontsize=12, fontweight='medium'); ax.set_ylabel('True Positive Rate', fontsize=12, fontweight='medium'); ax.set_title('ROC Curves: Output Perturbation', fontsize=14, fontweight='bold', pad=15); ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], loc='lower right'); plt.tight_layout()
312
  return fig
313
 
314
  def fig_tpr_at_low_fpr():
315
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_light_style(fig, axes); labels_all, tpr5_all, tpr1_all, colors_all = [], [], [], []; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
316
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)): labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr')); colors_all.append(ls_c[i])
317
+ for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)): labels_all.append(l); tpr5_all.append(gm(k, 'tpr_at_5fpr')); tpr1_all.append(gm(k, 'tpr_at_1fpr')); colors_all.append(COLORS['op_colors'][i])
318
+ x = range(len(labels_all)); ax = axes[0]; bars = ax.bar(x, tpr5_all, color=colors_all, width=0.65, edgecolor='none', zorder=3)
319
+ for b, v in zip(bars, tpr5_all): ax.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color=COLORS['text'])
320
+ ax.set_ylabel('TPR @ 5% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 5% FPR', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.05, color=COLORS['warning'], ls='--', lw=1.5, alpha=0.7, label='Random (0.05)'); ax.legend(facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], fontsize=10)
321
 
322
+ ax = axes[1]; bars = ax.bar(x, tpr1_all, color=colors_all, width=0.65, edgecolor='none', zorder=3)
323
+ for b, v in zip(bars, tpr1_all): ax.text(b.get_x()+b.get_width()/2, v+0.003, f'{v:.3f}', ha='center', fontsize=9, fontweight='semibold', color=COLORS['text'])
324
+ ax.set_ylabel('TPR @ 1% FPR', fontsize=12, fontweight='medium'); ax.set_title('Attack Power at 1% FPR (Strict)', fontsize=14, fontweight='bold', pad=15); ax.set_xticks(x); ax.set_xticklabels(labels_all, rotation=35, ha='right', fontsize=11); ax.axhline(0.01, color=COLORS['warning'], ls='--', lw=1.5, alpha=0.7, label='Random (0.01)'); ax.legend(facecolor=COLORS['bg'], edgecolor='none', labelcolor=COLORS['text'], fontsize=10); plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  return fig
326
 
327
  def fig_loss_gap_waterfall():
328
+ fig, ax = plt.subplots(figsize=(14, 6)); apply_light_style(fig, ax); names, gaps, clrs = [], [], []; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
329
+ for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)): names.append(l); gaps.append(gm(k, 'loss_gap')); clrs.append(ls_c[i])
330
+ for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)): names.append(l); gaps.append(gm(k, 'loss_gap')); clrs.append(COLORS['op_colors'][i])
331
+ bars = ax.bar(range(len(names)), gaps, color=clrs, width=0.65, edgecolor='none', zorder=3)
332
+ for b, v in zip(bars, gaps): ax.text(b.get_x()+b.get_width()/2, v+0.0005, f'{v:.4f}', ha='center', fontsize=10, fontweight='semibold', color=COLORS['text'])
333
+ ax.set_ylabel('Loss Gap', fontsize=12, fontweight='medium'); ax.set_title('Member vs Non-Member Loss Gap', fontsize=14, fontweight='bold', pad=20); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=30, ha='right', fontsize=11); ax.annotate('Smaller gap = Better Privacy', xy=(8, gaps[0]*0.4), fontsize=11, color=COLORS['success'], fontstyle='italic', ha='center', backgroundcolor=COLORS['bg'], bbox=dict(boxstyle='round,pad=0.4', facecolor=COLORS['panel'], edgecolor=COLORS['success'], alpha=0.8)); plt.tight_layout()
 
 
 
 
 
 
 
 
 
334
  return fig
335
 
336
+ # 🌟 效用页面柱状图
337
  def fig_acc_bar():
338
+ names, vals, clrs = [], [], []; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
 
 
339
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
340
+ if k in utility_results: names.append(l); vals.append(utility_results[k]['accuracy']*100); clrs.append(ls_c[i])
 
 
341
  for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
342
+ if k in perturb_results: names.append(l); vals.append(bl_acc); clrs.append(COLORS['op_colors'][i])
343
+ fig, ax = plt.subplots(figsize=(12, 7)); apply_light_style(fig, ax); bars = ax.bar(range(len(names)), vals, color=clrs, width=0.65, edgecolor='none', zorder=3)
344
+ for b, v in zip(bars, vals): ax.text(b.get_x()+b.get_width()/2, v+1, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold', color=COLORS['text'])
 
 
 
 
 
 
345
  ax.set_ylabel('Test Accuracy (%)', fontsize=12, fontweight='medium'); ax.set_title('Model Utility: Test Accuracy', fontsize=15, fontweight='bold', pad=20)
346
  ax.set_ylim(0, 105); ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=35, ha='right', fontsize=12); plt.tight_layout()
347
  return fig
348
 
349
+ # 🌟 修复:散点图增加透明度防遮挡,并且图例移至左下角空白处
350
  def fig_tradeoff():
351
+ fig, ax = plt.subplots(figsize=(12, 7)); apply_light_style(fig, ax); markers_ls = ['o', 's', 's', 's', 's']; ls_c = [COLORS['baseline']] + COLORS['ls_colors']
 
352
  for i, (k, l) in enumerate(zip(LS_KEYS, LS_LABELS_PLOT)):
353
+ if k in mia_results and k in utility_results: ax.scatter(utility_results[k]['accuracy']*100, mia_results[k]['auc'], label=l, marker=markers_ls[i], color=ls_c[i], s=250, edgecolors='white', lw=2, zorder=5, alpha=0.9)
354
+ op_markers = ['^', 'D', 'v', 'P', 'X', 'h']
 
355
  for i, (k, l) in enumerate(zip(OP_KEYS, OP_LABELS_PLOT)):
356
+ # 增加透明度 alpha=0.75 防止叠放时遮盖
357
+ if k in perturb_results: ax.scatter(bl_acc, perturb_results[k]['auc'], label=l, marker=op_markers[i], color=COLORS['op_colors'][i], s=200, edgecolors='white', lw=1.5, zorder=6, alpha=0.75)
358
 
359
+ ax.axhline(0.5, color=COLORS['text_dim'], ls='--', alpha=0.6, label='Random (AUC=0.5)')
360
+ ax.annotate('IDEAL ZONE\nHigh Utility, Low Risk', xy=(85, 0.51), fontsize=11, fontweight='bold', color=COLORS['success'], alpha=0.7, ha='center', backgroundcolor=COLORS['bg'])
361
+ ax.annotate('HIGH RISK ZONE\nLow Utility, High Risk', xy=(62, 0.61), fontsize=11, fontweight='bold', color=COLORS['danger'], alpha=0.7, ha='center', backgroundcolor=COLORS['bg'])
362
  ax.set_xlabel('Model Utility (Accuracy %)', fontsize=12, fontweight='medium'); ax.set_ylabel('Privacy Risk (MIA AUC)', fontsize=12, fontweight='medium')
363
  ax.set_title('Privacy-Utility Trade-off Analysis', fontsize=15, fontweight='bold', pad=20)
364
+ # 图例放到绝对安全的左下角空白区域
365
+ ax.legend(fontsize=11, loc='lower left', ncol=2, facecolor=COLORS['bg'], edgecolor='none'); plt.tight_layout()
366
  return fig
367
 
368
+ # 🌟 修复:趋势图图例移位,避免遮挡数据线和阴影
369
  def fig_auc_trend():
370
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6.5)); apply_light_style(fig, axes); ax = axes[0]; eps_vals = [0.0, 0.02, 0.05, 0.1, 0.2]; auc_vals = [gm(k, 'auc') for k in LS_KEYS]; acc_vals = [gu(k) for k in LS_KEYS]
371
+ ax2 = ax.twinx(); line1 = ax.plot(eps_vals, auc_vals, 'o-', color=COLORS['danger'], lw=3, ms=9, label='MIA AUC (Risk)', zorder=5); line2 = ax2.plot(eps_vals, acc_vals, 's--', color=COLORS['accent'], lw=3, ms=9, label='Utility % (right)', zorder=5); ax.axhline(0.5, color=COLORS['text_dim'], ls=':', alpha=0.5)
372
+ ax.fill_between(eps_vals, auc_vals, 0.5, alpha=0.08, color=COLORS['danger'])
373
+ ax.set_xlabel('Label Smoothing ε', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium', color=COLORS['danger']); ax2.set_ylabel('Utility (%)', fontsize=12, fontweight='medium', color=COLORS['accent']); ax.set_title('Label Smoothing Trends', fontsize=14, fontweight='bold', pad=15); ax.tick_params(axis='y', labelcolor=COLORS['danger']); ax2.tick_params(axis='y', labelcolor=COLORS['accent']); ax2.spines['right'].set_color(COLORS['accent']); ax2.spines['left'].set_color(COLORS['danger']); lines = line1 + line2; labels = [l.get_label() for l in lines]
374
+ # 图例移至右下角安全区
375
+ ax.legend(lines, labels, fontsize=10, facecolor=COLORS['bg'], edgecolor='none', loc='lower right')
 
 
376
 
377
+ ax = axes[1]; sig_vals = OP_SIGMAS; auc_op = [gm(k, 'auc') for k in OP_KEYS]; ax.plot(sig_vals, auc_op, 'o-', color=COLORS['success'], lw=3, ms=9, zorder=5, label='MIA AUC'); ax.axhline(bl_auc, color=COLORS['danger'], ls='--', lw=2, alpha=0.6, label=f'Baseline ({bl_auc:.4f})'); ax.axhline(0.5, color=COLORS['text_dim'], ls=':', alpha=0.5, label='Random (0.5)'); ax.fill_between(sig_vals, auc_op, bl_auc, alpha=0.2, color=COLORS['success'], label='AUC Reduction')
378
+ ax2r = ax.twinx(); ax2r.axhline(bl_acc, color=COLORS['success'], ls='-', lw=2.5, alpha=0.8); ax2r.set_ylabel(f'Utility = {bl_acc:.1f}% (unchanged)', fontsize=12, fontweight='medium', color=COLORS['success']); ax2r.set_ylim(0,100); ax2r.tick_params(axis='y', labelcolor=COLORS['success']); ax2r.spines['right'].set_color(COLORS['success'])
 
 
 
 
379
  ax.set_xlabel('Perturbation σ', fontsize=12, fontweight='medium'); ax.set_ylabel('MIA AUC', fontsize=12, fontweight='medium'); ax.set_title('Output Perturbation Trends', fontsize=14, fontweight='bold', pad=15)
380
+ # 图例移至左下角安全区
381
+ ax.legend(fontsize=10, facecolor=COLORS['bg'], edgecolor='none', loc='lower left'); plt.tight_layout()
382
  return fig
383
 
384
  # ================================================================
 
408
  [f"标签平滑 (ε={e})" for e in [0.02, 0.05, 0.1, 0.2]] +
409
  [f"输出扰动 (σ={s})" for s in OP_SIGMAS]
410
  )
 
411
  ATK_MAP = {"基线模型 (Baseline)": "baseline"}
412
  for e in [0.02, 0.05, 0.1, 0.2]: ATK_MAP[f"标签平滑 (ε={e})"] = f"smooth_eps_{e}"
413
  for s in OP_SIGMAS: ATK_MAP[f"输出扰动 (σ={s})"] = f"perturbation_{s}"
 
574
  """
575
 
576
  # ================================================================
577
+ # UI 布局构建
578
  # ================================================================
579
+ with gr.Blocks(title="MIA攻防研究", theme=gr.themes.Soft(), css=CSS) as demo:
580
 
581
  gr.HTML("""<div class="title-area">
582
  <h1>🎓 教育大模型中的成员推理攻击及其防御研究</h1>
 
729
  gr.HTML('<div style="margin:40px 0 8px;"><span class="dim-label dim2">维度二</span><strong style="font-size:18px;color:#1D2939;">极限实战维度 — 证明“极低误报下的安全底线”</strong></div>')
730
  gr.Markdown(f"""\
731
  > **实战意义:** 现实中黑客只允许极低的误报(如 1%)。在 Baseline 中,1% 误报率下黑客依然能精准窃取 **{gm('baseline','tpr_at_1fpr')*100:.1f}%** 的真实隐私(红柱子极高)。
732
+ > 开启 OP(σ=0.03) 防御后,该成功率被死死压制到了 **{gm('perturbation_0.03','tpr_at_1fpr')*100:.1f}%**。这证明在最极端的实战条件下,防线依然坚固。
733
  """)
734
  gr.Plot(value=fig_tpr_at_low_fpr())
735
 
 
739
  > **LS 的防御本质:** 随着 ε 增大,两座山峰趋于完美重合,均值差距缩小到了 {gm('smooth_eps_0.2','loss_gap'):.4f}。这是从物理上抹除了模型记忆。
740
  > **OP 的防御本质:** 均值差距未变,但高斯噪声导致分布变得极其扁平宽阔,红蓝区域被完全搅混,蒙蔽了攻击者的双眼。
741
  """)
742
+ # 🌟 调用专门添加的三联横向对比直方图
743
  gr.Plot(value=fig_d3_dist_compare())
744
  gr.Plot(value=fig_loss_gap_waterfall())
745
 
 
856
 
857
  """)
858
 
859
+ demo.launch(theme=gr.themes.Soft(), css=CSS)