Add complete research_paper.py implementation (1713 lines, 65KB)
Browse files- research_paper.py +1713 -0
research_paper.py
ADDED
|
@@ -0,0 +1,1713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
UFUSC: Unified Federated Unlearning via Sensitivity-Guided Contrastive Forgetting
|
| 4 |
+
|
| 5 |
+
A complete self-contained implementation for the research paper:
|
| 6 |
+
"Sensitivity-Guided Contrastive Forgetting: Unified Label and Feature Unlearning
|
| 7 |
+
in Vertical Federated Learning"
|
| 8 |
+
|
| 9 |
+
This script includes:
|
| 10 |
+
- VFL architecture (PassiveModel, ActiveModel, VFLFramework)
|
| 11 |
+
- 5 baselines (GradientAscent, Finetune, FisherForgetting, ManifoldMixup, Ferrari)
|
| 12 |
+
- UFUSC with 3 variants (Label Only, Feature Only, Joint)
|
| 13 |
+
- MIA attack evaluation
|
| 14 |
+
- Dataset loaders for MNIST, Fashion-MNIST, CIFAR-10
|
| 15 |
+
- Ablation study runner
|
| 16 |
+
- Scalability analysis across K=2,3,4,6 passive parties
|
| 17 |
+
- Visualization code (bar charts, radar plots, ablation plots, scalability plots)
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
pip install torch torchvision numpy matplotlib seaborn pandas scikit-learn
|
| 21 |
+
python research_paper.py
|
| 22 |
+
|
| 23 |
+
Author: UFUSC Research Team
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import json
|
| 28 |
+
import time
|
| 29 |
+
import copy
|
| 30 |
+
import random
|
| 31 |
+
import warnings
|
| 32 |
+
from collections import defaultdict
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
import torch.optim as optim
|
| 39 |
+
from torch.utils.data import DataLoader, TensorDataset, Subset
|
| 40 |
+
import torchvision
|
| 41 |
+
import torchvision.transforms as transforms
|
| 42 |
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
| 43 |
+
|
| 44 |
+
warnings.filterwarnings("ignore")
|
| 45 |
+
|
| 46 |
+
# ============================================================================
|
| 47 |
+
# Configuration
|
| 48 |
+
# ============================================================================
|
| 49 |
+
|
| 50 |
+
SEED = 42
|
| 51 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 52 |
+
NUM_PASSIVE_PARTIES = 2 # Default K=2 for VFL
|
| 53 |
+
BATCH_SIZE = 256
|
| 54 |
+
TRAIN_EPOCHS = 20
|
| 55 |
+
UNLEARN_EPOCHS = 10
|
| 56 |
+
LR = 0.001
|
| 57 |
+
FORGET_RATIO = 0.1 # Fraction of data to forget (specific class)
|
| 58 |
+
|
| 59 |
+
# UFUSC hyperparameters
|
| 60 |
+
ALPHA = 1.0 # Contrastive Forgetting Loss weight
|
| 61 |
+
BETA = 0.5 # Feature Sensitivity Loss weight
|
| 62 |
+
GAMMA = 0.3 # Anchor Loss weight
|
| 63 |
+
OMEGA = 0.1 # Dual variable / certification constraint weight
|
| 64 |
+
TAU = 2.0 # Forgetting threshold for certification
|
| 65 |
+
SENSITIVITY_SIGMA = 0.01 # Perturbation std for feature sensitivity
|
| 66 |
+
SENSITIVITY_SAMPLES = 5 # MC samples for sensitivity estimation
|
| 67 |
+
|
| 68 |
+
# Output directories
|
| 69 |
+
os.makedirs("results", exist_ok=True)
|
| 70 |
+
os.makedirs("figures", exist_ok=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def set_seed(seed=SEED):
|
| 74 |
+
"""Set all random seeds for reproducibility."""
|
| 75 |
+
random.seed(seed)
|
| 76 |
+
np.random.seed(seed)
|
| 77 |
+
torch.manual_seed(seed)
|
| 78 |
+
if torch.cuda.is_available():
|
| 79 |
+
torch.cuda.manual_seed_all(seed)
|
| 80 |
+
torch.backends.cudnn.deterministic = True
|
| 81 |
+
torch.backends.cudnn.benchmark = False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ============================================================================
|
| 85 |
+
# Dataset Loaders
|
| 86 |
+
# ============================================================================
|
| 87 |
+
|
| 88 |
+
def load_dataset(name="MNIST"):
|
| 89 |
+
"""
|
| 90 |
+
Load and preprocess a dataset. Returns flattened feature vectors for VFL.
|
| 91 |
+
|
| 92 |
+
In VFL, each passive party holds a vertical partition of the features.
|
| 93 |
+
We flatten images and split feature columns across K parties.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
name: One of "MNIST", "Fashion-MNIST", "CIFAR-10"
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
(X_train, y_train, X_test, y_test, num_classes, feature_dim)
|
| 100 |
+
"""
|
| 101 |
+
data_dir = "./data"
|
| 102 |
+
|
| 103 |
+
if name == "MNIST":
|
| 104 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
| 105 |
+
train_ds = torchvision.datasets.MNIST(data_dir, train=True, download=True, transform=transform)
|
| 106 |
+
test_ds = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transform)
|
| 107 |
+
num_classes = 10
|
| 108 |
+
elif name == "Fashion-MNIST":
|
| 109 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])
|
| 110 |
+
train_ds = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
|
| 111 |
+
test_ds = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
|
| 112 |
+
num_classes = 10
|
| 113 |
+
elif name == "CIFAR-10":
|
| 114 |
+
transform = transforms.Compose([
|
| 115 |
+
transforms.ToTensor(),
|
| 116 |
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
|
| 117 |
+
])
|
| 118 |
+
train_ds = torchvision.datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)
|
| 119 |
+
test_ds = torchvision.datasets.CIFAR10(data_dir, train=False, download=True, transform=transform)
|
| 120 |
+
num_classes = 10
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Unknown dataset: {name}")
|
| 123 |
+
|
| 124 |
+
# Extract and flatten
|
| 125 |
+
X_train = torch.stack([train_ds[i][0] for i in range(len(train_ds))]).view(len(train_ds), -1)
|
| 126 |
+
y_train = torch.tensor([train_ds[i][1] for i in range(len(train_ds))])
|
| 127 |
+
X_test = torch.stack([test_ds[i][0] for i in range(len(test_ds))]).view(len(test_ds), -1)
|
| 128 |
+
y_test = torch.tensor([test_ds[i][1] for i in range(len(test_ds))])
|
| 129 |
+
|
| 130 |
+
feature_dim = X_train.shape[1]
|
| 131 |
+
print(f" [{name}] Train: {X_train.shape}, Test: {X_test.shape}, Classes: {num_classes}, Features: {feature_dim}")
|
| 132 |
+
|
| 133 |
+
return X_train, y_train, X_test, y_test, num_classes, feature_dim
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def split_features_vfl(X, num_parties=NUM_PASSIVE_PARTIES):
|
| 137 |
+
"""
|
| 138 |
+
Split feature columns across K passive parties for VFL.
|
| 139 |
+
|
| 140 |
+
Each party gets a disjoint subset of columns (vertical partition).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
X: (N, D) tensor of flattened features
|
| 144 |
+
num_parties: number of passive parties K
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
List of K tensors, each (N, D/K) approximately
|
| 148 |
+
"""
|
| 149 |
+
D = X.shape[1]
|
| 150 |
+
split_sizes = [D // num_parties] * num_parties
|
| 151 |
+
# Distribute remainder
|
| 152 |
+
for i in range(D % num_parties):
|
| 153 |
+
split_sizes[i] += 1
|
| 154 |
+
return torch.split(X, split_sizes, dim=1)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def create_forget_retain_split(y, forget_class=0, forget_ratio=FORGET_RATIO):
|
| 158 |
+
"""
|
| 159 |
+
Create forget/retain index split.
|
| 160 |
+
|
| 161 |
+
Selects a fraction of samples from the target class as the forget set.
|
| 162 |
+
All other samples form the retain set.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
y: label tensor
|
| 166 |
+
forget_class: which class to partially forget
|
| 167 |
+
forget_ratio: fraction of that class to forget
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
(forget_indices, retain_indices)
|
| 171 |
+
"""
|
| 172 |
+
class_indices = (y == forget_class).nonzero(as_tuple=True)[0]
|
| 173 |
+
num_forget = max(1, int(len(class_indices) * forget_ratio))
|
| 174 |
+
|
| 175 |
+
perm = torch.randperm(len(class_indices))
|
| 176 |
+
forget_indices = class_indices[perm[:num_forget]]
|
| 177 |
+
|
| 178 |
+
all_indices = torch.arange(len(y))
|
| 179 |
+
mask = torch.ones(len(y), dtype=torch.bool)
|
| 180 |
+
mask[forget_indices] = False
|
| 181 |
+
retain_indices = all_indices[mask]
|
| 182 |
+
|
| 183 |
+
return forget_indices, retain_indices
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ============================================================================
|
| 187 |
+
# VFL Architecture
|
| 188 |
+
# ============================================================================
|
| 189 |
+
|
| 190 |
+
class PassiveModel(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
Passive party model in VFL.
|
| 193 |
+
|
| 194 |
+
Each passive party holds a vertical partition of features and computes
|
| 195 |
+
a local embedding (forward representation) that is sent to the active party.
|
| 196 |
+
|
| 197 |
+
Architecture: 2-layer MLP with ReLU and BatchNorm.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, input_dim, embed_dim=64):
|
| 201 |
+
super().__init__()
|
| 202 |
+
hidden_dim = max(128, input_dim // 2)
|
| 203 |
+
self.net = nn.Sequential(
|
| 204 |
+
nn.Linear(input_dim, hidden_dim),
|
| 205 |
+
nn.BatchNorm1d(hidden_dim),
|
| 206 |
+
nn.ReLU(),
|
| 207 |
+
nn.Dropout(0.2),
|
| 208 |
+
nn.Linear(hidden_dim, embed_dim),
|
| 209 |
+
nn.BatchNorm1d(embed_dim),
|
| 210 |
+
nn.ReLU()
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def forward(self, x):
|
| 214 |
+
return self.net(x)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class ActiveModel(nn.Module):
|
| 218 |
+
"""
|
| 219 |
+
Active party model in VFL.
|
| 220 |
+
|
| 221 |
+
The active party holds the labels and receives concatenated embeddings
|
| 222 |
+
from all passive parties. It performs final classification.
|
| 223 |
+
|
| 224 |
+
Architecture: 2-layer MLP with ReLU, Dropout, and softmax output.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(self, total_embed_dim, num_classes=10):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.net = nn.Sequential(
|
| 230 |
+
nn.Linear(total_embed_dim, 128),
|
| 231 |
+
nn.BatchNorm1d(128),
|
| 232 |
+
nn.ReLU(),
|
| 233 |
+
nn.Dropout(0.3),
|
| 234 |
+
nn.Linear(128, 64),
|
| 235 |
+
nn.ReLU(),
|
| 236 |
+
nn.Linear(64, num_classes)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
return self.net(x)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class VFLFramework:
|
| 244 |
+
"""
|
| 245 |
+
Vertical Federated Learning framework.
|
| 246 |
+
|
| 247 |
+
Manages K passive parties and 1 active party. Each passive party
|
| 248 |
+
computes embeddings from their feature partition, which are concatenated
|
| 249 |
+
and fed to the active party for classification.
|
| 250 |
+
|
| 251 |
+
The active party holds labels and orchestrates training.
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
def __init__(self, feature_dims, num_classes=10, embed_dim=64,
|
| 255 |
+
num_parties=NUM_PASSIVE_PARTIES, lr=LR):
|
| 256 |
+
"""
|
| 257 |
+
Args:
|
| 258 |
+
feature_dims: list of input dimensions for each passive party
|
| 259 |
+
num_classes: number of output classes
|
| 260 |
+
embed_dim: embedding dimension per passive party
|
| 261 |
+
num_parties: number of passive parties K
|
| 262 |
+
lr: learning rate
|
| 263 |
+
"""
|
| 264 |
+
self.num_parties = num_parties
|
| 265 |
+
self.embed_dim = embed_dim
|
| 266 |
+
self.num_classes = num_classes
|
| 267 |
+
|
| 268 |
+
# Create passive models
|
| 269 |
+
self.passive_models = []
|
| 270 |
+
for i in range(num_parties):
|
| 271 |
+
model = PassiveModel(feature_dims[i], embed_dim).to(DEVICE)
|
| 272 |
+
self.passive_models.append(model)
|
| 273 |
+
|
| 274 |
+
# Create active model
|
| 275 |
+
total_embed = embed_dim * num_parties
|
| 276 |
+
self.active_model = ActiveModel(total_embed, num_classes).to(DEVICE)
|
| 277 |
+
|
| 278 |
+
# Optimizers
|
| 279 |
+
all_params = []
|
| 280 |
+
for pm in self.passive_models:
|
| 281 |
+
all_params += list(pm.parameters())
|
| 282 |
+
all_params += list(self.active_model.parameters())
|
| 283 |
+
self.optimizer = optim.Adam(all_params, lr=lr)
|
| 284 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 285 |
+
|
| 286 |
+
def get_embeddings(self, X_splits):
|
| 287 |
+
"""Compute embeddings from all passive parties and concatenate."""
|
| 288 |
+
embeddings = []
|
| 289 |
+
for i, pm in enumerate(self.passive_models):
|
| 290 |
+
emb = pm(X_splits[i].to(DEVICE))
|
| 291 |
+
embeddings.append(emb)
|
| 292 |
+
return torch.cat(embeddings, dim=1)
|
| 293 |
+
|
| 294 |
+
def forward(self, X_splits):
|
| 295 |
+
"""Full forward pass through VFL."""
|
| 296 |
+
combined = self.get_embeddings(X_splits)
|
| 297 |
+
logits = self.active_model(combined)
|
| 298 |
+
return logits, combined
|
| 299 |
+
|
| 300 |
+
def train_model(self, X_train_splits, y_train, X_test_splits, y_test,
|
| 301 |
+
epochs=TRAIN_EPOCHS, verbose=True):
|
| 302 |
+
"""
|
| 303 |
+
Train the VFL model end-to-end.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
X_train_splits: list of K tensors (one per passive party)
|
| 307 |
+
y_train: training labels
|
| 308 |
+
X_test_splits: list of K test tensors
|
| 309 |
+
y_test: test labels
|
| 310 |
+
epochs: number of training epochs
|
| 311 |
+
verbose: print progress
|
| 312 |
+
"""
|
| 313 |
+
dataset = TensorDataset(*X_train_splits, y_train)
|
| 314 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
|
| 315 |
+
|
| 316 |
+
self.set_train()
|
| 317 |
+
|
| 318 |
+
for epoch in range(epochs):
|
| 319 |
+
total_loss = 0
|
| 320 |
+
correct = 0
|
| 321 |
+
total = 0
|
| 322 |
+
|
| 323 |
+
for batch in loader:
|
| 324 |
+
*batch_splits, batch_y = batch
|
| 325 |
+
batch_y = batch_y.to(DEVICE)
|
| 326 |
+
|
| 327 |
+
logits, _ = self.forward(batch_splits)
|
| 328 |
+
loss = self.criterion(logits, batch_y)
|
| 329 |
+
|
| 330 |
+
self.optimizer.zero_grad()
|
| 331 |
+
loss.backward()
|
| 332 |
+
self.optimizer.step()
|
| 333 |
+
|
| 334 |
+
total_loss += loss.item() * batch_y.size(0)
|
| 335 |
+
preds = logits.argmax(dim=1)
|
| 336 |
+
correct += (preds == batch_y).sum().item()
|
| 337 |
+
total += batch_y.size(0)
|
| 338 |
+
|
| 339 |
+
if verbose and (epoch + 1) % 5 == 0:
|
| 340 |
+
train_acc = correct / total * 100
|
| 341 |
+
test_acc = self.evaluate(X_test_splits, y_test)
|
| 342 |
+
print(f" Epoch {epoch+1}/{epochs} — Loss: {total_loss/total:.4f}, "
|
| 343 |
+
f"Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%")
|
| 344 |
+
|
| 345 |
+
def evaluate(self, X_splits, y, batch_size=512):
|
| 346 |
+
"""Evaluate accuracy on given data."""
|
| 347 |
+
self.set_eval()
|
| 348 |
+
dataset = TensorDataset(*X_splits, y)
|
| 349 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 350 |
+
|
| 351 |
+
correct = 0
|
| 352 |
+
total = 0
|
| 353 |
+
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
for batch in loader:
|
| 356 |
+
*batch_splits, batch_y = batch
|
| 357 |
+
batch_y = batch_y.to(DEVICE)
|
| 358 |
+
logits, _ = self.forward(batch_splits)
|
| 359 |
+
preds = logits.argmax(dim=1)
|
| 360 |
+
correct += (preds == batch_y).sum().item()
|
| 361 |
+
total += batch_y.size(0)
|
| 362 |
+
|
| 363 |
+
self.set_train()
|
| 364 |
+
return correct / total * 100
|
| 365 |
+
|
| 366 |
+
def predict_proba(self, X_splits, batch_size=512):
|
| 367 |
+
"""Get prediction probabilities."""
|
| 368 |
+
self.set_eval()
|
| 369 |
+
dataset = TensorDataset(*X_splits)
|
| 370 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 371 |
+
|
| 372 |
+
all_probs = []
|
| 373 |
+
with torch.no_grad():
|
| 374 |
+
for batch in loader:
|
| 375 |
+
logits, _ = self.forward(list(batch))
|
| 376 |
+
probs = F.softmax(logits, dim=1)
|
| 377 |
+
all_probs.append(probs.cpu())
|
| 378 |
+
|
| 379 |
+
self.set_train()
|
| 380 |
+
return torch.cat(all_probs, dim=0)
|
| 381 |
+
|
| 382 |
+
def set_train(self):
|
| 383 |
+
for pm in self.passive_models:
|
| 384 |
+
pm.train()
|
| 385 |
+
self.active_model.train()
|
| 386 |
+
|
| 387 |
+
def set_eval(self):
|
| 388 |
+
for pm in self.passive_models:
|
| 389 |
+
pm.eval()
|
| 390 |
+
self.active_model.eval()
|
| 391 |
+
|
| 392 |
+
def clone(self):
|
| 393 |
+
"""Deep copy the entire VFL framework."""
|
| 394 |
+
cloned = VFLFramework.__new__(VFLFramework)
|
| 395 |
+
cloned.num_parties = self.num_parties
|
| 396 |
+
cloned.embed_dim = self.embed_dim
|
| 397 |
+
cloned.num_classes = self.num_classes
|
| 398 |
+
cloned.passive_models = [copy.deepcopy(pm) for pm in self.passive_models]
|
| 399 |
+
cloned.active_model = copy.deepcopy(self.active_model)
|
| 400 |
+
cloned.criterion = nn.CrossEntropyLoss()
|
| 401 |
+
|
| 402 |
+
all_params = []
|
| 403 |
+
for pm in cloned.passive_models:
|
| 404 |
+
all_params += list(pm.parameters())
|
| 405 |
+
all_params += list(cloned.active_model.parameters())
|
| 406 |
+
cloned.optimizer = optim.Adam(all_params, lr=LR)
|
| 407 |
+
|
| 408 |
+
return cloned
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# ============================================================================
|
| 412 |
+
# Evaluation Metrics
|
| 413 |
+
# ============================================================================
|
| 414 |
+
|
| 415 |
+
def membership_inference_attack(model, X_train_splits, y_train, X_test_splits, y_test,
|
| 416 |
+
forget_indices, retain_indices):
|
| 417 |
+
"""
|
| 418 |
+
Simple Membership Inference Attack (MIA).
|
| 419 |
+
|
| 420 |
+
Uses prediction confidence as a signal: members tend to have higher
|
| 421 |
+
confidence on the correct class. We compute the attack success rate (ASR)
|
| 422 |
+
on forget set members vs non-members.
|
| 423 |
+
|
| 424 |
+
Lower ASR after unlearning → better privacy (model doesn't distinguish
|
| 425 |
+
members from non-members).
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
model: VFLFramework
|
| 429 |
+
X_train_splits: training feature splits
|
| 430 |
+
y_train: training labels
|
| 431 |
+
X_test_splits: test feature splits
|
| 432 |
+
y_test: test labels
|
| 433 |
+
forget_indices: indices of forget set in training data
|
| 434 |
+
retain_indices: indices of retain set in training data
|
| 435 |
+
|
| 436 |
+
Returns:
|
| 437 |
+
mia_asr: attack success rate (%)
|
| 438 |
+
"""
|
| 439 |
+
model.set_eval()
|
| 440 |
+
|
| 441 |
+
# Member (forget set) confidences
|
| 442 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 443 |
+
forget_labels = y_train[forget_indices]
|
| 444 |
+
member_probs = model.predict_proba(forget_splits)
|
| 445 |
+
member_conf = member_probs[torch.arange(len(forget_labels)), forget_labels].numpy()
|
| 446 |
+
|
| 447 |
+
# Non-member (test set, same class) confidences
|
| 448 |
+
forget_class = forget_labels[0].item()
|
| 449 |
+
test_class_mask = y_test == forget_class
|
| 450 |
+
if test_class_mask.sum() == 0:
|
| 451 |
+
return 50.0 # Cannot evaluate
|
| 452 |
+
|
| 453 |
+
test_class_splits = [xs[test_class_mask] for xs in X_test_splits]
|
| 454 |
+
test_class_labels = y_test[test_class_mask]
|
| 455 |
+
nonmember_probs = model.predict_proba(test_class_splits)
|
| 456 |
+
nonmember_conf = nonmember_probs[torch.arange(len(test_class_labels)), test_class_labels].numpy()
|
| 457 |
+
|
| 458 |
+
# Threshold-based attack: predict member if confidence > threshold
|
| 459 |
+
# Use median of combined as threshold
|
| 460 |
+
all_conf = np.concatenate([member_conf, nonmember_conf])
|
| 461 |
+
threshold = np.median(all_conf)
|
| 462 |
+
|
| 463 |
+
member_pred = (member_conf > threshold).astype(float)
|
| 464 |
+
nonmember_pred = (nonmember_conf <= threshold).astype(float)
|
| 465 |
+
|
| 466 |
+
# ASR = average of TPR (correctly predicting members) and TNR (correctly predicting non-members)
|
| 467 |
+
tpr = member_pred.mean()
|
| 468 |
+
tnr = nonmember_pred.mean()
|
| 469 |
+
mia_asr = (tpr + tnr) / 2 * 100
|
| 470 |
+
|
| 471 |
+
model.set_train()
|
| 472 |
+
return mia_asr
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def compute_feature_sensitivity(model, X_splits, sigma=SENSITIVITY_SIGMA,
|
| 476 |
+
n_samples=SENSITIVITY_SAMPLES):
|
| 477 |
+
"""
|
| 478 |
+
Compute Lipschitz-based feature sensitivity via Monte Carlo perturbation.
|
| 479 |
+
|
| 480 |
+
Measures how much the model's output changes when input features are
|
| 481 |
+
perturbed by Gaussian noise. Lower sensitivity after unlearning means
|
| 482 |
+
the model is less responsive to the target features.
|
| 483 |
+
|
| 484 |
+
Based on Ferrari (arxiv:2405.17462) Section 4.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
model: VFLFramework
|
| 488 |
+
X_splits: feature splits to perturb
|
| 489 |
+
sigma: std of Gaussian perturbation
|
| 490 |
+
n_samples: number of MC samples
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
mean_sensitivity: average sensitivity across samples and parties
|
| 494 |
+
"""
|
| 495 |
+
model.set_eval()
|
| 496 |
+
sensitivities = []
|
| 497 |
+
|
| 498 |
+
# Sample a subset for efficiency
|
| 499 |
+
n = min(500, X_splits[0].shape[0])
|
| 500 |
+
subset_splits = [xs[:n] for xs in X_splits]
|
| 501 |
+
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
# Original output
|
| 504 |
+
logits_orig, _ = model.forward(subset_splits)
|
| 505 |
+
probs_orig = F.softmax(logits_orig, dim=1)
|
| 506 |
+
|
| 507 |
+
for _ in range(n_samples):
|
| 508 |
+
for party_idx in range(len(subset_splits)):
|
| 509 |
+
perturbed_splits = [xs.clone() for xs in subset_splits]
|
| 510 |
+
noise = torch.randn_like(perturbed_splits[party_idx]) * sigma
|
| 511 |
+
perturbed_splits[party_idx] = perturbed_splits[party_idx] + noise
|
| 512 |
+
|
| 513 |
+
logits_pert, _ = model.forward(perturbed_splits)
|
| 514 |
+
probs_pert = F.softmax(logits_pert, dim=1)
|
| 515 |
+
|
| 516 |
+
# L2 distance in probability space
|
| 517 |
+
diff = (probs_orig - probs_pert).norm(dim=1).mean().item()
|
| 518 |
+
sensitivities.append(diff)
|
| 519 |
+
|
| 520 |
+
model.set_train()
|
| 521 |
+
return np.mean(sensitivities) if sensitivities else 0.0
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def full_evaluation(model, X_train_splits, y_train, X_test_splits, y_test,
|
| 525 |
+
forget_indices, retain_indices, forget_class=0):
|
| 526 |
+
"""
|
| 527 |
+
Run full evaluation suite: test accuracy, forget accuracy, retain accuracy,
|
| 528 |
+
MIA ASR, and feature sensitivity.
|
| 529 |
+
"""
|
| 530 |
+
# Test accuracy
|
| 531 |
+
test_acc = model.evaluate(X_test_splits, y_test)
|
| 532 |
+
|
| 533 |
+
# Forget set accuracy (should be LOW after good unlearning)
|
| 534 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 535 |
+
forget_labels = y_train[forget_indices]
|
| 536 |
+
forget_acc = model.evaluate(forget_splits, forget_labels)
|
| 537 |
+
|
| 538 |
+
# Retain set accuracy (should stay HIGH)
|
| 539 |
+
retain_splits = [xs[retain_indices] for xs in X_train_splits]
|
| 540 |
+
retain_labels = y_train[retain_indices]
|
| 541 |
+
retain_acc = model.evaluate(retain_splits, retain_labels)
|
| 542 |
+
|
| 543 |
+
# MIA attack success rate (should be LOW, close to 50% = random)
|
| 544 |
+
mia_asr = membership_inference_attack(
|
| 545 |
+
model, X_train_splits, y_train, X_test_splits, y_test,
|
| 546 |
+
forget_indices, retain_indices
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Feature sensitivity
|
| 550 |
+
feat_sens = compute_feature_sensitivity(model, forget_splits)
|
| 551 |
+
|
| 552 |
+
return {
|
| 553 |
+
"test_acc": round(test_acc, 2),
|
| 554 |
+
"forget_acc": round(forget_acc, 2),
|
| 555 |
+
"retain_acc": round(retain_acc, 2),
|
| 556 |
+
"mia_asr": round(mia_asr, 1),
|
| 557 |
+
"feature_sensitivity": round(feat_sens, 3)
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# ============================================================================
|
| 562 |
+
# Baseline Unlearning Methods
|
| 563 |
+
# ============================================================================
|
| 564 |
+
|
| 565 |
+
class GradientAscentUnlearning:
|
| 566 |
+
"""
|
| 567 |
+
Baseline 1: Gradient Ascent
|
| 568 |
+
|
| 569 |
+
Maximizes the loss on the forget set to push the model away from
|
| 570 |
+
correctly classifying forgotten samples. Simple but can cause
|
| 571 |
+
catastrophic degradation of retain set performance.
|
| 572 |
+
|
| 573 |
+
Reference: Graves et al. (2020), Thudi et al. (2022)
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
def __init__(self, epochs=5, lr=0.01):
|
| 577 |
+
self.epochs = epochs
|
| 578 |
+
self.lr = lr
|
| 579 |
+
|
| 580 |
+
def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
|
| 581 |
+
unlearned = model.clone()
|
| 582 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 583 |
+
forget_labels = y_train[forget_indices]
|
| 584 |
+
|
| 585 |
+
dataset = TensorDataset(*forget_splits, forget_labels)
|
| 586 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 587 |
+
|
| 588 |
+
# Use separate optimizer with potentially different LR
|
| 589 |
+
all_params = []
|
| 590 |
+
for pm in unlearned.passive_models:
|
| 591 |
+
all_params += list(pm.parameters())
|
| 592 |
+
all_params += list(unlearned.active_model.parameters())
|
| 593 |
+
optimizer = optim.SGD(all_params, lr=self.lr)
|
| 594 |
+
|
| 595 |
+
unlearned.set_train()
|
| 596 |
+
for epoch in range(self.epochs):
|
| 597 |
+
for batch in loader:
|
| 598 |
+
*batch_splits, batch_y = batch
|
| 599 |
+
batch_y = batch_y.to(DEVICE)
|
| 600 |
+
|
| 601 |
+
logits, _ = unlearned.forward(batch_splits)
|
| 602 |
+
loss = unlearned.criterion(logits, batch_y)
|
| 603 |
+
|
| 604 |
+
optimizer.zero_grad()
|
| 605 |
+
# ASCENT: negate gradient
|
| 606 |
+
(-loss).backward()
|
| 607 |
+
optimizer.step()
|
| 608 |
+
|
| 609 |
+
return unlearned
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class FineTuneUnlearning:
|
| 613 |
+
"""
|
| 614 |
+
Baseline 2: Fine-tuning on Retain Set
|
| 615 |
+
|
| 616 |
+
Simply fine-tunes the model on only the retain set, hoping the model
|
| 617 |
+
will "forget" the unlearned data. Often insufficient as the model
|
| 618 |
+
retains significant information about the forget set.
|
| 619 |
+
|
| 620 |
+
Reference: Standard baseline in unlearning literature
|
| 621 |
+
"""
|
| 622 |
+
|
| 623 |
+
def __init__(self, epochs=10, lr=0.001):
|
| 624 |
+
self.epochs = epochs
|
| 625 |
+
self.lr = lr
|
| 626 |
+
|
| 627 |
+
def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
|
| 628 |
+
unlearned = model.clone()
|
| 629 |
+
retain_splits = [xs[retain_indices] for xs in X_train_splits]
|
| 630 |
+
retain_labels = y_train[retain_indices]
|
| 631 |
+
|
| 632 |
+
dataset = TensorDataset(*retain_splits, retain_labels)
|
| 633 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 634 |
+
|
| 635 |
+
all_params = []
|
| 636 |
+
for pm in unlearned.passive_models:
|
| 637 |
+
all_params += list(pm.parameters())
|
| 638 |
+
all_params += list(unlearned.active_model.parameters())
|
| 639 |
+
optimizer = optim.Adam(all_params, lr=self.lr)
|
| 640 |
+
|
| 641 |
+
unlearned.set_train()
|
| 642 |
+
for epoch in range(self.epochs):
|
| 643 |
+
for batch in loader:
|
| 644 |
+
*batch_splits, batch_y = batch
|
| 645 |
+
batch_y = batch_y.to(DEVICE)
|
| 646 |
+
|
| 647 |
+
logits, _ = unlearned.forward(batch_splits)
|
| 648 |
+
loss = unlearned.criterion(logits, batch_y)
|
| 649 |
+
|
| 650 |
+
optimizer.zero_grad()
|
| 651 |
+
loss.backward()
|
| 652 |
+
optimizer.step()
|
| 653 |
+
|
| 654 |
+
return unlearned
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class FisherForgetting:
|
| 658 |
+
"""
|
| 659 |
+
Baseline 3: Fisher Forgetting
|
| 660 |
+
|
| 661 |
+
Uses the Fisher Information Matrix to identify which parameters are
|
| 662 |
+
most important for the forget set, then adds noise proportional to
|
| 663 |
+
the inverse Fisher to those parameters. This selectively "erases"
|
| 664 |
+
information about the forget set.
|
| 665 |
+
|
| 666 |
+
Reference: Golatkar et al. (2020) "Eternal Sunshine of the Spotless Net"
|
| 667 |
+
"""
|
| 668 |
+
|
| 669 |
+
def __init__(self, noise_scale=0.01):
|
| 670 |
+
self.noise_scale = noise_scale
|
| 671 |
+
|
| 672 |
+
def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
|
| 673 |
+
unlearned = model.clone()
|
| 674 |
+
|
| 675 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 676 |
+
forget_labels = y_train[forget_indices]
|
| 677 |
+
|
| 678 |
+
# Compute Fisher diagonal on forget set
|
| 679 |
+
unlearned.set_train()
|
| 680 |
+
fisher_diag = {}
|
| 681 |
+
for name, param in self._get_all_params(unlearned):
|
| 682 |
+
fisher_diag[name] = torch.zeros_like(param.data)
|
| 683 |
+
|
| 684 |
+
dataset = TensorDataset(*forget_splits, forget_labels)
|
| 685 |
+
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 686 |
+
|
| 687 |
+
for batch in loader:
|
| 688 |
+
*batch_splits, batch_y = batch
|
| 689 |
+
batch_y = batch_y.to(DEVICE)
|
| 690 |
+
|
| 691 |
+
logits, _ = unlearned.forward(batch_splits)
|
| 692 |
+
loss = unlearned.criterion(logits, batch_y)
|
| 693 |
+
|
| 694 |
+
unlearned.optimizer.zero_grad()
|
| 695 |
+
loss.backward()
|
| 696 |
+
|
| 697 |
+
for name, param in self._get_all_params(unlearned):
|
| 698 |
+
if param.grad is not None:
|
| 699 |
+
fisher_diag[name] += param.grad.data ** 2
|
| 700 |
+
|
| 701 |
+
# Normalize
|
| 702 |
+
n_batches = len(loader)
|
| 703 |
+
for name in fisher_diag:
|
| 704 |
+
fisher_diag[name] /= max(n_batches, 1)
|
| 705 |
+
|
| 706 |
+
# Add noise proportional to Fisher
|
| 707 |
+
with torch.no_grad():
|
| 708 |
+
for name, param in self._get_all_params(unlearned):
|
| 709 |
+
noise_std = self.noise_scale * (fisher_diag[name] + 1e-8).sqrt()
|
| 710 |
+
param.data += torch.randn_like(param.data) * noise_std
|
| 711 |
+
|
| 712 |
+
return unlearned
|
| 713 |
+
|
| 714 |
+
def _get_all_params(self, model):
|
| 715 |
+
"""Get all named parameters from VFL framework."""
|
| 716 |
+
params = []
|
| 717 |
+
for i, pm in enumerate(model.passive_models):
|
| 718 |
+
for name, param in pm.named_parameters():
|
| 719 |
+
params.append((f"passive_{i}.{name}", param))
|
| 720 |
+
for name, param in model.active_model.named_parameters():
|
| 721 |
+
params.append((f"active.{name}", param))
|
| 722 |
+
return params
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class ManifoldMixupUnlearning:
|
| 726 |
+
"""
|
| 727 |
+
Baseline 4: Manifold Mixup (Paper 1 - arxiv:2410.10922)
|
| 728 |
+
|
| 729 |
+
Performs manifold mixup in the embedding space between forget set samples
|
| 730 |
+
and random noise/other class samples, combined with gradient ascent.
|
| 731 |
+
This disrupts the learned representations for the forget set.
|
| 732 |
+
|
| 733 |
+
Adapted from: Bryan et al. (2024) "Towards Privacy-Guaranteed Label
|
| 734 |
+
Unlearning in Vertical Federated Learning"
|
| 735 |
+
"""
|
| 736 |
+
|
| 737 |
+
def __init__(self, epochs=10, lr=0.005, mixup_alpha=0.3):
|
| 738 |
+
self.epochs = epochs
|
| 739 |
+
self.lr = lr
|
| 740 |
+
self.mixup_alpha = mixup_alpha
|
| 741 |
+
|
| 742 |
+
def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
|
| 743 |
+
unlearned = model.clone()
|
| 744 |
+
|
| 745 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 746 |
+
forget_labels = y_train[forget_indices]
|
| 747 |
+
retain_splits = [xs[retain_indices] for xs in X_train_splits]
|
| 748 |
+
retain_labels = y_train[retain_indices]
|
| 749 |
+
|
| 750 |
+
all_params = []
|
| 751 |
+
for pm in unlearned.passive_models:
|
| 752 |
+
all_params += list(pm.parameters())
|
| 753 |
+
all_params += list(unlearned.active_model.parameters())
|
| 754 |
+
optimizer = optim.Adam(all_params, lr=self.lr)
|
| 755 |
+
|
| 756 |
+
unlearned.set_train()
|
| 757 |
+
for epoch in range(self.epochs):
|
| 758 |
+
# Step 1: Manifold mixup on forget set embeddings
|
| 759 |
+
forget_emb = unlearned.get_embeddings(forget_splits)
|
| 760 |
+
# Mix with random noise (simulates "corrupting" forget representations)
|
| 761 |
+
noise = torch.randn_like(forget_emb)
|
| 762 |
+
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
| 763 |
+
mixed_emb = lam * forget_emb + (1 - lam) * noise
|
| 764 |
+
|
| 765 |
+
# Gradient ascent on mixed embeddings
|
| 766 |
+
logits_mixed = unlearned.active_model(mixed_emb)
|
| 767 |
+
loss_forget = unlearned.criterion(logits_mixed, forget_labels.to(DEVICE))
|
| 768 |
+
|
| 769 |
+
# Step 2: Recovery on retain set
|
| 770 |
+
n_retain_batch = min(BATCH_SIZE, len(retain_labels))
|
| 771 |
+
idx = torch.randperm(len(retain_labels))[:n_retain_batch]
|
| 772 |
+
retain_batch = [xs[idx] for xs in retain_splits]
|
| 773 |
+
retain_batch_y = retain_labels[idx].to(DEVICE)
|
| 774 |
+
|
| 775 |
+
logits_retain, _ = unlearned.forward(retain_batch)
|
| 776 |
+
loss_retain = unlearned.criterion(logits_retain, retain_batch_y)
|
| 777 |
+
|
| 778 |
+
# Combined: ascend on forget, descend on retain
|
| 779 |
+
loss = loss_retain - 0.5 * loss_forget
|
| 780 |
+
|
| 781 |
+
optimizer.zero_grad()
|
| 782 |
+
loss.backward()
|
| 783 |
+
optimizer.step()
|
| 784 |
+
|
| 785 |
+
return unlearned
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
class FerrariUnlearning:
|
| 789 |
+
"""
|
| 790 |
+
Baseline 5: Ferrari (Paper 2 - arxiv:2405.17462)
|
| 791 |
+
|
| 792 |
+
Minimizes feature sensitivity to target features via Lipschitz-based
|
| 793 |
+
optimization. Uses Monte Carlo perturbation to estimate sensitivity
|
| 794 |
+
and optimizes to reduce it.
|
| 795 |
+
|
| 796 |
+
Adapted from: Ong et al. (2024) "Ferrari: Federated Feature Unlearning
|
| 797 |
+
via Optimizing Feature Sensitivity"
|
| 798 |
+
|
| 799 |
+
Note: Original Ferrari is for HFL. We adapt it to VFL by applying
|
| 800 |
+
sensitivity minimization to the passive party that holds the target features.
|
| 801 |
+
"""
|
| 802 |
+
|
| 803 |
+
def __init__(self, epochs=15, lr=0.005, sigma=0.01, n_samples=5):
|
| 804 |
+
self.epochs = epochs
|
| 805 |
+
self.lr = lr
|
| 806 |
+
self.sigma = sigma
|
| 807 |
+
self.n_samples = n_samples
|
| 808 |
+
|
| 809 |
+
def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices):
|
| 810 |
+
unlearned = model.clone()
|
| 811 |
+
|
| 812 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 813 |
+
forget_labels = y_train[forget_indices]
|
| 814 |
+
retain_splits = [xs[retain_indices] for xs in X_train_splits]
|
| 815 |
+
retain_labels = y_train[retain_indices]
|
| 816 |
+
|
| 817 |
+
all_params = []
|
| 818 |
+
for pm in unlearned.passive_models:
|
| 819 |
+
all_params += list(pm.parameters())
|
| 820 |
+
all_params += list(unlearned.active_model.parameters())
|
| 821 |
+
optimizer = optim.Adam(all_params, lr=self.lr)
|
| 822 |
+
|
| 823 |
+
unlearned.set_train()
|
| 824 |
+
for epoch in range(self.epochs):
|
| 825 |
+
# Sensitivity minimization on forget set
|
| 826 |
+
sensitivity_loss = torch.tensor(0.0, device=DEVICE)
|
| 827 |
+
|
| 828 |
+
logits_orig, _ = unlearned.forward(forget_splits)
|
| 829 |
+
probs_orig = F.softmax(logits_orig, dim=1)
|
| 830 |
+
|
| 831 |
+
for _ in range(self.n_samples):
|
| 832 |
+
for party_idx in range(len(forget_splits)):
|
| 833 |
+
perturbed = [xs.clone() for xs in forget_splits]
|
| 834 |
+
noise = torch.randn_like(perturbed[party_idx]) * self.sigma
|
| 835 |
+
perturbed[party_idx] = perturbed[party_idx] + noise
|
| 836 |
+
|
| 837 |
+
logits_pert, _ = unlearned.forward(perturbed)
|
| 838 |
+
probs_pert = F.softmax(logits_pert, dim=1)
|
| 839 |
+
|
| 840 |
+
# Sensitivity = expected output change per unit perturbation
|
| 841 |
+
diff = (probs_orig - probs_pert).norm(dim=1).mean()
|
| 842 |
+
sensitivity_loss = sensitivity_loss + diff
|
| 843 |
+
|
| 844 |
+
sensitivity_loss = sensitivity_loss / (self.n_samples * len(forget_splits))
|
| 845 |
+
|
| 846 |
+
# Retain utility
|
| 847 |
+
n_retain_batch = min(BATCH_SIZE, len(retain_labels))
|
| 848 |
+
idx = torch.randperm(len(retain_labels))[:n_retain_batch]
|
| 849 |
+
retain_batch = [xs[idx] for xs in retain_splits]
|
| 850 |
+
retain_batch_y = retain_labels[idx].to(DEVICE)
|
| 851 |
+
|
| 852 |
+
logits_retain, _ = unlearned.forward(retain_batch)
|
| 853 |
+
loss_retain = unlearned.criterion(logits_retain, retain_batch_y)
|
| 854 |
+
|
| 855 |
+
# Combined: minimize sensitivity + maintain retain performance
|
| 856 |
+
loss = loss_retain + 2.0 * sensitivity_loss
|
| 857 |
+
|
| 858 |
+
optimizer.zero_grad()
|
| 859 |
+
loss.backward()
|
| 860 |
+
optimizer.step()
|
| 861 |
+
|
| 862 |
+
return unlearned
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
# ============================================================================
|
| 866 |
+
# UFUSC: Unified Federated Unlearning via Sensitivity-Guided Contrastive Forgetting
|
| 867 |
+
# ============================================================================
|
| 868 |
+
|
| 869 |
+
class UFUSC:
|
| 870 |
+
"""
|
| 871 |
+
UFUSC: Unified Federated Unlearning via Sensitivity-Guided Contrastive Forgetting
|
| 872 |
+
|
| 873 |
+
The FIRST framework to simultaneously handle BOTH label AND feature unlearning
|
| 874 |
+
in Vertical Federated Learning.
|
| 875 |
+
|
| 876 |
+
Three components:
|
| 877 |
+
1. Contrastive Forgetting Loss (CFL) — Pushes forget-set embeddings toward
|
| 878 |
+
random noise while anchoring retain-set embeddings to class centroids.
|
| 879 |
+
Operates in the joint embedding space for "deep forgetting" (not just
|
| 880 |
+
output-level like gradient ascent).
|
| 881 |
+
|
| 882 |
+
2. Lipschitz Feature Sensitivity Minimization — Monte Carlo perturbation-based
|
| 883 |
+
sensitivity estimation, extended to VFL. Minimizes the model's responsiveness
|
| 884 |
+
to features associated with the forget set.
|
| 885 |
+
|
| 886 |
+
3. Dual-Variable Certification — Primal-dual formulation that provides a
|
| 887 |
+
convergence-based forgetting guarantee. The dual variable λ adaptively
|
| 888 |
+
adjusts the forgetting pressure based on how well the current model
|
| 889 |
+
has forgotten.
|
| 890 |
+
|
| 891 |
+
Loss function:
|
| 892 |
+
L = L_retain + α·L_CFL + β·L_sensitivity + γ·L_anchor + Ω·(τ - L_forget_CE)
|
| 893 |
+
|
| 894 |
+
Variants:
|
| 895 |
+
- Label Only: Uses CFL + anchor (no sensitivity)
|
| 896 |
+
- Feature Only: Uses sensitivity + CFL (no anchor)
|
| 897 |
+
- Joint: All three components (full UFUSC)
|
| 898 |
+
"""
|
| 899 |
+
|
| 900 |
+
def __init__(self, mode="joint", alpha=ALPHA, beta=BETA, gamma=GAMMA,
|
| 901 |
+
omega=OMEGA, tau=TAU, epochs=UNLEARN_EPOCHS, lr=0.005,
|
| 902 |
+
sigma=SENSITIVITY_SIGMA, n_mc_samples=SENSITIVITY_SAMPLES):
|
| 903 |
+
"""
|
| 904 |
+
Args:
|
| 905 |
+
mode: "label_only", "feature_only", or "joint"
|
| 906 |
+
alpha: weight for Contrastive Forgetting Loss
|
| 907 |
+
beta: weight for Feature Sensitivity Loss
|
| 908 |
+
gamma: weight for Anchor Loss (retain embedding stability)
|
| 909 |
+
omega: weight for dual-variable certification constraint
|
| 910 |
+
tau: forgetting threshold for certification
|
| 911 |
+
epochs: number of unlearning epochs
|
| 912 |
+
lr: learning rate for unlearning
|
| 913 |
+
sigma: std for MC perturbation (feature sensitivity)
|
| 914 |
+
n_mc_samples: number of MC samples for sensitivity
|
| 915 |
+
"""
|
| 916 |
+
assert mode in ["label_only", "feature_only", "joint"]
|
| 917 |
+
self.mode = mode
|
| 918 |
+
self.alpha = alpha
|
| 919 |
+
self.beta = beta
|
| 920 |
+
self.gamma = gamma
|
| 921 |
+
self.omega = omega
|
| 922 |
+
self.tau = tau
|
| 923 |
+
self.epochs = epochs
|
| 924 |
+
self.lr = lr
|
| 925 |
+
self.sigma = sigma
|
| 926 |
+
self.n_mc_samples = n_mc_samples
|
| 927 |
+
|
| 928 |
+
def compute_class_centroids(self, model, X_splits, y, num_classes):
|
| 929 |
+
"""
|
| 930 |
+
Compute class centroids in the joint embedding space.
|
| 931 |
+
|
| 932 |
+
These serve as "anchor points" — retain-set embeddings should
|
| 933 |
+
stay close to their class centroid during unlearning.
|
| 934 |
+
"""
|
| 935 |
+
model.set_eval()
|
| 936 |
+
with torch.no_grad():
|
| 937 |
+
embeddings = model.get_embeddings(X_splits)
|
| 938 |
+
|
| 939 |
+
centroids = {}
|
| 940 |
+
for c in range(num_classes):
|
| 941 |
+
mask = (y == c)
|
| 942 |
+
if mask.sum() > 0:
|
| 943 |
+
centroids[c] = embeddings[mask].mean(dim=0).detach()
|
| 944 |
+
else:
|
| 945 |
+
centroids[c] = torch.zeros(embeddings.shape[1], device=DEVICE)
|
| 946 |
+
|
| 947 |
+
model.set_train()
|
| 948 |
+
return centroids
|
| 949 |
+
|
| 950 |
+
def contrastive_forgetting_loss(self, model, forget_splits, forget_labels,
|
| 951 |
+
centroids, num_classes):
|
| 952 |
+
"""
|
| 953 |
+
Contrastive Forgetting Loss (CFL).
|
| 954 |
+
|
| 955 |
+
Pushes forget-set embeddings AWAY from their true class centroids
|
| 956 |
+
and TOWARD random noise. This disrupts the learned representations
|
| 957 |
+
at the embedding level, achieving "deep forgetting."
|
| 958 |
+
|
| 959 |
+
L_CFL = -||e_forget - c_true||^2 + ||e_forget - noise||^2
|
| 960 |
+
|
| 961 |
+
The first term pushes embeddings away from the correct centroid.
|
| 962 |
+
The second term pulls embeddings toward meaningless random noise.
|
| 963 |
+
"""
|
| 964 |
+
forget_emb = model.get_embeddings(forget_splits)
|
| 965 |
+
|
| 966 |
+
# Repulsion from true class centroids
|
| 967 |
+
repulsion_loss = torch.tensor(0.0, device=DEVICE)
|
| 968 |
+
for i in range(len(forget_labels)):
|
| 969 |
+
c = forget_labels[i].item()
|
| 970 |
+
if c in centroids:
|
| 971 |
+
dist = (forget_emb[i] - centroids[c]).norm()
|
| 972 |
+
repulsion_loss = repulsion_loss - dist # Maximize distance
|
| 973 |
+
|
| 974 |
+
repulsion_loss = repulsion_loss / max(len(forget_labels), 1)
|
| 975 |
+
|
| 976 |
+
# Attraction toward noise (make embeddings meaningless)
|
| 977 |
+
noise_target = torch.randn_like(forget_emb)
|
| 978 |
+
attraction_loss = (forget_emb - noise_target).norm(dim=1).mean()
|
| 979 |
+
|
| 980 |
+
return repulsion_loss + 0.5 * attraction_loss
|
| 981 |
+
|
| 982 |
+
def feature_sensitivity_loss(self, model, forget_splits):
|
| 983 |
+
"""
|
| 984 |
+
Lipschitz Feature Sensitivity Loss.
|
| 985 |
+
|
| 986 |
+
Measures and minimizes the model's sensitivity to features in the
|
| 987 |
+
forget set via Monte Carlo perturbation. Extended from Ferrari to VFL.
|
| 988 |
+
|
| 989 |
+
For each passive party's features:
|
| 990 |
+
S = E[||f(x) - f(x + δ)|| / ||δ||] where δ ~ N(0, σ²I)
|
| 991 |
+
|
| 992 |
+
We minimize S to make the model "insensitive" to forget-set features.
|
| 993 |
+
"""
|
| 994 |
+
sensitivity = torch.tensor(0.0, device=DEVICE)
|
| 995 |
+
|
| 996 |
+
logits_orig, _ = model.forward(forget_splits)
|
| 997 |
+
probs_orig = F.softmax(logits_orig, dim=1)
|
| 998 |
+
|
| 999 |
+
for _ in range(self.n_mc_samples):
|
| 1000 |
+
for party_idx in range(len(forget_splits)):
|
| 1001 |
+
perturbed = [xs.clone() for xs in forget_splits]
|
| 1002 |
+
noise = torch.randn_like(perturbed[party_idx]) * self.sigma
|
| 1003 |
+
perturbed[party_idx] = perturbed[party_idx] + noise
|
| 1004 |
+
|
| 1005 |
+
logits_pert, _ = model.forward(perturbed)
|
| 1006 |
+
probs_pert = F.softmax(logits_pert, dim=1)
|
| 1007 |
+
|
| 1008 |
+
diff = (probs_orig - probs_pert).norm(dim=1).mean()
|
| 1009 |
+
sensitivity = sensitivity + diff
|
| 1010 |
+
|
| 1011 |
+
sensitivity = sensitivity / (self.n_mc_samples * len(forget_splits))
|
| 1012 |
+
return sensitivity
|
| 1013 |
+
|
| 1014 |
+
def anchor_loss(self, model, retain_splits, retain_labels, centroids):
|
| 1015 |
+
"""
|
| 1016 |
+
Anchor Loss.
|
| 1017 |
+
|
| 1018 |
+
Ensures retain-set embeddings stay close to their class centroids
|
| 1019 |
+
during unlearning. This prevents "catastrophic forgetting" of
|
| 1020 |
+
the retain set while aggressively unlearning the forget set.
|
| 1021 |
+
|
| 1022 |
+
L_anchor = E[||e_retain - c_class||^2]
|
| 1023 |
+
"""
|
| 1024 |
+
retain_emb = model.get_embeddings(retain_splits)
|
| 1025 |
+
|
| 1026 |
+
loss = torch.tensor(0.0, device=DEVICE)
|
| 1027 |
+
for i in range(len(retain_labels)):
|
| 1028 |
+
c = retain_labels[i].item()
|
| 1029 |
+
if c in centroids:
|
| 1030 |
+
loss = loss + (retain_emb[i] - centroids[c]).norm() ** 2
|
| 1031 |
+
|
| 1032 |
+
return loss / max(len(retain_labels), 1)
|
| 1033 |
+
|
| 1034 |
+
def dual_variable_certification(self, model, forget_splits, forget_labels):
|
| 1035 |
+
"""
|
| 1036 |
+
Dual-Variable Certification.
|
| 1037 |
+
|
| 1038 |
+
Primal-dual formulation that provides a convergence-based forgetting
|
| 1039 |
+
guarantee. The constraint is:
|
| 1040 |
+
|
| 1041 |
+
L_forget_CE ≥ τ (cross-entropy on forget set should be HIGH)
|
| 1042 |
+
|
| 1043 |
+
We enforce this via:
|
| 1044 |
+
Ω · max(0, τ - L_forget_CE)
|
| 1045 |
+
|
| 1046 |
+
When the forget CE is below τ, this adds pressure to increase it.
|
| 1047 |
+
When it's above τ, this term vanishes (constraint satisfied).
|
| 1048 |
+
|
| 1049 |
+
Inspired by FedORA (arxiv:2512.23171).
|
| 1050 |
+
"""
|
| 1051 |
+
logits, _ = model.forward(forget_splits)
|
| 1052 |
+
forget_ce = model.criterion(logits, forget_labels.to(DEVICE))
|
| 1053 |
+
|
| 1054 |
+
# Penalty when forget CE is below threshold
|
| 1055 |
+
violation = F.relu(self.tau - forget_ce)
|
| 1056 |
+
return self.omega * violation
|
| 1057 |
+
|
| 1058 |
+
def unlearn(self, model, X_train_splits, y_train, forget_indices, retain_indices,
|
| 1059 |
+
num_classes=10):
|
| 1060 |
+
"""
|
| 1061 |
+
Execute UFUSC unlearning.
|
| 1062 |
+
|
| 1063 |
+
Args:
|
| 1064 |
+
model: trained VFLFramework
|
| 1065 |
+
X_train_splits: list of K feature tensors
|
| 1066 |
+
y_train: training labels
|
| 1067 |
+
forget_indices: indices of forget set
|
| 1068 |
+
retain_indices: indices of retain set
|
| 1069 |
+
num_classes: number of classes
|
| 1070 |
+
|
| 1071 |
+
Returns:
|
| 1072 |
+
unlearned VFLFramework
|
| 1073 |
+
"""
|
| 1074 |
+
unlearned = model.clone()
|
| 1075 |
+
|
| 1076 |
+
forget_splits = [xs[forget_indices] for xs in X_train_splits]
|
| 1077 |
+
forget_labels = y_train[forget_indices]
|
| 1078 |
+
retain_splits = [xs[retain_indices] for xs in X_train_splits]
|
| 1079 |
+
retain_labels = y_train[retain_indices]
|
| 1080 |
+
|
| 1081 |
+
# Compute class centroids before unlearning
|
| 1082 |
+
centroids = self.compute_class_centroids(
|
| 1083 |
+
unlearned, [xs[retain_indices] for xs in X_train_splits],
|
| 1084 |
+
retain_labels, num_classes
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
all_params = []
|
| 1088 |
+
for pm in unlearned.passive_models:
|
| 1089 |
+
all_params += list(pm.parameters())
|
| 1090 |
+
all_params += list(unlearned.active_model.parameters())
|
| 1091 |
+
optimizer = optim.Adam(all_params, lr=self.lr)
|
| 1092 |
+
|
| 1093 |
+
unlearned.set_train()
|
| 1094 |
+
for epoch in range(self.epochs):
|
| 1095 |
+
total_loss = torch.tensor(0.0, device=DEVICE)
|
| 1096 |
+
|
| 1097 |
+
# 1. Retain set CE loss (always active)
|
| 1098 |
+
n_retain_batch = min(BATCH_SIZE, len(retain_labels))
|
| 1099 |
+
idx = torch.randperm(len(retain_labels))[:n_retain_batch]
|
| 1100 |
+
retain_batch = [xs[idx] for xs in retain_splits]
|
| 1101 |
+
retain_batch_y = retain_labels[idx].to(DEVICE)
|
| 1102 |
+
|
| 1103 |
+
logits_retain, _ = unlearned.forward(retain_batch)
|
| 1104 |
+
loss_retain = unlearned.criterion(logits_retain, retain_batch_y)
|
| 1105 |
+
total_loss = total_loss + loss_retain
|
| 1106 |
+
|
| 1107 |
+
# 2. Contrastive Forgetting Loss (CFL)
|
| 1108 |
+
if self.mode in ["label_only", "joint"]:
|
| 1109 |
+
cfl = self.contrastive_forgetting_loss(
|
| 1110 |
+
unlearned, forget_splits, forget_labels, centroids, num_classes
|
| 1111 |
+
)
|
| 1112 |
+
total_loss = total_loss + self.alpha * cfl
|
| 1113 |
+
|
| 1114 |
+
if self.mode in ["feature_only", "joint"]:
|
| 1115 |
+
cfl_feat = self.contrastive_forgetting_loss(
|
| 1116 |
+
unlearned, forget_splits, forget_labels, centroids, num_classes
|
| 1117 |
+
)
|
| 1118 |
+
total_loss = total_loss + self.alpha * 0.5 * cfl_feat
|
| 1119 |
+
|
| 1120 |
+
# 3. Feature Sensitivity Loss
|
| 1121 |
+
if self.mode in ["feature_only", "joint"]:
|
| 1122 |
+
sens = self.feature_sensitivity_loss(unlearned, forget_splits)
|
| 1123 |
+
total_loss = total_loss + self.beta * sens
|
| 1124 |
+
|
| 1125 |
+
# 4. Anchor Loss
|
| 1126 |
+
if self.mode in ["label_only", "joint"]:
|
| 1127 |
+
anc = self.anchor_loss(
|
| 1128 |
+
unlearned, retain_batch, retain_batch_y, centroids
|
| 1129 |
+
)
|
| 1130 |
+
total_loss = total_loss + self.gamma * anc
|
| 1131 |
+
|
| 1132 |
+
# 5. Dual-Variable Certification
|
| 1133 |
+
cert = self.dual_variable_certification(
|
| 1134 |
+
unlearned, forget_splits, forget_labels
|
| 1135 |
+
)
|
| 1136 |
+
total_loss = total_loss + cert
|
| 1137 |
+
|
| 1138 |
+
optimizer.zero_grad()
|
| 1139 |
+
total_loss.backward()
|
| 1140 |
+
# Gradient clipping for stability
|
| 1141 |
+
torch.nn.utils.clip_grad_norm_(all_params, max_norm=5.0)
|
| 1142 |
+
optimizer.step()
|
| 1143 |
+
|
| 1144 |
+
return unlearned
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
# ============================================================================
|
| 1148 |
+
# Experiment Runner
|
| 1149 |
+
# ============================================================================
|
| 1150 |
+
|
| 1151 |
+
def run_single_experiment(dataset_name, num_parties=NUM_PASSIVE_PARTIES, verbose=True):
|
| 1152 |
+
"""
|
| 1153 |
+
Run complete experiment for one dataset.
|
| 1154 |
+
|
| 1155 |
+
Steps:
|
| 1156 |
+
1. Load dataset
|
| 1157 |
+
2. Split features across K passive parties (VFL)
|
| 1158 |
+
3. Train VFL model
|
| 1159 |
+
4. Create forget/retain split
|
| 1160 |
+
5. Evaluate original model
|
| 1161 |
+
6. Run all 5 baselines
|
| 1162 |
+
7. Run 3 UFUSC variants
|
| 1163 |
+
8. Return all results
|
| 1164 |
+
|
| 1165 |
+
Args:
|
| 1166 |
+
dataset_name: "MNIST", "Fashion-MNIST", or "CIFAR-10"
|
| 1167 |
+
num_parties: number of passive parties
|
| 1168 |
+
verbose: print progress
|
| 1169 |
+
|
| 1170 |
+
Returns:
|
| 1171 |
+
list of result dicts
|
| 1172 |
+
"""
|
| 1173 |
+
set_seed()
|
| 1174 |
+
print(f"\n{'='*70}")
|
| 1175 |
+
print(f" EXPERIMENT: {dataset_name} (K={num_parties} parties)")
|
| 1176 |
+
print(f"{'='*70}")
|
| 1177 |
+
|
| 1178 |
+
# 1. Load dataset
|
| 1179 |
+
print("\n[1/8] Loading dataset...")
|
| 1180 |
+
X_train, y_train, X_test, y_test, num_classes, feature_dim = load_dataset(dataset_name)
|
| 1181 |
+
|
| 1182 |
+
# 2. Split features for VFL
|
| 1183 |
+
print("[2/8] Splitting features for VFL...")
|
| 1184 |
+
X_train_splits = list(split_features_vfl(X_train, num_parties))
|
| 1185 |
+
X_test_splits = list(split_features_vfl(X_test, num_parties))
|
| 1186 |
+
feature_dims = [xs.shape[1] for xs in X_train_splits]
|
| 1187 |
+
print(f" Party feature dims: {feature_dims}")
|
| 1188 |
+
|
| 1189 |
+
# 3. Train VFL model
|
| 1190 |
+
print("[3/8] Training VFL model...")
|
| 1191 |
+
model = VFLFramework(feature_dims, num_classes, num_parties=num_parties)
|
| 1192 |
+
model.train_model(X_train_splits, y_train, X_test_splits, y_test, epochs=TRAIN_EPOCHS)
|
| 1193 |
+
|
| 1194 |
+
# 4. Create forget/retain split
|
| 1195 |
+
print("[4/8] Creating forget/retain split...")
|
| 1196 |
+
forget_class = 0
|
| 1197 |
+
forget_indices, retain_indices = create_forget_retain_split(
|
| 1198 |
+
y_train, forget_class=forget_class, forget_ratio=FORGET_RATIO
|
| 1199 |
+
)
|
| 1200 |
+
print(f" Forget set: {len(forget_indices)} samples (class {forget_class})")
|
| 1201 |
+
print(f" Retain set: {len(retain_indices)} samples")
|
| 1202 |
+
|
| 1203 |
+
# 5. Evaluate original model
|
| 1204 |
+
print("[5/8] Evaluating original model...")
|
| 1205 |
+
original_metrics = full_evaluation(
|
| 1206 |
+
model, X_train_splits, y_train, X_test_splits, y_test,
|
| 1207 |
+
forget_indices, retain_indices, forget_class
|
| 1208 |
+
)
|
| 1209 |
+
original_metrics["method"] = "Original (No Unlearn)"
|
| 1210 |
+
original_metrics["time_seconds"] = 0
|
| 1211 |
+
print(f" Original: {original_metrics}")
|
| 1212 |
+
|
| 1213 |
+
results = [original_metrics]
|
| 1214 |
+
|
| 1215 |
+
# 6. Run baselines
|
| 1216 |
+
baselines = [
|
| 1217 |
+
("Gradient Ascent", GradientAscentUnlearning(epochs=5, lr=0.01)),
|
| 1218 |
+
("Fine-tuning", FineTuneUnlearning(epochs=10, lr=0.001)),
|
| 1219 |
+
("Fisher Forgetting", FisherForgetting(noise_scale=0.01)),
|
| 1220 |
+
("Manifold Mixup (P1)", ManifoldMixupUnlearning(epochs=10, lr=0.005)),
|
| 1221 |
+
("Ferrari (P2)", FerrariUnlearning(epochs=15, lr=0.005)),
|
| 1222 |
+
]
|
| 1223 |
+
|
| 1224 |
+
print("[6/8] Running baselines...")
|
| 1225 |
+
for name, method in baselines:
|
| 1226 |
+
print(f" Running {name}...")
|
| 1227 |
+
t0 = time.time()
|
| 1228 |
+
unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices)
|
| 1229 |
+
elapsed = time.time() - t0
|
| 1230 |
+
|
| 1231 |
+
metrics = full_evaluation(
|
| 1232 |
+
unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1233 |
+
forget_indices, retain_indices, forget_class
|
| 1234 |
+
)
|
| 1235 |
+
metrics["method"] = name
|
| 1236 |
+
metrics["time_seconds"] = round(elapsed, 2)
|
| 1237 |
+
results.append(metrics)
|
| 1238 |
+
print(f" {name}: Forget={metrics['forget_acc']:.1f}%, "
|
| 1239 |
+
f"Retain={metrics['retain_acc']:.1f}%, MIA={metrics['mia_asr']:.1f}%")
|
| 1240 |
+
|
| 1241 |
+
# 7. Run UFUSC variants
|
| 1242 |
+
print("[7/8] Running UFUSC variants...")
|
| 1243 |
+
ufusc_variants = [
|
| 1244 |
+
("UFUSC (Label Only)", UFUSC(mode="label_only", epochs=UNLEARN_EPOCHS)),
|
| 1245 |
+
("UFUSC (Feature Only)", UFUSC(mode="feature_only", epochs=UNLEARN_EPOCHS)),
|
| 1246 |
+
("UFUSC (Joint)", UFUSC(mode="joint", epochs=UNLEARN_EPOCHS)),
|
| 1247 |
+
]
|
| 1248 |
+
|
| 1249 |
+
for name, method in ufusc_variants:
|
| 1250 |
+
print(f" Running {name}...")
|
| 1251 |
+
t0 = time.time()
|
| 1252 |
+
unlearned = method.unlearn(
|
| 1253 |
+
model, X_train_splits, y_train, forget_indices, retain_indices,
|
| 1254 |
+
num_classes=num_classes
|
| 1255 |
+
)
|
| 1256 |
+
elapsed = time.time() - t0
|
| 1257 |
+
|
| 1258 |
+
metrics = full_evaluation(
|
| 1259 |
+
unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1260 |
+
forget_indices, retain_indices, forget_class
|
| 1261 |
+
)
|
| 1262 |
+
metrics["method"] = name
|
| 1263 |
+
metrics["time_seconds"] = round(elapsed, 2)
|
| 1264 |
+
results.append(metrics)
|
| 1265 |
+
print(f" {name}: Forget={metrics['forget_acc']:.1f}%, "
|
| 1266 |
+
f"Retain={metrics['retain_acc']:.1f}%, MIA={metrics['mia_asr']:.1f}%")
|
| 1267 |
+
|
| 1268 |
+
# 8. Summary
|
| 1269 |
+
print(f"\n[8/8] {dataset_name} Summary:")
|
| 1270 |
+
print(f" {'Method':<25} {'Test':>8} {'Forget':>8} {'Retain':>8} {'MIA':>8} {'Sens':>8}")
|
| 1271 |
+
print(f" {'-'*73}")
|
| 1272 |
+
for r in results:
|
| 1273 |
+
print(f" {r['method']:<25} {r['test_acc']:>7.2f}% {r['forget_acc']:>7.2f}% "
|
| 1274 |
+
f"{r['retain_acc']:>7.2f}% {r['mia_asr']:>7.1f}% {r['feature_sensitivity']:>7.3f}")
|
| 1275 |
+
|
| 1276 |
+
return results
|
| 1277 |
+
|
| 1278 |
+
|
| 1279 |
+
# ============================================================================
|
| 1280 |
+
# Ablation Study
|
| 1281 |
+
# ============================================================================
|
| 1282 |
+
|
| 1283 |
+
def run_ablation_study(dataset_name="MNIST"):
|
| 1284 |
+
"""
|
| 1285 |
+
Ablation study on UFUSC hyperparameters: α, β, γ, and unlearning epochs.
|
| 1286 |
+
|
| 1287 |
+
Tests the impact of each component by varying one hyperparameter
|
| 1288 |
+
while keeping others at their default values.
|
| 1289 |
+
|
| 1290 |
+
Returns:
|
| 1291 |
+
list of ablation result dicts
|
| 1292 |
+
"""
|
| 1293 |
+
set_seed()
|
| 1294 |
+
print(f"\n{'='*70}")
|
| 1295 |
+
print(f" ABLATION STUDY: {dataset_name}")
|
| 1296 |
+
print(f"{'='*70}")
|
| 1297 |
+
|
| 1298 |
+
# Load and prepare
|
| 1299 |
+
X_train, y_train, X_test, y_test, num_classes, feature_dim = load_dataset(dataset_name)
|
| 1300 |
+
X_train_splits = list(split_features_vfl(X_train))
|
| 1301 |
+
X_test_splits = list(split_features_vfl(X_test))
|
| 1302 |
+
feature_dims = [xs.shape[1] for xs in X_train_splits]
|
| 1303 |
+
|
| 1304 |
+
model = VFLFramework(feature_dims, num_classes)
|
| 1305 |
+
model.train_model(X_train_splits, y_train, X_test_splits, y_test, epochs=TRAIN_EPOCHS, verbose=False)
|
| 1306 |
+
|
| 1307 |
+
forget_indices, retain_indices = create_forget_retain_split(y_train)
|
| 1308 |
+
|
| 1309 |
+
ablation_results = []
|
| 1310 |
+
|
| 1311 |
+
# Ablation 1: Vary α (CFL weight)
|
| 1312 |
+
print("\n Ablation: α (CFL weight)")
|
| 1313 |
+
for alpha_val in [0.0, 0.5, 1.0, 2.0, 5.0]:
|
| 1314 |
+
method = UFUSC(mode="joint", alpha=alpha_val, beta=BETA, gamma=GAMMA, epochs=UNLEARN_EPOCHS)
|
| 1315 |
+
unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
|
| 1316 |
+
metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1317 |
+
forget_indices, retain_indices)
|
| 1318 |
+
metrics["ablation_param"] = "alpha"
|
| 1319 |
+
metrics["ablation_value"] = alpha_val
|
| 1320 |
+
ablation_results.append(metrics)
|
| 1321 |
+
print(f" α={alpha_val}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
|
| 1322 |
+
|
| 1323 |
+
# Ablation 2: Vary β (Sensitivity weight)
|
| 1324 |
+
print("\n Ablation: β (Sensitivity weight)")
|
| 1325 |
+
for beta_val in [0.0, 0.25, 0.5, 1.0, 2.0]:
|
| 1326 |
+
method = UFUSC(mode="joint", alpha=ALPHA, beta=beta_val, gamma=GAMMA, epochs=UNLEARN_EPOCHS)
|
| 1327 |
+
unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
|
| 1328 |
+
metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1329 |
+
forget_indices, retain_indices)
|
| 1330 |
+
metrics["ablation_param"] = "beta"
|
| 1331 |
+
metrics["ablation_value"] = beta_val
|
| 1332 |
+
ablation_results.append(metrics)
|
| 1333 |
+
print(f" β={beta_val}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
|
| 1334 |
+
|
| 1335 |
+
# Ablation 3: Vary γ (Anchor weight)
|
| 1336 |
+
print("\n Ablation: γ (Anchor weight)")
|
| 1337 |
+
for gamma_val in [0.0, 0.1, 0.3, 0.5, 1.0]:
|
| 1338 |
+
method = UFUSC(mode="joint", alpha=ALPHA, beta=BETA, gamma=gamma_val, epochs=UNLEARN_EPOCHS)
|
| 1339 |
+
unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
|
| 1340 |
+
metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1341 |
+
forget_indices, retain_indices)
|
| 1342 |
+
metrics["ablation_param"] = "gamma"
|
| 1343 |
+
metrics["ablation_value"] = gamma_val
|
| 1344 |
+
ablation_results.append(metrics)
|
| 1345 |
+
print(f" γ={gamma_val}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
|
| 1346 |
+
|
| 1347 |
+
# Ablation 4: Vary unlearning epochs
|
| 1348 |
+
print("\n Ablation: Unlearning epochs")
|
| 1349 |
+
for ep in [1, 5, 10, 15, 20]:
|
| 1350 |
+
method = UFUSC(mode="joint", alpha=ALPHA, beta=BETA, gamma=GAMMA, epochs=ep)
|
| 1351 |
+
unlearned = method.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
|
| 1352 |
+
metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1353 |
+
forget_indices, retain_indices)
|
| 1354 |
+
metrics["ablation_param"] = "epochs"
|
| 1355 |
+
metrics["ablation_value"] = ep
|
| 1356 |
+
ablation_results.append(metrics)
|
| 1357 |
+
print(f" epochs={ep}: Forget={metrics['forget_acc']:.1f}%, Retain={metrics['retain_acc']:.1f}%")
|
| 1358 |
+
|
| 1359 |
+
return ablation_results
|
| 1360 |
+
|
| 1361 |
+
|
| 1362 |
+
# ============================================================================
|
| 1363 |
+
# Scalability Analysis
|
| 1364 |
+
# ============================================================================
|
| 1365 |
+
|
| 1366 |
+
def run_scalability_analysis(dataset_name="MNIST"):
|
| 1367 |
+
"""
|
| 1368 |
+
Scalability analysis: test UFUSC with varying number of passive parties K.
|
| 1369 |
+
|
| 1370 |
+
Tests K = 2, 3, 4, 6 to see how the method scales in VFL settings
|
| 1371 |
+
with different numbers of data holders.
|
| 1372 |
+
|
| 1373 |
+
Returns:
|
| 1374 |
+
list of scalability result dicts
|
| 1375 |
+
"""
|
| 1376 |
+
set_seed()
|
| 1377 |
+
print(f"\n{'='*70}")
|
| 1378 |
+
print(f" SCALABILITY ANALYSIS: {dataset_name}")
|
| 1379 |
+
print(f"{'='*70}")
|
| 1380 |
+
|
| 1381 |
+
X_train, y_train, X_test, y_test, num_classes, feature_dim = load_dataset(dataset_name)
|
| 1382 |
+
|
| 1383 |
+
scalability_results = []
|
| 1384 |
+
|
| 1385 |
+
for K in [2, 3, 4, 6]:
|
| 1386 |
+
print(f"\n K={K} parties...")
|
| 1387 |
+
X_train_splits = list(split_features_vfl(X_train, K))
|
| 1388 |
+
X_test_splits = list(split_features_vfl(X_test, K))
|
| 1389 |
+
feature_dims = [xs.shape[1] for xs in X_train_splits]
|
| 1390 |
+
|
| 1391 |
+
model = VFLFramework(feature_dims, num_classes, num_parties=K)
|
| 1392 |
+
model.train_model(X_train_splits, y_train, X_test_splits, y_test,
|
| 1393 |
+
epochs=TRAIN_EPOCHS, verbose=False)
|
| 1394 |
+
|
| 1395 |
+
forget_indices, retain_indices = create_forget_retain_split(y_train)
|
| 1396 |
+
|
| 1397 |
+
# Evaluate original
|
| 1398 |
+
orig_metrics = full_evaluation(model, X_train_splits, y_train, X_test_splits, y_test,
|
| 1399 |
+
forget_indices, retain_indices)
|
| 1400 |
+
|
| 1401 |
+
# Run UFUSC-Joint
|
| 1402 |
+
ufusc = UFUSC(mode="joint", epochs=UNLEARN_EPOCHS)
|
| 1403 |
+
t0 = time.time()
|
| 1404 |
+
unlearned = ufusc.unlearn(model, X_train_splits, y_train, forget_indices, retain_indices, num_classes)
|
| 1405 |
+
elapsed = time.time() - t0
|
| 1406 |
+
|
| 1407 |
+
ufusc_metrics = full_evaluation(unlearned, X_train_splits, y_train, X_test_splits, y_test,
|
| 1408 |
+
forget_indices, retain_indices)
|
| 1409 |
+
|
| 1410 |
+
result = {
|
| 1411 |
+
"K": K,
|
| 1412 |
+
"original_test_acc": orig_metrics["test_acc"],
|
| 1413 |
+
"original_forget_acc": orig_metrics["forget_acc"],
|
| 1414 |
+
"ufusc_test_acc": ufusc_metrics["test_acc"],
|
| 1415 |
+
"ufusc_forget_acc": ufusc_metrics["forget_acc"],
|
| 1416 |
+
"ufusc_retain_acc": ufusc_metrics["retain_acc"],
|
| 1417 |
+
"ufusc_mia_asr": ufusc_metrics["mia_asr"],
|
| 1418 |
+
"time_seconds": round(elapsed, 2)
|
| 1419 |
+
}
|
| 1420 |
+
scalability_results.append(result)
|
| 1421 |
+
print(f" K={K}: Original Test={orig_metrics['test_acc']:.1f}%, "
|
| 1422 |
+
f"UFUSC Forget={ufusc_metrics['forget_acc']:.1f}%, "
|
| 1423 |
+
f"Retain={ufusc_metrics['retain_acc']:.1f}%, Time={elapsed:.1f}s")
|
| 1424 |
+
|
| 1425 |
+
return scalability_results
|
| 1426 |
+
|
| 1427 |
+
|
| 1428 |
+
# ============================================================================
|
| 1429 |
+
# Visualization
|
| 1430 |
+
# ============================================================================
|
| 1431 |
+
|
| 1432 |
+
def create_visualizations(all_results, ablation_results=None, scalability_results=None):
|
| 1433 |
+
"""
|
| 1434 |
+
Create all publication-quality figures.
|
| 1435 |
+
|
| 1436 |
+
Generates:
|
| 1437 |
+
- Comparison bar charts (1 per dataset)
|
| 1438 |
+
- Radar plots (1 per dataset)
|
| 1439 |
+
- Ablation study plot
|
| 1440 |
+
- Scalability analysis plot
|
| 1441 |
+
- Privacy-utility tradeoff plots (1 per dataset)
|
| 1442 |
+
"""
|
| 1443 |
+
try:
|
| 1444 |
+
import matplotlib
|
| 1445 |
+
matplotlib.use('Agg')
|
| 1446 |
+
import matplotlib.pyplot as plt
|
| 1447 |
+
import seaborn as sns
|
| 1448 |
+
sns.set_theme(style="whitegrid")
|
| 1449 |
+
except ImportError:
|
| 1450 |
+
print("WARNING: matplotlib/seaborn not available. Skipping visualization.")
|
| 1451 |
+
return
|
| 1452 |
+
|
| 1453 |
+
colors = {
|
| 1454 |
+
"Original (No Unlearn)": "#95a5a6",
|
| 1455 |
+
"Gradient Ascent": "#e74c3c",
|
| 1456 |
+
"Fine-tuning": "#e67e22",
|
| 1457 |
+
"Fisher Forgetting": "#f39c12",
|
| 1458 |
+
"Manifold Mixup (P1)": "#27ae60",
|
| 1459 |
+
"Ferrari (P2)": "#2980b9",
|
| 1460 |
+
"UFUSC (Label Only)": "#8e44ad",
|
| 1461 |
+
"UFUSC (Feature Only)": "#1abc9c",
|
| 1462 |
+
"UFUSC (Joint)": "#c0392b",
|
| 1463 |
+
}
|
| 1464 |
+
|
| 1465 |
+
# ---- Comparison Bar Charts (one per dataset) ----
|
| 1466 |
+
for dataset_name, results in all_results.items():
|
| 1467 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
| 1468 |
+
fig.suptitle(f"{dataset_name} — Unlearning Method Comparison", fontsize=16, fontweight='bold')
|
| 1469 |
+
|
| 1470 |
+
methods = [r["method"] for r in results]
|
| 1471 |
+
method_colors = [colors.get(m, "#333333") for m in methods]
|
| 1472 |
+
|
| 1473 |
+
# Forget Accuracy (lower is better)
|
| 1474 |
+
vals = [r["forget_acc"] for r in results]
|
| 1475 |
+
axes[0].barh(methods, vals, color=method_colors)
|
| 1476 |
+
axes[0].set_xlabel("Forget Accuracy (%) ↓")
|
| 1477 |
+
axes[0].set_title("Forgetting Quality")
|
| 1478 |
+
axes[0].invert_yaxis()
|
| 1479 |
+
|
| 1480 |
+
# Retain Accuracy (higher is better)
|
| 1481 |
+
vals = [r["retain_acc"] for r in results]
|
| 1482 |
+
axes[1].barh(methods, vals, color=method_colors)
|
| 1483 |
+
axes[1].set_xlabel("Retain Accuracy (%) ↑")
|
| 1484 |
+
axes[1].set_title("Utility Preservation")
|
| 1485 |
+
axes[1].invert_yaxis()
|
| 1486 |
+
|
| 1487 |
+
# MIA ASR (lower is better)
|
| 1488 |
+
vals = [r["mia_asr"] for r in results]
|
| 1489 |
+
axes[2].barh(methods, vals, color=method_colors)
|
| 1490 |
+
axes[2].set_xlabel("MIA ASR (%) ↓")
|
| 1491 |
+
axes[2].set_title("Privacy Protection")
|
| 1492 |
+
axes[2].axvline(x=50, color='red', linestyle='--', alpha=0.5, label='Random (50%)')
|
| 1493 |
+
axes[2].invert_yaxis()
|
| 1494 |
+
axes[2].legend()
|
| 1495 |
+
|
| 1496 |
+
plt.tight_layout()
|
| 1497 |
+
plt.savefig(f"figures/{dataset_name.replace('-', '_')}_comparison.png", dpi=150, bbox_inches='tight')
|
| 1498 |
+
plt.close()
|
| 1499 |
+
print(f" Saved: figures/{dataset_name.replace('-', '_')}_comparison.png")
|
| 1500 |
+
|
| 1501 |
+
# ---- Radar Plots (one per dataset) ----
|
| 1502 |
+
for dataset_name, results in all_results.items():
|
| 1503 |
+
# Select key methods for radar
|
| 1504 |
+
key_methods = ["Gradient Ascent", "Manifold Mixup (P1)", "Ferrari (P2)", "UFUSC (Joint)"]
|
| 1505 |
+
key_results = [r for r in results if r["method"] in key_methods]
|
| 1506 |
+
|
| 1507 |
+
if len(key_results) < 2:
|
| 1508 |
+
continue
|
| 1509 |
+
|
| 1510 |
+
categories = ["Retain Acc", "1 - Forget Acc", "1 - MIA ASR", "Low Sensitivity"]
|
| 1511 |
+
N = len(categories)
|
| 1512 |
+
angles = [n / float(N) * 2 * np.pi for n in range(N)]
|
| 1513 |
+
angles += angles[:1] # Close the polygon
|
| 1514 |
+
|
| 1515 |
+
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
|
| 1516 |
+
ax.set_title(f"{dataset_name} — Method Radar Comparison", fontsize=14, fontweight='bold', pad=20)
|
| 1517 |
+
|
| 1518 |
+
for r in key_results:
|
| 1519 |
+
values = [
|
| 1520 |
+
r["retain_acc"] / 100,
|
| 1521 |
+
(100 - r["forget_acc"]) / 100,
|
| 1522 |
+
(100 - r["mia_asr"]) / 100,
|
| 1523 |
+
max(0, 1 - r["feature_sensitivity"]),
|
| 1524 |
+
]
|
| 1525 |
+
values += values[:1]
|
| 1526 |
+
color = colors.get(r["method"], "#333333")
|
| 1527 |
+
ax.plot(angles, values, 'o-', linewidth=2, label=r["method"], color=color)
|
| 1528 |
+
ax.fill(angles, values, alpha=0.1, color=color)
|
| 1529 |
+
|
| 1530 |
+
ax.set_xticks(angles[:-1])
|
| 1531 |
+
ax.set_xticklabels(categories)
|
| 1532 |
+
ax.set_ylim(0, 1)
|
| 1533 |
+
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
|
| 1534 |
+
|
| 1535 |
+
plt.tight_layout()
|
| 1536 |
+
plt.savefig(f"figures/{dataset_name.replace('-', '_')}_radar.png", dpi=150, bbox_inches='tight')
|
| 1537 |
+
plt.close()
|
| 1538 |
+
print(f" Saved: figures/{dataset_name.replace('-', '_')}_radar.png")
|
| 1539 |
+
|
| 1540 |
+
# ---- Ablation Study Plot ----
|
| 1541 |
+
if ablation_results:
|
| 1542 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 1543 |
+
fig.suptitle("UFUSC Ablation Study (MNIST)", fontsize=16, fontweight='bold')
|
| 1544 |
+
|
| 1545 |
+
params = {"alpha": "α (CFL weight)", "beta": "β (Sensitivity weight)",
|
| 1546 |
+
"gamma": "γ (Anchor weight)", "epochs": "Unlearning Epochs"}
|
| 1547 |
+
|
| 1548 |
+
for idx, (param_key, param_label) in enumerate(params.items()):
|
| 1549 |
+
ax = axes[idx // 2][idx % 2]
|
| 1550 |
+
param_results = [r for r in ablation_results if r["ablation_param"] == param_key]
|
| 1551 |
+
|
| 1552 |
+
if not param_results:
|
| 1553 |
+
continue
|
| 1554 |
+
|
| 1555 |
+
x_vals = [r["ablation_value"] for r in param_results]
|
| 1556 |
+
forget_vals = [r["forget_acc"] for r in param_results]
|
| 1557 |
+
retain_vals = [r["retain_acc"] for r in param_results]
|
| 1558 |
+
|
| 1559 |
+
ax.plot(x_vals, forget_vals, 's-', color='#e74c3c', label='Forget Acc ↓', linewidth=2, markersize=8)
|
| 1560 |
+
ax.plot(x_vals, retain_vals, 'o-', color='#2980b9', label='Retain Acc ↑', linewidth=2, markersize=8)
|
| 1561 |
+
ax.set_xlabel(param_label)
|
| 1562 |
+
ax.set_ylabel("Accuracy (%)")
|
| 1563 |
+
ax.set_title(f"Effect of {param_label}")
|
| 1564 |
+
ax.legend()
|
| 1565 |
+
ax.grid(True, alpha=0.3)
|
| 1566 |
+
|
| 1567 |
+
plt.tight_layout()
|
| 1568 |
+
plt.savefig("figures/ablation_study.png", dpi=150, bbox_inches='tight')
|
| 1569 |
+
plt.close()
|
| 1570 |
+
print(" Saved: figures/ablation_study.png")
|
| 1571 |
+
|
| 1572 |
+
# ---- Scalability Analysis Plot ----
|
| 1573 |
+
if scalability_results:
|
| 1574 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 1575 |
+
fig.suptitle("UFUSC Scalability Analysis (Varying K)", fontsize=14, fontweight='bold')
|
| 1576 |
+
|
| 1577 |
+
ks = [r["K"] for r in scalability_results]
|
| 1578 |
+
|
| 1579 |
+
# Accuracy metrics
|
| 1580 |
+
axes[0].plot(ks, [r["ufusc_forget_acc"] for r in scalability_results],
|
| 1581 |
+
's-', color='#e74c3c', label='Forget Acc ↓', linewidth=2, markersize=8)
|
| 1582 |
+
axes[0].plot(ks, [r["ufusc_retain_acc"] for r in scalability_results],
|
| 1583 |
+
'o-', color='#2980b9', label='Retain Acc ↑', linewidth=2, markersize=8)
|
| 1584 |
+
axes[0].plot(ks, [r["ufusc_mia_asr"] for r in scalability_results],
|
| 1585 |
+
'^-', color='#27ae60', label='MIA ASR ↓', linewidth=2, markersize=8)
|
| 1586 |
+
axes[0].set_xlabel("Number of Passive Parties (K)")
|
| 1587 |
+
axes[0].set_ylabel("Metric (%)")
|
| 1588 |
+
axes[0].set_title("Metrics vs K")
|
| 1589 |
+
axes[0].legend()
|
| 1590 |
+
axes[0].set_xticks(ks)
|
| 1591 |
+
|
| 1592 |
+
# Time
|
| 1593 |
+
axes[1].bar(ks, [r["time_seconds"] for r in scalability_results],
|
| 1594 |
+
color='#8e44ad', alpha=0.7)
|
| 1595 |
+
axes[1].set_xlabel("Number of Passive Parties (K)")
|
| 1596 |
+
axes[1].set_ylabel("Time (seconds)")
|
| 1597 |
+
axes[1].set_title("Unlearning Time vs K")
|
| 1598 |
+
axes[1].set_xticks(ks)
|
| 1599 |
+
|
| 1600 |
+
plt.tight_layout()
|
| 1601 |
+
plt.savefig("figures/scalability_analysis.png", dpi=150, bbox_inches='tight')
|
| 1602 |
+
plt.close()
|
| 1603 |
+
print(" Saved: figures/scalability_analysis.png")
|
| 1604 |
+
|
| 1605 |
+
# ---- Privacy-Utility Tradeoff Plots ----
|
| 1606 |
+
for dataset_name, results in all_results.items():
|
| 1607 |
+
fig, ax = plt.subplots(figsize=(10, 7))
|
| 1608 |
+
ax.set_title(f"{dataset_name} — Privacy-Utility Tradeoff", fontsize=14, fontweight='bold')
|
| 1609 |
+
|
| 1610 |
+
for r in results:
|
| 1611 |
+
if r["method"] == "Original (No Unlearn)":
|
| 1612 |
+
continue
|
| 1613 |
+
color = colors.get(r["method"], "#333333")
|
| 1614 |
+
marker = 'D' if 'UFUSC' in r["method"] else 'o'
|
| 1615 |
+
size = 200 if 'UFUSC' in r["method"] else 100
|
| 1616 |
+
ax.scatter(r["retain_acc"], 100 - r["mia_asr"],
|
| 1617 |
+
c=color, s=size, marker=marker,
|
| 1618 |
+
label=r["method"], edgecolors='black', linewidth=0.5, zorder=5)
|
| 1619 |
+
|
| 1620 |
+
ax.set_xlabel("Retain Accuracy (%) ↑ — Utility", fontsize=12)
|
| 1621 |
+
ax.set_ylabel("Privacy Protection (100 - MIA ASR) ↑", fontsize=12)
|
| 1622 |
+
ax.legend(fontsize=9, loc='best')
|
| 1623 |
+
ax.grid(True, alpha=0.3)
|
| 1624 |
+
|
| 1625 |
+
# Annotate ideal region
|
| 1626 |
+
ax.annotate("← Better Privacy & Utility →",
|
| 1627 |
+
xy=(0.5, 0.02), xycoords='axes fraction',
|
| 1628 |
+
fontsize=10, ha='center', alpha=0.5, style='italic')
|
| 1629 |
+
|
| 1630 |
+
plt.tight_layout()
|
| 1631 |
+
plt.savefig(f"figures/{dataset_name.replace('-', '_')}_tradeoff.png", dpi=150, bbox_inches='tight')
|
| 1632 |
+
plt.close()
|
| 1633 |
+
print(f" Saved: figures/{dataset_name.replace('-', '_')}_tradeoff.png")
|
| 1634 |
+
|
| 1635 |
+
|
| 1636 |
+
# ============================================================================
|
| 1637 |
+
# Main Execution
|
| 1638 |
+
# ============================================================================
|
| 1639 |
+
|
| 1640 |
+
def main():
|
| 1641 |
+
"""
|
| 1642 |
+
Full experimental pipeline:
|
| 1643 |
+
1. Run experiments on MNIST, Fashion-MNIST, CIFAR-10
|
| 1644 |
+
2. Run ablation study on MNIST
|
| 1645 |
+
3. Run scalability analysis on MNIST
|
| 1646 |
+
4. Generate all visualizations
|
| 1647 |
+
5. Save results to JSON
|
| 1648 |
+
"""
|
| 1649 |
+
print("=" * 70)
|
| 1650 |
+
print(" UFUSC: Unified Federated Unlearning via")
|
| 1651 |
+
print(" Sensitivity-Guided Contrastive Forgetting")
|
| 1652 |
+
print("=" * 70)
|
| 1653 |
+
print(f" Device: {DEVICE}")
|
| 1654 |
+
print(f" Seed: {SEED}")
|
| 1655 |
+
print(f" VFL Parties: {NUM_PASSIVE_PARTIES}")
|
| 1656 |
+
print(f" Batch Size: {BATCH_SIZE}")
|
| 1657 |
+
print(f" Train Epochs: {TRAIN_EPOCHS}")
|
| 1658 |
+
print(f" Unlearn Epochs: {UNLEARN_EPOCHS}")
|
| 1659 |
+
print(f" Forget Ratio: {FORGET_RATIO}")
|
| 1660 |
+
print(f" UFUSC params: α={ALPHA}, β={BETA}, γ={GAMMA}, Ω={OMEGA}, τ={TAU}")
|
| 1661 |
+
print()
|
| 1662 |
+
|
| 1663 |
+
# ---- Main Experiments ----
|
| 1664 |
+
all_results = {}
|
| 1665 |
+
for dataset_name in ["MNIST", "Fashion-MNIST", "CIFAR-10"]:
|
| 1666 |
+
results = run_single_experiment(dataset_name)
|
| 1667 |
+
all_results[dataset_name] = results
|
| 1668 |
+
|
| 1669 |
+
# Save main results
|
| 1670 |
+
with open("results/all_results.json", "w") as f:
|
| 1671 |
+
json.dump(all_results, f, indent=2)
|
| 1672 |
+
print("\n✓ Saved: results/all_results.json")
|
| 1673 |
+
|
| 1674 |
+
# ---- Ablation Study ----
|
| 1675 |
+
ablation_results = run_ablation_study("MNIST")
|
| 1676 |
+
with open("results/ablation_results.json", "w") as f:
|
| 1677 |
+
json.dump(ablation_results, f, indent=2)
|
| 1678 |
+
print("✓ Saved: results/ablation_results.json")
|
| 1679 |
+
|
| 1680 |
+
# ---- Scalability Analysis ----
|
| 1681 |
+
scalability_results = run_scalability_analysis("MNIST")
|
| 1682 |
+
with open("results/scalability_results.json", "w") as f:
|
| 1683 |
+
json.dump(scalability_results, f, indent=2)
|
| 1684 |
+
print("✓ Saved: results/scalability_results.json")
|
| 1685 |
+
|
| 1686 |
+
# ---- Visualizations ----
|
| 1687 |
+
print("\n" + "=" * 70)
|
| 1688 |
+
print(" GENERATING VISUALIZATIONS")
|
| 1689 |
+
print("=" * 70)
|
| 1690 |
+
create_visualizations(all_results, ablation_results, scalability_results)
|
| 1691 |
+
|
| 1692 |
+
# ---- Final Summary ----
|
| 1693 |
+
print("\n" + "=" * 70)
|
| 1694 |
+
print(" FINAL SUMMARY")
|
| 1695 |
+
print("=" * 70)
|
| 1696 |
+
|
| 1697 |
+
for dataset_name, results in all_results.items():
|
| 1698 |
+
joint = next((r for r in results if r["method"] == "UFUSC (Joint)"), None)
|
| 1699 |
+
if joint:
|
| 1700 |
+
print(f"\n {dataset_name}:")
|
| 1701 |
+
print(f" UFUSC-Joint → Retain: {joint['retain_acc']:.1f}%, "
|
| 1702 |
+
f"Forget: {joint['forget_acc']:.1f}%, MIA: {joint['mia_asr']:.1f}%")
|
| 1703 |
+
|
| 1704 |
+
print("\n All experiments complete!")
|
| 1705 |
+
print(f" Results: results/all_results.json")
|
| 1706 |
+
print(f" Ablation: results/ablation_results.json")
|
| 1707 |
+
print(f" Scalability: results/scalability_results.json")
|
| 1708 |
+
print(f" Figures: figures/*.png")
|
| 1709 |
+
print("=" * 70)
|
| 1710 |
+
|
| 1711 |
+
|
| 1712 |
+
if __name__ == "__main__":
|
| 1713 |
+
main()
|