diff --git "a/Geolip_Procrustes_Bert_Model_Step_Model_Scaling.ipynb" "b/Geolip_Procrustes_Bert_Model_Step_Model_Scaling.ipynb" new file mode 100644--- /dev/null +++ "b/Geolip_Procrustes_Bert_Model_Step_Model_Scaling.ipynb" @@ -0,0 +1,3522 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "dp7hwReFXokU", + "0wpUVBCiXmJg" + ], + "machine_shape": "hm", + "gpuType": "G4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d5c2a63f6f8544e79268b9bade807345": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f027a57fcb6a49cfbbd28b95a6c0adf7", + "IPY_MODEL_4df257d047cb4e7ab2a0a6c57be58b81", + "IPY_MODEL_214861fbbd124b45ba164e934f91a024" + ], + "layout": "IPY_MODEL_f6cf4e1cd78c4c0da8bf756a7cc8760a" + } + }, + "f027a57fcb6a49cfbbd28b95a6c0adf7": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bd3077d18b9640abb14a4c385cac7c39", + "placeholder": "​", + "style": "IPY_MODEL_c39efce145764de59d3dc9cb7a818d33", + "value": "Loading weights: 100%" + } + }, + "4df257d047cb4e7ab2a0a6c57be58b81": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b9fc695d6d70408690bbd7ef8bb5841a", + "max": 202, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a1486fd248674e9584e90d907b5156c2", + "value": 202 + } + }, + "214861fbbd124b45ba164e934f91a024": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3472ba9857a543568ec3cd2f340571d6", + "placeholder": "​", + "style": "IPY_MODEL_89eae54541d447bdaa9625cbbe357e55", + "value": " 202/202 [00:00<00:00, 4335.55it/s, Materializing param=cls.predictions.transform.dense.weight]" + } + }, + "f6cf4e1cd78c4c0da8bf756a7cc8760a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bd3077d18b9640abb14a4c385cac7c39": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c39efce145764de59d3dc9cb7a818d33": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b9fc695d6d70408690bbd7ef8bb5841a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a1486fd248674e9584e90d907b5156c2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "3472ba9857a543568ec3cd2f340571d6": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "89eae54541d447bdaa9625cbbe357e55": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ae1bdb89a9704d09a1c03ec82354905e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_8d03191c19b64a698be6d6fc141817cb", + "IPY_MODEL_8d24919b783b47748aefac2a6c234313", + "IPY_MODEL_38bea3901f6c4e339d17a0269f0e535b" + ], + "layout": "IPY_MODEL_e4655e0917814b1b950d53a8f59d1fa5" + } + }, + "8d03191c19b64a698be6d6fc141817cb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_325923e6a4054298be70581c2cd1061d", + "placeholder": "​", + "style": "IPY_MODEL_97bcf95ab5ac4823b2f696f42db534da", + "value": "Loading weights: 100%" + } + }, + "8d24919b783b47748aefac2a6c234313": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_866d9713b1f5436b924a5505e36b9e7d", + "max": 202, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8f481ca713e04111a7dcab82f19a072b", + "value": 202 + } + }, + "38bea3901f6c4e339d17a0269f0e535b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8566ff523c5d4bb5b799cdd12a95c633", + "placeholder": "​", + "style": "IPY_MODEL_bd53b01466524897aea749b68000684a", + "value": " 202/202 [00:00<00:00, 4209.01it/s, Materializing param=cls.predictions.transform.dense.weight]" + } + }, + "e4655e0917814b1b950d53a8f59d1fa5": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "325923e6a4054298be70581c2cd1061d": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "97bcf95ab5ac4823b2f696f42db534da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "866d9713b1f5436b924a5505e36b9e7d": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8f481ca713e04111a7dcab82f19a072b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8566ff523c5d4bb5b799cdd12a95c633": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bd53b01466524897aea749b68000684a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# geolip-rescale experimentation" + ], + "metadata": { + "id": "dp7hwReFXokU" + } + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KrVSx-_rMPBx", + "outputId": "0972f0ef-a170-4554-cfac-13f98e6a894a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "======================================================================\n", + "ITERATIVE MULTI-SCALE GEOMETRIC TRANSFER\n", + "======================================================================\n", + " Device: cuda\n", + "\n", + " Task: 16-class pattern recognition, seq_len=64, noise=0.3\n", + " Chance accuracy: 6.2%\n", + " Scales: 256 → 224 → 192 → 160 → 131 → 113 → 97 → 73 → 64\n", + " CV tolerance: ±0.05\n", + "\n", + "======================================================================\n", + "SCALE 0: 256-dim (ROOT — train from scratch)\n", + "======================================================================\n", + " Params: 153,872\n", + " CV before training: 0.0771\n", + " Trained: 200 epochs → acc=0.8720, cv=0.0839\n", + " layer_0: CV=0.1256 eff_rank=62.2\n", + " layer_1: CV=0.0966 eff_rank=203.9\n", + " layer_2: CV=0.1023 eff_rank=194.6\n", + " layer_3: CV=0.0308 eff_rank=15.9\n", + "\n", + "======================================================================\n", + "SCALE 1: 256-dim → 224-dim (12% reduction)\n", + "======================================================================\n", + " Params: 120,304 (78.2% of root)\n", + "\n", + " Projecting 256 → 224...\n", + " After transfer: acc=0.8214, cv=0.1000\n", + " layer_0: CV=0.1393 eff_rank=61.9\n", + " layer_1: CV=0.0941 eff_rank=176.7\n", + " layer_2: CV=0.1014 eff_rank=169.4\n", + " layer_3: CV=0.0395 eff_rank=15.9\n", + "\n", + " Healing toward parent CV=0.0839 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8600, cv=0.0950\n", + " layer_0: CV=0.1305 eff_rank=61.9\n", + " layer_1: CV=0.0888 eff_rank=176.7\n", + " layer_2: CV=0.1015 eff_rank=170.0\n", + " layer_3: CV=0.0421 eff_rank=15.9\n", + "\n", + "======================================================================\n", + "SCALE 2: 224-dim → 192-dim (14% reduction)\n", + "======================================================================\n", + " Params: 90,832 (59.0% of root)\n", + "\n", + " Projecting 224 → 192...\n", + " After transfer: acc=0.8190, cv=0.0968\n", + " layer_0: CV=0.1230 eff_rank=61.4\n", + " layer_1: CV=0.1080 eff_rank=151.3\n", + " layer_2: CV=0.1105 eff_rank=145.6\n", + " layer_3: CV=0.0510 eff_rank=15.9\n", + "\n", + " Healing toward parent CV=0.0950 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8855, cv=0.0971\n", + " layer_0: CV=0.1356 eff_rank=61.5\n", + " layer_1: CV=0.0970 eff_rank=151.3\n", + " layer_2: CV=0.1015 eff_rank=145.9\n", + " layer_3: CV=0.0538 eff_rank=15.9\n", + "\n", + "======================================================================\n", + "SCALE 3: 192-dim → 160-dim (17% reduction)\n", + "======================================================================\n", + " Params: 65,456 (42.5% of root)\n", + "\n", + " Projecting 192 → 160...\n", + " After transfer: acc=0.7910, cv=0.0999\n", + " layer_0: CV=0.1229 eff_rank=60.8\n", + " layer_1: CV=0.1212 eff_rank=125.2\n", + " layer_2: CV=0.1105 eff_rank=122.0\n", + " layer_3: CV=0.0496 eff_rank=15.8\n", + "\n", + " Healing toward parent CV=0.0971 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8845, cv=0.1006\n", + " layer_0: CV=0.1307 eff_rank=60.8\n", + " layer_1: CV=0.1207 eff_rank=125.3\n", + " layer_2: CV=0.1217 eff_rank=122.1\n", + " layer_3: CV=0.0535 eff_rank=15.9\n", + "\n", + "======================================================================\n", + "SCALE 4: 160-dim → 131-dim (18% reduction)\n", + "======================================================================\n", + " Params: 45,997 (29.9% of root)\n", + "\n", + " Projecting 160 → 131...\n", + " After transfer: acc=0.7444, cv=0.1079\n", + " layer_0: CV=0.1331 eff_rank=59.9\n", + " layer_1: CV=0.1135 eff_rank=102.2\n", + " layer_2: CV=0.1180 eff_rank=100.1\n", + " layer_3: CV=0.0790 eff_rank=15.8\n", + "\n", + " Healing toward parent CV=0.1006 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8670, cv=0.1126\n", + " layer_0: CV=0.1277 eff_rank=59.9\n", + " layer_1: CV=0.1292 eff_rank=102.2\n", + " layer_2: CV=0.1101 eff_rank=99.9\n", + " layer_3: CV=0.0695 eff_rank=15.8\n", + "\n", + "======================================================================\n", + "SCALE 5: 131-dim → 113-dim (14% reduction)\n", + "======================================================================\n", + " Params: 35,611 (23.1% of root)\n", + "\n", + " Projecting 131 → 113...\n", + " After transfer: acc=0.7944, cv=0.1136\n", + " layer_0: CV=0.1412 eff_rank=59.0\n", + " layer_1: CV=0.1237 eff_rank=88.6\n", + " layer_2: CV=0.1269 eff_rank=86.3\n", + " layer_3: CV=0.0863 eff_rank=15.8\n", + "\n", + " Healing toward parent CV=0.1126 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8775, cv=0.1123\n", + " layer_0: CV=0.1186 eff_rank=59.0\n", + " layer_1: CV=0.1226 eff_rank=88.6\n", + " layer_2: CV=0.1189 eff_rank=86.1\n", + " layer_3: CV=0.0726 eff_rank=15.8\n", + "\n", + "======================================================================\n", + "SCALE 6: 113-dim → 97-dim (14% reduction)\n", + "======================================================================\n", + " Params: 27,467 (17.9% of root)\n", + "\n", + " Projecting 113 → 97...\n", + " After transfer: acc=0.7706, cv=0.1218\n", + " layer_0: CV=0.1447 eff_rank=57.8\n", + " layer_1: CV=0.1389 eff_rank=76.0\n", + " layer_2: CV=0.1464 eff_rank=73.9\n", + " layer_3: CV=0.0872 eff_rank=15.7\n", + "\n", + " Healing toward parent CV=0.1123 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8685, cv=0.1193\n", + " layer_0: CV=0.1383 eff_rank=57.8\n", + " layer_1: CV=0.1388 eff_rank=76.0\n", + " layer_2: CV=0.1364 eff_rank=73.5\n", + " layer_3: CV=0.0780 eff_rank=15.8\n", + "\n", + "======================================================================\n", + "SCALE 7: 97-dim → 73-dim (25% reduction)\n", + "======================================================================\n", + " Params: 17,171 (11.2% of root)\n", + "\n", + " Projecting 97 → 73...\n", + " After transfer: acc=0.6966, cv=0.1427\n", + " layer_0: CV=0.1180 eff_rank=54.3\n", + " layer_1: CV=0.1580 eff_rank=56.0\n", + " layer_2: CV=0.1436 eff_rank=54.4\n", + " layer_3: CV=0.1031 eff_rank=15.6\n", + "\n", + " Healing toward parent CV=0.1193 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8670, cv=0.1325\n", + " layer_0: CV=0.1279 eff_rank=54.3\n", + " layer_1: CV=0.1600 eff_rank=56.0\n", + " layer_2: CV=0.1592 eff_rank=54.3\n", + " layer_3: CV=0.0945 eff_rank=15.7\n", + "\n", + "======================================================================\n", + "SCALE 8: 73-dim → 64-dim (12% reduction)\n", + "======================================================================\n", + " Params: 13,904 (9.0% of root)\n", + "\n", + " Projecting 73 → 64...\n", + " After transfer: acc=0.7416, cv=0.1433\n", + " layer_0: CV=0.1420 eff_rank=51.5\n", + " layer_1: CV=0.1614 eff_rank=49.2\n", + " layer_2: CV=0.1705 eff_rank=48.2\n", + " layer_3: CV=0.1047 eff_rank=15.6\n", + "\n", + " Healing toward parent CV=0.1325 (±0.05)...\n", + " Healed: 1 epochs (0.5s) → acc=0.8455, cv=0.1372\n", + " layer_0: CV=0.1499 eff_rank=51.6\n", + " layer_1: CV=0.1602 eff_rank=49.2\n", + " layer_2: CV=0.1749 eff_rank=47.9\n", + " layer_3: CV=0.0936 eff_rank=15.7\n", + "\n", + "======================================================================\n", + "BASELINE: Train 64-dim from scratch\n", + "======================================================================\n", + " Trained: 200 epochs → acc=0.8560, cv=0.1436\n", + "\n", + "======================================================================\n", + "DIRECT PROJECTION: 256 → 64 (single jump)\n", + "======================================================================\n", + " After direct transfer: acc=0.2960, cv=0.1740\n", + "\n", + "======================================================================\n", + "RESULTS — ITERATIVE CASCADE\n", + "======================================================================\n", + "\n", + " Scale Params Acc(proj) Acc(heal) CV(proj) CV(heal) Epochs\n", + " ──────── ──────── ────────── ────────── ───────── ───────── ───────\n", + " 256 153,872 0.8720 0.8720 0.0839 0.0839 200\n", + " 224 120,304 0.8214 0.8600 0.1000 0.0950 1\n", + " 192 90,832 0.8190 0.8855 0.0968 0.0971 1\n", + " 160 65,456 0.7910 0.8845 0.0999 0.1006 1\n", + " 131 45,997 0.7444 0.8670 0.1079 0.1126 1\n", + " 113 35,611 0.7944 0.8775 0.1136 0.1123 1\n", + " 97 27,467 0.7706 0.8685 0.1218 0.1193 1\n", + " 73 17,171 0.6966 0.8670 0.1427 0.1325 1\n", + " 64 13,904 0.7416 0.8455 0.1433 0.1372 1\n", + "\n", + " COMPARISONS\n", + " ────────────────────────────────────────────────────────────\n", + " Cascade 64-dim: acc=0.8455 cv=0.1372 (total 8 heal epochs)\n", + " Direct proj 64-dim: acc=0.2960 cv=0.1740 (0 training)\n", + " Scratch 64-dim: acc=0.8560 cv=0.1436 (200 epochs)\n", + " Chance: acc=0.0625\n", + "\n", + " COMPRESSION:\n", + " Root: 153,872 params\n", + " Target: 13,904 params (9.0%)\n", + " Ratio: 11.1×\n", + "\n", + " GEOMETRIC PRESERVATION:\n", + " Root CV: 0.0839\n", + " Final CV: 0.1372\n", + " Δ CV: 0.0533\n", + " Direct CV: 0.1740\n", + " Scratch CV: 0.1436\n", + "\n", + "Done.\n" + ] + } + ], + "source": [ + "# ============================================================================\n", + "# ITERATIVE MULTI-SCALE GEOMETRIC TRANSFER\n", + "#\n", + "# Cascade: 256 → 224 → 192 → 160 → 131\n", + "# At each scale:\n", + "# 1. Procrustes-project from parent\n", + "# 2. Measure accuracy + CV\n", + "# 3. Train ONLY until CV reaches parent's CV band (±tolerance)\n", + "# 4. Measure accuracy again\n", + "# 5. Project down to next scale\n", + "#\n", + "# The hypothesis: small iterative steps preserve geometric structure\n", + "# better than one large jump, because each intermediate model can\n", + "# \"heal\" the projection distortion through minimal training.\n", + "# ============================================================================\n", + "\n", + "import math\n", + "import time\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# GEOMETRIC UTILITIES\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def cayley_menger_vol2(pts):\n", + " with torch.amp.autocast(\"cuda\", enabled=False):\n", + " pts = pts.float()\n", + " diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)\n", + " d2 = (diff * diff).sum(-1)\n", + " B, V, _ = d2.shape\n", + " cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)\n", + " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", + " s = (-1.0)**V; f = math.factorial(V-1)\n", + " return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)\n", + "\n", + "\n", + "def pentachoron_cv(embeddings, n_samples=200):\n", + " B = embeddings.shape[0]\n", + " if B < 5:\n", + " return 0.0\n", + " vols = []\n", + " for _ in range(n_samples):\n", + " idx = torch.randperm(B, device=embeddings.device)[:5]\n", + " v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))\n", + " v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()\n", + " if v > 0:\n", + " vols.append(v)\n", + " if len(vols) < 10:\n", + " return 0.0\n", + " a = np.array(vols, dtype=np.float64)\n", + " return float(a.std() / max(a.mean(), 1e-12))\n", + "\n", + "\n", + "def profile_model(model):\n", + " \"\"\"Profile all linear layers: CV, effective rank.\"\"\"\n", + " results = {}\n", + " for i, layer in enumerate(model.get_linear_layers()):\n", + " W = layer.weight.detach().float()\n", + " cv = pentachoron_cv(W, n_samples=200)\n", + " S = torch.linalg.svdvals(W)\n", + " S_norm = S / S.sum()\n", + " eff_rank = torch.exp(-torch.sum(S_norm * torch.log(S_norm + 1e-12))).item()\n", + " results[f\"layer_{i}\"] = {\"cv\": cv, \"eff_rank\": eff_rank}\n", + " mean_cv = np.mean([v[\"cv\"] for v in results.values()])\n", + " return results, mean_cv\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# TASK: Multi-class sequence pattern recognition (harder than needle)\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "class PatternTask:\n", + " \"\"\"\n", + " Multi-pattern classification. Each class has a distinct learned template.\n", + " The model must learn to recognize WHICH pattern, not just WHERE.\n", + " This forces genuine geometric restructuring during training.\n", + "\n", + " Input: (B, seq_len) — noisy pattern\n", + " Target: (B,) — pattern class\n", + " \"\"\"\n", + " def __init__(self, n_classes=16, seq_len=64, noise=0.3, device=\"cpu\"):\n", + " self.n_classes = n_classes\n", + " self.seq_len = seq_len\n", + " self.noise = noise\n", + " self.device = device\n", + "\n", + " # Fixed random templates — each class has a unique pattern\n", + " torch.manual_seed(42)\n", + " self.templates = torch.randn(n_classes, seq_len, device=device)\n", + " self.templates = F.normalize(self.templates, dim=-1)\n", + "\n", + " def generate(self, n_samples):\n", + " labels = torch.randint(0, self.n_classes, (n_samples,), device=self.device)\n", + " patterns = self.templates[labels]\n", + " noise = torch.randn_like(patterns) * self.noise\n", + " inputs = patterns + noise\n", + " return inputs, labels\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# MODEL\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "class PatternModel(nn.Module):\n", + " def __init__(self, seq_len, hidden_dim, n_classes, n_layers=4):\n", + " super().__init__()\n", + " self.seq_len = seq_len\n", + " self.hidden_dim = hidden_dim\n", + " self.n_classes = n_classes\n", + "\n", + " layers = []\n", + " layers.append(nn.Linear(seq_len, hidden_dim))\n", + " layers.append(nn.GELU())\n", + " layers.append(nn.LayerNorm(hidden_dim))\n", + " for _ in range(n_layers - 2):\n", + " layers.append(nn.Linear(hidden_dim, hidden_dim))\n", + " layers.append(nn.GELU())\n", + " layers.append(nn.LayerNorm(hidden_dim))\n", + " layers.append(nn.Linear(hidden_dim, n_classes))\n", + " self.network = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " return self.network(x)\n", + "\n", + " def get_linear_layers(self):\n", + " return [m for m in self.network.modules() if isinstance(m, nn.Linear)]\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# PROCRUSTES PROJECTION: truncated SVD\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def svd_project(W_large, out_dim, in_dim):\n", + " \"\"\"Project weight matrix via truncated SVD reconstruction.\"\"\"\n", + " W = W_large.float()\n", + " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", + " k = min(S.shape[0], out_dim, in_dim)\n", + " U_k = U[:min(W.shape[0], out_dim), :k]\n", + " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", + " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", + " result = torch.zeros(out_dim, in_dim, device=W.device)\n", + " r, c = W_small.shape\n", + " result[:r, :c] = W_small\n", + " return result\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def transfer_weights(source, target):\n", + " \"\"\"Procrustes-project all linear layers + layernorms from source → target.\"\"\"\n", + " src_layers = source.get_linear_layers()\n", + " tgt_layers = target.get_linear_layers()\n", + "\n", + " for L, S in zip(src_layers, tgt_layers):\n", + " to, ti = S.weight.shape\n", + " S.weight.data.copy_(svd_project(L.weight.data, to, ti))\n", + "\n", + " if L.bias is not None and S.bias is not None:\n", + " b = L.bias.data.float()\n", + " if b.shape[0] > to:\n", + " U, _, _ = torch.linalg.svd(L.weight.data.float(), full_matrices=True)\n", + " S.bias.data.copy_(U[:, :to].T @ b)\n", + " elif b.shape[0] < to:\n", + " S.bias.data.zero_()\n", + " S.bias.data[:b.shape[0]].copy_(b)\n", + " else:\n", + " S.bias.data.copy_(b)\n", + "\n", + " # LayerNorms\n", + " src_norms = [m for m in source.network.modules() if isinstance(m, nn.LayerNorm)]\n", + " tgt_norms = [m for m in target.network.modules() if isinstance(m, nn.LayerNorm)]\n", + " for ln_s, ln_t in zip(src_norms, tgt_norms):\n", + " d = min(ln_s.weight.shape[0], ln_t.weight.shape[0])\n", + " ln_t.weight.data[:d].copy_(ln_s.weight.data[:d])\n", + " ln_t.bias.data[:d].copy_(ln_s.bias.data[:d])\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# CV-GATED TRAINING\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def train_until_cv(model, task, target_cv, cv_tolerance=0.05,\n", + " max_epochs=200, lr=3e-4, batch_size=256):\n", + " \"\"\"\n", + " Train until mean CV reaches target ± tolerance.\n", + " Returns: epochs_used, final_acc, final_cv\n", + " \"\"\"\n", + " device = next(model.parameters()).device\n", + " train_x, train_y = task.generate(10000)\n", + " test_x, test_y = task.generate(2000)\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", + " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)\n", + "\n", + " for epoch in range(max_epochs):\n", + " model.train()\n", + " perm = torch.randperm(train_x.shape[0], device=device)\n", + " for i in range(0, train_x.shape[0], batch_size):\n", + " idx = perm[i:i+batch_size]\n", + " loss = F.cross_entropy(model(train_x[idx]), train_y[idx])\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " scheduler.step()\n", + "\n", + " # Check CV every 5 epochs\n", + " if (epoch + 1) % 5 == 0 or epoch == 0:\n", + " model.eval()\n", + " _, mean_cv = profile_model(model)\n", + " with torch.no_grad():\n", + " acc = (model(test_x).argmax(-1) == test_y).float().mean().item()\n", + "\n", + " if abs(mean_cv - target_cv) <= cv_tolerance:\n", + " return epoch + 1, acc, mean_cv\n", + " if acc >= 0.99 and epoch > 20:\n", + " return epoch + 1, acc, mean_cv\n", + "\n", + " # Max epochs reached\n", + " model.eval()\n", + " _, final_cv = profile_model(model)\n", + " with torch.no_grad():\n", + " final_acc = (model(test_x).argmax(-1) == test_y).float().mean().item()\n", + " return max_epochs, final_acc, final_cv\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# EXPERIMENT\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def run_experiment():\n", + " print(\"=\" * 70)\n", + " print(\"ITERATIVE MULTI-SCALE GEOMETRIC TRANSFER\")\n", + " print(\"=\" * 70)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " print(f\" Device: {device}\")\n", + "\n", + " # ── Configuration ──\n", + " SEQ_LEN = 64\n", + " N_CLASSES = 16\n", + " NOISE = 0.3\n", + " N_LAYERS = 4\n", + " SCALES = [256, 224, 192, 160, 131, 113, 97, 73, 64] # cascade\n", + " CV_TOLERANCE = 0.05\n", + " MAX_HEAL_EPOCHS = 500\n", + "\n", + " print(f\"\\n Task: {N_CLASSES}-class pattern recognition, seq_len={SEQ_LEN}, noise={NOISE}\")\n", + " print(f\" Chance accuracy: {1/N_CLASSES:.1%}\")\n", + " print(f\" Scales: {' → '.join(str(s) for s in SCALES)}\")\n", + " print(f\" CV tolerance: ±{CV_TOLERANCE}\")\n", + "\n", + " task = PatternTask(N_CLASSES, SEQ_LEN, NOISE, device)\n", + " test_x, test_y = task.generate(5000)\n", + "\n", + " # ══════════════════════════════════════════════════════════\n", + " # STEP 0: Train root model (256-dim) to convergence\n", + " # ══════════════════════════════════════════════════════════\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"SCALE 0: {SCALES[0]}-dim (ROOT — train from scratch)\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " root = PatternModel(SEQ_LEN, SCALES[0], N_CLASSES, N_LAYERS).to(device)\n", + " n_root = sum(p.numel() for p in root.parameters())\n", + " print(f\" Params: {n_root:,}\")\n", + "\n", + " # Profile before training\n", + " _, cv_before = profile_model(root)\n", + " print(f\" CV before training: {cv_before:.4f}\")\n", + "\n", + " epochs, acc, cv = train_until_cv(root, task, target_cv=0.20,\n", + " cv_tolerance=0.1, max_epochs=200)\n", + " print(f\" Trained: {epochs} epochs → acc={acc:.4f}, cv={cv:.4f}\")\n", + "\n", + " # Full profile\n", + " profile, _ = profile_model(root)\n", + " for name, stats in profile.items():\n", + " print(f\" {name}: CV={stats['cv']:.4f} eff_rank={stats['eff_rank']:.1f}\")\n", + "\n", + " # ══════════════════════════════════════════════════════════\n", + " # ITERATIVE CASCADE\n", + " # ══════════════════════════════════════════════════════════\n", + "\n", + " results = [{\n", + " \"scale\": SCALES[0],\n", + " \"params\": n_root,\n", + " \"acc_after_transfer\": acc,\n", + " \"acc_after_heal\": acc,\n", + " \"cv_after_transfer\": cv,\n", + " \"cv_after_heal\": cv,\n", + " \"heal_epochs\": epochs,\n", + " \"source\": \"scratch\",\n", + " }]\n", + "\n", + " parent_model = root\n", + " parent_cv = cv\n", + "\n", + " for i in range(1, len(SCALES)):\n", + " dim = SCALES[i]\n", + " parent_dim = SCALES[i-1]\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"SCALE {i}: {parent_dim}-dim → {dim}-dim \"\n", + " f\"({(parent_dim-dim)/parent_dim:.0%} reduction)\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " # Build target model\n", + " child = PatternModel(SEQ_LEN, dim, N_CLASSES, N_LAYERS).to(device)\n", + " n_child = sum(p.numel() for p in child.parameters())\n", + " print(f\" Params: {n_child:,} ({n_child/n_root:.1%} of root)\")\n", + "\n", + " # ── Transfer ──\n", + " print(f\"\\n Projecting {parent_dim} → {dim}...\")\n", + " transfer_weights(parent_model, child)\n", + "\n", + " # Measure immediately after transfer (no training)\n", + " child.eval()\n", + " _, cv_transfer = profile_model(child)\n", + " with torch.no_grad():\n", + " acc_transfer = (child(test_x).argmax(-1) == test_y).float().mean().item()\n", + " print(f\" After transfer: acc={acc_transfer:.4f}, cv={cv_transfer:.4f}\")\n", + "\n", + " child_profile, _ = profile_model(child)\n", + " for name, stats in child_profile.items():\n", + " print(f\" {name}: CV={stats['cv']:.4f} eff_rank={stats['eff_rank']:.1f}\")\n", + "\n", + " # ── Heal: train until CV matches parent ──\n", + " print(f\"\\n Healing toward parent CV={parent_cv:.4f} (±{CV_TOLERANCE})...\")\n", + " t0 = time.time()\n", + " heal_epochs, acc_heal, cv_heal = train_until_cv(\n", + " child, task, target_cv=parent_cv,\n", + " cv_tolerance=CV_TOLERANCE, max_epochs=MAX_HEAL_EPOCHS)\n", + " elapsed = time.time() - t0\n", + " print(f\" Healed: {heal_epochs} epochs ({elapsed:.1f}s) → \"\n", + " f\"acc={acc_heal:.4f}, cv={cv_heal:.4f}\")\n", + "\n", + " # Post-heal profile\n", + " heal_profile, _ = profile_model(child)\n", + " for name, stats in heal_profile.items():\n", + " print(f\" {name}: CV={stats['cv']:.4f} eff_rank={stats['eff_rank']:.1f}\")\n", + "\n", + " results.append({\n", + " \"scale\": dim,\n", + " \"params\": n_child,\n", + " \"acc_after_transfer\": acc_transfer,\n", + " \"acc_after_heal\": acc_heal,\n", + " \"cv_after_transfer\": cv_transfer,\n", + " \"cv_after_heal\": cv_heal,\n", + " \"heal_epochs\": heal_epochs,\n", + " \"source\": f\"projected from {parent_dim}\",\n", + " })\n", + "\n", + " # This child becomes next parent\n", + " parent_model = child\n", + " parent_cv = cv_heal\n", + "\n", + " # ══════════════════════════════════════════════════════════\n", + " # BASELINE: Train 131-dim from scratch\n", + " # ══════════════════════════════════════════════════════════\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"BASELINE: Train {SCALES[-1]}-dim from scratch\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " baseline = PatternModel(SEQ_LEN, SCALES[-1], N_CLASSES, N_LAYERS).to(device)\n", + " n_base = sum(p.numel() for p in baseline.parameters())\n", + " base_epochs, base_acc, base_cv = train_until_cv(\n", + " baseline, task, target_cv=0.20,\n", + " cv_tolerance=0.05, max_epochs=200)\n", + " print(f\" Trained: {base_epochs} epochs → acc={base_acc:.4f}, cv={base_cv:.4f}\")\n", + "\n", + " # ══════════════════════════════════════════════════════════\n", + " # DIRECT PROJECTION: 256 → 131 (single jump baseline)\n", + " # ══════════════════════════════════════════════════════════\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"DIRECT PROJECTION: {SCALES[0]} → {SCALES[-1]} (single jump)\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " direct = PatternModel(SEQ_LEN, SCALES[-1], N_CLASSES, N_LAYERS).to(device)\n", + " transfer_weights(root, direct)\n", + " direct.eval()\n", + " _, direct_cv = profile_model(direct)\n", + " with torch.no_grad():\n", + " direct_acc = (direct(test_x).argmax(-1) == test_y).float().mean().item()\n", + " print(f\" After direct transfer: acc={direct_acc:.4f}, cv={direct_cv:.4f}\")\n", + "\n", + " # ══════════════════════════════════════════════════════════\n", + " # FINAL REPORT\n", + " # ══════════════════════════════════════════════════════════\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"RESULTS — ITERATIVE CASCADE\")\n", + " print(f\"{'='*70}\\n\")\n", + "\n", + " print(f\" {'Scale':<8s} {'Params':>8s} {'Acc(proj)':>10s} {'Acc(heal)':>10s} \"\n", + " f\"{'CV(proj)':>9s} {'CV(heal)':>9s} {'Epochs':>7s}\")\n", + " print(f\" {'─'*8} {'─'*8} {'─'*10} {'─'*10} {'─'*9} {'─'*9} {'─'*7}\")\n", + "\n", + " total_heal_epochs = 0\n", + " for r in results:\n", + " print(f\" {r['scale']:<8d} {r['params']:>8,} {r['acc_after_transfer']:>10.4f} \"\n", + " f\"{r['acc_after_heal']:>10.4f} {r['cv_after_transfer']:>9.4f} \"\n", + " f\"{r['cv_after_heal']:>9.4f} {r['heal_epochs']:>7d}\")\n", + " if r['source'] != 'scratch':\n", + " total_heal_epochs += r['heal_epochs']\n", + "\n", + " print(f\"\\n {'COMPARISONS':}\")\n", + " print(f\" {'─'*60}\")\n", + " print(f\" Cascade {SCALES[-1]}-dim: acc={results[-1]['acc_after_heal']:.4f} \"\n", + " f\"cv={results[-1]['cv_after_heal']:.4f} \"\n", + " f\"(total {total_heal_epochs} heal epochs)\")\n", + " print(f\" Direct proj {SCALES[-1]}-dim: acc={direct_acc:.4f} \"\n", + " f\"cv={direct_cv:.4f} (0 training)\")\n", + " print(f\" Scratch {SCALES[-1]}-dim: acc={base_acc:.4f} \"\n", + " f\"cv={base_cv:.4f} ({base_epochs} epochs)\")\n", + " print(f\" Chance: acc={1/N_CLASSES:.4f}\")\n", + "\n", + " print(f\"\\n COMPRESSION:\")\n", + " print(f\" Root: {n_root:>8,} params\")\n", + " print(f\" Target: {n_child:>8,} params ({n_child/n_root:.1%})\")\n", + " print(f\" Ratio: {n_root/n_child:.1f}×\")\n", + "\n", + " # ── Geometric preservation ──\n", + " root_cv = results[0][\"cv_after_heal\"]\n", + " final_cv = results[-1][\"cv_after_heal\"]\n", + " print(f\"\\n GEOMETRIC PRESERVATION:\")\n", + " print(f\" Root CV: {root_cv:.4f}\")\n", + " print(f\" Final CV: {final_cv:.4f}\")\n", + " print(f\" Δ CV: {abs(root_cv - final_cv):.4f}\")\n", + " print(f\" Direct CV: {direct_cv:.4f}\")\n", + " print(f\" Scratch CV: {base_cv:.4f}\")\n", + "\n", + " print(f\"\\nDone.\")\n", + " return results\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " results = run_experiment()" + ] + }, + { + "cell_type": "code", + "source": [ + "# ============================================================================\n", + "# OPTIMAL SCALING RATIO EXPERIMENT\n", + "#\n", + "# Sweep: What ratio between consecutive scales minimizes accuracy loss\n", + "# while maximizing compression?\n", + "#\n", + "# For each ratio r ∈ {0.50, 0.55, 0.60, 0.618, 0.65, 0.70, 0.707, 0.75,\n", + "# 0.80, 0.85, 0.90, 0.95}:\n", + "# - Build cascade from 256 → 64 using steps of dim[i+1] = round(dim[i] * r)\n", + "# - Apply iterative project + 1-epoch heal at each step\n", + "# - Measure: final accuracy, total heal epochs, CV preservation\n", + "#\n", + "# Natural candidates:\n", + "# φ⁻¹ = 0.6180 (golden ratio inverse — nature's scaling constant)\n", + "# 2⁻⁰·⁵ = 0.7071 (inverse sqrt 2 — octave halving)\n", + "# 1-0.29514 = 0.7049 (Phil's recurring ratio complement)\n", + "# e⁻¹ = 0.3679 (too aggressive, but worth checking)\n", + "# ============================================================================\n", + "\n", + "import math\n", + "import time\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "\n", + "\n", + "# ═════════════════════════���════════════════════════════════════════\n", + "# GEOMETRIC UTILITIES\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def cayley_menger_vol2(pts):\n", + " with torch.amp.autocast(\"cuda\", enabled=False):\n", + " pts = pts.float()\n", + " diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)\n", + " d2 = (diff * diff).sum(-1)\n", + " B, V, _ = d2.shape\n", + " cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)\n", + " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", + " s = (-1.0)**V; f = math.factorial(V-1)\n", + " return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)\n", + "\n", + "\n", + "def pentachoron_cv(embeddings, n_samples=100):\n", + " B = embeddings.shape[0]\n", + " if B < 5:\n", + " return 0.0\n", + " vols = []\n", + " for _ in range(n_samples):\n", + " idx = torch.randperm(B, device=embeddings.device)[:5]\n", + " v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))\n", + " v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()\n", + " if v > 0:\n", + " vols.append(v)\n", + " if len(vols) < 10:\n", + " return 0.0\n", + " a = np.array(vols, dtype=np.float64)\n", + " return float(a.std() / max(a.mean(), 1e-12))\n", + "\n", + "\n", + "def profile_model(model):\n", + " results = {}\n", + " for i, layer in enumerate(model.get_linear_layers()):\n", + " W = layer.weight.detach().float()\n", + " cv = pentachoron_cv(W, n_samples=100)\n", + " results[f\"layer_{i}\"] = {\"cv\": cv}\n", + " mean_cv = np.mean([v[\"cv\"] for v in results.values()])\n", + " return mean_cv\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# TASK\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "class PatternTask:\n", + " def __init__(self, n_classes=16, seq_len=64, noise=0.3, device=\"cpu\"):\n", + " self.n_classes = n_classes\n", + " self.seq_len = seq_len\n", + " self.noise = noise\n", + " self.device = device\n", + " torch.manual_seed(42)\n", + " self.templates = F.normalize(\n", + " torch.randn(n_classes, seq_len, device=device), dim=-1)\n", + "\n", + " def generate(self, n_samples):\n", + " labels = torch.randint(0, self.n_classes, (n_samples,), device=self.device)\n", + " patterns = self.templates[labels]\n", + " return patterns + torch.randn_like(patterns) * self.noise, labels\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# MODEL\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "class PatternModel(nn.Module):\n", + " def __init__(self, seq_len, hidden_dim, n_classes, n_layers=4):\n", + " super().__init__()\n", + " layers = []\n", + " layers.append(nn.Linear(seq_len, hidden_dim))\n", + " layers.append(nn.GELU())\n", + " layers.append(nn.LayerNorm(hidden_dim))\n", + " for _ in range(n_layers - 2):\n", + " layers.append(nn.Linear(hidden_dim, hidden_dim))\n", + " layers.append(nn.GELU())\n", + " layers.append(nn.LayerNorm(hidden_dim))\n", + " layers.append(nn.Linear(hidden_dim, n_classes))\n", + " self.network = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " return self.network(x)\n", + "\n", + " def get_linear_layers(self):\n", + " return [m for m in self.network.modules() if isinstance(m, nn.Linear)]\n", + "\n", + "\n", + "# ══════════════════════════���═══════════════════════════════════════\n", + "# PROJECTION + HEALING\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def svd_project(W_large, out_dim, in_dim):\n", + " W = W_large.float()\n", + " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", + " k = min(S.shape[0], out_dim, in_dim)\n", + " U_k = U[:min(W.shape[0], out_dim), :k]\n", + " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", + " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", + " result = torch.zeros(out_dim, in_dim, device=W.device)\n", + " r, c = W_small.shape\n", + " result[:r, :c] = W_small\n", + " return result\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def transfer_weights(source, target):\n", + " src_layers = source.get_linear_layers()\n", + " tgt_layers = target.get_linear_layers()\n", + " for L, Sm in zip(src_layers, tgt_layers):\n", + " to, ti = Sm.weight.shape\n", + " Sm.weight.data.copy_(svd_project(L.weight.data, to, ti))\n", + " if L.bias is not None and Sm.bias is not None:\n", + " b = L.bias.data.float()\n", + " if b.shape[0] > to:\n", + " U, _, _ = torch.linalg.svd(L.weight.data.float(), full_matrices=True)\n", + " Sm.bias.data.copy_(U[:, :to].T @ b)\n", + " elif b.shape[0] < to:\n", + " Sm.bias.data.zero_()\n", + " Sm.bias.data[:b.shape[0]].copy_(b)\n", + " else:\n", + " Sm.bias.data.copy_(b)\n", + " src_norms = [m for m in source.network.modules() if isinstance(m, nn.LayerNorm)]\n", + " tgt_norms = [m for m in target.network.modules() if isinstance(m, nn.LayerNorm)]\n", + " for ln_s, ln_t in zip(src_norms, tgt_norms):\n", + " d = min(ln_s.weight.shape[0], ln_t.weight.shape[0])\n", + " ln_t.weight.data[:d].copy_(ln_s.weight.data[:d])\n", + " ln_t.bias.data[:d].copy_(ln_s.bias.data[:d])\n", + "\n", + "\n", + "def heal_one_epoch(model, task, batch_size=256, lr=3e-4):\n", + " \"\"\"Single healing epoch. Returns accuracy after.\"\"\"\n", + " device = next(model.parameters()).device\n", + " train_x, train_y = task.generate(10000)\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", + " model.train()\n", + " perm = torch.randperm(train_x.shape[0], device=device)\n", + " for i in range(0, train_x.shape[0], batch_size):\n", + " idx = perm[i:i+batch_size]\n", + " loss = F.cross_entropy(model(train_x[idx]), train_y[idx])\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " return model\n", + "\n", + "\n", + "def evaluate(model, test_x, test_y):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " return (model(test_x).argmax(-1) == test_y).float().mean().item()\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# BUILD SCALE CASCADE FOR A GIVEN RATIO\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def build_cascade(start_dim, end_dim, ratio):\n", + " \"\"\"Generate dimension sequence: start, round(start*r), round(start*r²), ... until ≤ end.\"\"\"\n", + " dims = [start_dim]\n", + " while True:\n", + " next_dim = max(round(dims[-1] * ratio), end_dim)\n", + " if next_dim >= dims[-1]:\n", + " # Ratio too close to 1, force a step down\n", + " next_dim = dims[-1] - 1\n", + " if next_dim <= end_dim:\n", + " if dims[-1] != end_dim:\n", + " dims.append(end_dim)\n", + " break\n", + " dims.append(next_dim)\n", + " return dims\n", + "\n", + "\n", + "def run_cascade(root_model, task, test_x, test_y, scales, device):\n", + " \"\"\"\n", + " Run a full cascade: project + heal at each scale.\n", + " Returns: final_acc, total_heal_epochs, per-step data.\n", + " \"\"\"\n", + " parent = root_model\n", + " steps = []\n", + " total_epochs = 0\n", + "\n", + " for i in range(1, len(scales)):\n", + " dim = scales[i]\n", + " child = PatternModel(task.seq_len, dim, task.n_classes, 4).to(device)\n", + " transfer_weights(parent, child)\n", + "\n", + " acc_proj = evaluate(child, test_x, test_y)\n", + " cv_proj = profile_model(child)\n", + "\n", + " # 1 healing epoch\n", + " heal_one_epoch(child, task)\n", + " total_epochs += 1\n", + "\n", + " acc_heal = evaluate(child, test_x, test_y)\n", + " cv_heal = profile_model(child)\n", + "\n", + " steps.append({\n", + " \"from\": scales[i-1], \"to\": dim,\n", + " \"acc_proj\": acc_proj, \"acc_heal\": acc_heal,\n", + " \"cv_proj\": cv_proj, \"cv_heal\": cv_heal,\n", + " })\n", + "\n", + " parent = child\n", + "\n", + " return acc_heal, total_epochs, cv_heal, steps\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# EXPERIMENT\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def run_experiment():\n", + " print(\"=\" * 70)\n", + " print(\"OPTIMAL SCALING RATIO EXPERIMENT\")\n", + " print(\"=\" * 70)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " print(f\" Device: {device}\")\n", + "\n", + " # Config\n", + " SEQ_LEN = 64\n", + " N_CLASSES = 16\n", + " NOISE = 0.3\n", + " START_DIM = 256\n", + " END_DIM = 64\n", + " N_LAYERS = 4\n", + "\n", + " # Named ratios\n", + " PHI_INV = 1.0 / ((1 + math.sqrt(5)) / 2) # 0.6180\n", + " SQRT2_INV = 1.0 / math.sqrt(2) # 0.7071\n", + " PHIL_COMP = 1.0 - 0.29514 # 0.7049\n", + " E_INV = 1.0 / math.e # 0.3679\n", + "\n", + " RATIOS = [\n", + " (0.50, \"0.500 (halving)\"),\n", + " (0.55, \"0.550\"),\n", + " (0.60, \"0.600\"),\n", + " (PHI_INV, f\"0.618 (1/φ golden)\"),\n", + " (0.65, \"0.650\"),\n", + " (0.70, \"0.700\"),\n", + " (PHIL_COMP, f\"0.705 (1-0.295)\"),\n", + " (SQRT2_INV, f\"0.707 (1/√2)\"),\n", + " (0.75, \"0.750\"),\n", + " (0.80, \"0.800\"),\n", + " (0.85, \"0.850\"),\n", + " (0.90, \"0.900\"),\n", + " (0.95, \"0.950\"),\n", + " ]\n", + "\n", + " task = PatternTask(N_CLASSES, SEQ_LEN, NOISE, device)\n", + " test_x, test_y = task.generate(5000)\n", + "\n", + " print(f\"\\n Task: {N_CLASSES}-class, seq_len={SEQ_LEN}, noise={NOISE}\")\n", + " print(f\" Compression: {START_DIM} → {END_DIM}\")\n", + " print(f\" Testing {len(RATIOS)} scaling ratios\")\n", + "\n", + " # ── Train root model ──\n", + " print(f\"\\n Training root model ({START_DIM}-dim)...\")\n", + " root = PatternModel(SEQ_LEN, START_DIM, N_CLASSES, N_LAYERS).to(device)\n", + " optimizer = torch.optim.AdamW(root.parameters(), lr=3e-4)\n", + " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)\n", + " train_x, train_y = task.generate(10000)\n", + "\n", + " for epoch in range(200):\n", + " root.train()\n", + " perm = torch.randperm(10000, device=device)\n", + " for i in range(0, 10000, 256):\n", + " idx = perm[i:i+256]\n", + " loss = F.cross_entropy(root(train_x[idx]), train_y[idx])\n", + " optimizer.zero_grad(); loss.backward(); optimizer.step()\n", + " scheduler.step()\n", + " if (epoch+1) % 50 == 0:\n", + " acc = evaluate(root, test_x, test_y)\n", + " print(f\" Epoch {epoch+1}: acc={acc:.4f}\")\n", + "\n", + " root_acc = evaluate(root, test_x, test_y)\n", + " root_cv = profile_model(root)\n", + " print(f\" Root: acc={root_acc:.4f}, cv={root_cv:.4f}\")\n", + "\n", + " # ── Sweep ratios ──\n", + " print(f\"\\n{'='*70}\")\n", + " print(\"RATIO SWEEP\")\n", + " print(f\"{'='*70}\\n\")\n", + "\n", + " results = []\n", + "\n", + " for ratio, name in RATIOS:\n", + " scales = build_cascade(START_DIM, END_DIM, ratio)\n", + " n_steps = len(scales) - 1\n", + "\n", + " t0 = time.time()\n", + " final_acc, total_epochs, final_cv, steps = run_cascade(\n", + " root, task, test_x, test_y, scales, device)\n", + " elapsed = time.time() - t0\n", + "\n", + " # Compute efficiency metric: accuracy retained per heal epoch\n", + " acc_retained = final_acc / max(root_acc, 1e-8)\n", + " efficiency = acc_retained / max(total_epochs, 1)\n", + "\n", + " result = {\n", + " \"ratio\": ratio,\n", + " \"name\": name,\n", + " \"scales\": scales,\n", + " \"n_steps\": n_steps,\n", + " \"final_acc\": final_acc,\n", + " \"final_cv\": final_cv,\n", + " \"total_epochs\": total_epochs,\n", + " \"acc_retained\": acc_retained,\n", + " \"efficiency\": efficiency,\n", + " \"elapsed\": elapsed,\n", + " \"steps\": steps,\n", + " }\n", + " results.append(result)\n", + "\n", + " scale_str = \"→\".join(str(s) for s in scales)\n", + " print(f\" r={ratio:.3f} ({name:20s}): {n_steps} steps \"\n", + " f\"acc={final_acc:.4f} ret={acc_retained:.1%} \"\n", + " f\"cv={final_cv:.4f} eff={efficiency:.4f} \"\n", + " f\"[{scale_str}]\")\n", + "\n", + " # ── Direct jump baseline ──\n", + " direct = PatternModel(SEQ_LEN, END_DIM, N_CLASSES, N_LAYERS).to(device)\n", + " transfer_weights(root, direct)\n", + " direct_acc = evaluate(direct, test_x, test_y)\n", + " heal_one_epoch(direct, task)\n", + " direct_heal_acc = evaluate(direct, test_x, test_y)\n", + "\n", + " # ══════════════════════════════════════════════════════════════\n", + " # ANALYSIS\n", + " # ══════════════════════════════════════════════════════════════\n", + "\n", + " # Sort by final accuracy\n", + " results.sort(key=lambda x: x[\"final_acc\"], reverse=True)\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(\"RESULTS — SORTED BY ACCURACY\")\n", + " print(f\"{'='*70}\\n\")\n", + "\n", + " print(f\" {'Ratio':<22s} {'Steps':>5s} {'Acc':>7s} {'Retained':>9s} \"\n", + " f\"{'CV':>7s} {'Epochs':>6s} {'Eff':>7s}\")\n", + " print(f\" {'─'*22} {'─'*5} {'─'*7} {'─'*9} {'─'*7} {'─'*6} {'─'*7}\")\n", + "\n", + " for r in results:\n", + " marker = \" ★\" if \"golden\" in r[\"name\"] or \"0.295\" in r[\"name\"] or \"√2\" in r[\"name\"] else \"\"\n", + " print(f\" {r['name']:<22s} {r['n_steps']:>5d} {r['final_acc']:>7.4f} \"\n", + " f\"{r['acc_retained']:>8.1%} {r['final_cv']:>7.4f} \"\n", + " f\"{r['total_epochs']:>6d} {r['efficiency']:>7.4f}{marker}\")\n", + "\n", + " print(f\"\\n {'Direct 256→64':22s} {'1':>5s} {direct_heal_acc:>7.4f} \"\n", + " f\"{direct_heal_acc/root_acc:>8.1%} {'—':>7s} {'1':>6s}\")\n", + " print(f\" {'Root (256)':22s} {'—':>5s} {root_acc:>7.4f} \"\n", + " f\"{'100.0%':>9s} {root_cv:>7.4f} {'200':>6s}\")\n", + "\n", + " # ── Find optimal ──\n", + " best = results[0]\n", + " print(f\"\\n OPTIMAL RATIO: {best['name']}\")\n", + " print(f\" Accuracy: {best['final_acc']:.4f} ({best['acc_retained']:.1%} retained)\")\n", + " print(f\" Steps: {best['n_steps']}\")\n", + " print(f\" Scales: {'→'.join(str(s) for s in best['scales'])}\")\n", + " print(f\" CV: {best['final_cv']:.4f} (root: {root_cv:.4f})\")\n", + "\n", + " # ── Natural constant comparison ──\n", + " phi_result = next(r for r in results if \"golden\" in r[\"name\"])\n", + " sqrt2_result = next(r for r in results if \"√2\" in r[\"name\"])\n", + " phil_result = next(r for r in results if \"0.295\" in r[\"name\"])\n", + "\n", + " print(f\"\\n NATURAL CONSTANTS:\")\n", + " print(f\" 1/φ (0.618): acc={phi_result['final_acc']:.4f} \"\n", + " f\"steps={phi_result['n_steps']} scales={'→'.join(str(s) for s in phi_result['scales'])}\")\n", + " print(f\" 1/��2 (0.707): acc={sqrt2_result['final_acc']:.4f} \"\n", + " f\"steps={sqrt2_result['n_steps']} scales={'→'.join(str(s) for s in sqrt2_result['scales'])}\")\n", + " print(f\" 1-0.295(0.705): acc={phil_result['final_acc']:.4f} \"\n", + " f\"steps={phil_result['n_steps']} scales={'→'.join(str(s) for s in phil_result['scales'])}\")\n", + "\n", + " # ── Pareto analysis: accuracy vs training cost ──\n", + " print(f\"\\n PARETO FRONTIER (accuracy vs epochs):\")\n", + " print(f\" {'─'*50}\")\n", + " pareto = []\n", + " best_acc_so_far = 0\n", + " for r in sorted(results, key=lambda x: x[\"total_epochs\"]):\n", + " if r[\"final_acc\"] > best_acc_so_far:\n", + " best_acc_so_far = r[\"final_acc\"]\n", + " pareto.append(r)\n", + " print(f\" {r['total_epochs']:2d} epochs → {r['final_acc']:.4f} \"\n", + " f\"({r['name']})\")\n", + "\n", + " print(f\"\\nDone.\")\n", + " return results\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " results = run_experiment()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gZk3JSs4UCke", + "outputId": "2e8b6628-056e-44f0-c168-931134031e84" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "======================================================================\n", + "OPTIMAL SCALING RATIO EXPERIMENT\n", + "======================================================================\n", + " Device: cuda\n", + "\n", + " Task: 16-class, seq_len=64, noise=0.3\n", + " Compression: 256 → 64\n", + " Testing 13 scaling ratios\n", + "\n", + " Training root model (256-dim)...\n", + " Epoch 50: acc=0.8788\n", + " Epoch 100: acc=0.8766\n", + " Epoch 150: acc=0.8746\n", + " Epoch 200: acc=0.8746\n", + " Root: acc=0.8746, cv=0.0929\n", + "\n", + "======================================================================\n", + "RATIO SWEEP\n", + "======================================================================\n", + "\n", + " r=0.500 (0.500 (halving) ): 2 steps acc=0.7724 ret=88.3% cv=0.1617 eff=0.4416 [256→128→64]\n", + " r=0.550 (0.550 ): 3 steps acc=0.8236 ret=94.2% cv=0.1478 eff=0.3139 [256→141→78→64]\n", + " r=0.600 (0.600 ): 3 steps acc=0.8148 ret=93.2% cv=0.1406 eff=0.3105 [256→154→92→64]\n", + " r=0.618 (0.618 (1/φ golden) ): 3 steps acc=0.8114 ret=92.8% cv=0.1378 eff=0.3092 [256→158→98→64]\n", + " r=0.650 (0.650 ): 4 steps acc=0.8416 ret=96.2% cv=0.1357 eff=0.2406 [256→166→108→70→64]\n", + " r=0.700 (0.700 ): 4 steps acc=0.8248 ret=94.3% cv=0.1418 eff=0.2358 [256→179→125→88→64]\n", + " r=0.705 (0.705 (1-0.295) ): 4 steps acc=0.8232 ret=94.1% cv=0.1378 eff=0.2353 [256→180→127→90→64]\n", + " r=0.707 (0.707 (1/√2) ): 4 steps acc=0.8194 ret=93.7% cv=0.1404 eff=0.2342 [256→181→128→91→64]\n", + " r=0.750 (0.750 ): 5 steps acc=0.8350 ret=95.5% cv=0.1483 eff=0.1909 [256→192→144→108→81→64]\n", + " r=0.800 (0.800 ): 7 steps acc=0.8626 ret=98.6% cv=0.1432 eff=0.1409 [256→205→164→131→105→84→67→64]\n", + " r=0.850 (0.850 ): 9 steps acc=0.8680 ret=99.2% cv=0.1400 eff=0.1103 [256→218→185→157→133→113→96→82→70→64]\n", + " r=0.900 (0.900 ): 14 steps acc=0.8852 ret=101.2% cv=0.1316 eff=0.0723 [256→230→207→186→167→150→135→122→110→99→89→80→72→65→64]\n", + " r=0.950 (0.950 ): 27 steps acc=0.8916 ret=101.9% cv=0.1318 eff=0.0378 [256→243→231→219→208→198→188→179→170→162→154→146→139→132→125→119→113→107→102→97→92→87→83→79→75→71→67→64]\n", + "\n", + "======================================================================\n", + "RESULTS — SORTED BY ACCURACY\n", + "======================================================================\n", + "\n", + " Ratio Steps Acc Retained CV Epochs Eff\n", + " ────────────────────── ───── ─────── ───────── ─────── ────── ─────���─\n", + " 0.950 27 0.8916 101.9% 0.1318 27 0.0378\n", + " 0.900 14 0.8852 101.2% 0.1316 14 0.0723\n", + " 0.850 9 0.8680 99.2% 0.1400 9 0.1103\n", + " 0.800 7 0.8626 98.6% 0.1432 7 0.1409\n", + " 0.650 4 0.8416 96.2% 0.1357 4 0.2406\n", + " 0.750 5 0.8350 95.5% 0.1483 5 0.1909\n", + " 0.700 4 0.8248 94.3% 0.1418 4 0.2358\n", + " 0.550 3 0.8236 94.2% 0.1478 3 0.3139\n", + " 0.705 (1-0.295) 4 0.8232 94.1% 0.1378 4 0.2353 ★\n", + " 0.707 (1/√2) 4 0.8194 93.7% 0.1404 4 0.2342 ★\n", + " 0.600 3 0.8148 93.2% 0.1406 3 0.3105\n", + " 0.618 (1/φ golden) 3 0.8114 92.8% 0.1378 3 0.3092 ★\n", + " 0.500 (halving) 2 0.7724 88.3% 0.1617 2 0.4416\n", + "\n", + " Direct 256→64 1 0.7426 84.9% — 1\n", + " Root (256) — 0.8746 100.0% 0.0929 200\n", + "\n", + " OPTIMAL RATIO: 0.950\n", + " Accuracy: 0.8916 (101.9% retained)\n", + " Steps: 27\n", + " Scales: 256→243→231→219→208→198→188→179→170→162→154→146→139→132→125→119→113→107→102→97→92→87→83→79→75→71→67→64\n", + " CV: 0.1318 (root: 0.0929)\n", + "\n", + " NATURAL CONSTANTS:\n", + " 1/φ (0.618): acc=0.8114 steps=3 scales=256→158→98→64\n", + " 1/√2 (0.707): acc=0.8194 steps=4 scales=256→181→128→91→64\n", + " 1-0.295(0.705): acc=0.8232 steps=4 scales=256→180→127→90→64\n", + "\n", + " PARETO FRONTIER (accuracy vs epochs):\n", + " ──────────────────────────────────────────────────\n", + " 2 epochs → 0.7724 (0.500 (halving))\n", + " 3 epochs → 0.8236 (0.550)\n", + " 4 epochs → 0.8416 (0.650)\n", + " 7 epochs → 0.8626 (0.800)\n", + " 9 epochs → 0.8680 (0.850)\n", + " 14 epochs → 0.8852 (0.900)\n", + " 27 epochs → 0.8916 (0.950)\n", + "\n", + "Done.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# bert rescaling" + ], + "metadata": { + "id": "0wpUVBCiXmJg" + } + }, + { + "cell_type": "code", + "source": [ + "# ============================================================================\n", + "# ITERATIVE GEOMETRIC CASCADE ON PRETRAINED BERT\n", + "#\n", + "# Take BERT-base (768-dim, 12 layers, 110M params) and cascade it down:\n", + "# 768 → 672 → 576 → 480 → 384\n", + "#\n", + "# At each scale:\n", + "# 1. SVD-project ALL weight matrices from parent\n", + "# 2. Evaluate MLM accuracy (can it still predict masked words?)\n", + "# 3. Optionally heal with 1 epoch of MLM\n", + "# 4. Project to next scale\n", + "#\n", + "# No fine-tuning on any downstream task. Pure compression of pretrained\n", + "# knowledge via geometric projection.\n", + "# ============================================================================\n", + "\n", + "import math\n", + "import time\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "from dataclasses import dataclass\n", + "from typing import Dict, List, Tuple, Optional\n", + "from transformers import (\n", + " BertForMaskedLM, BertTokenizer, BertConfig,\n", + " DataCollatorForLanguageModeling\n", + ")\n", + "from datasets import load_dataset\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# GEOMETRIC UTILITIES\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def cayley_menger_vol2(pts):\n", + " with torch.amp.autocast(\"cuda\", enabled=False):\n", + " pts = pts.float()\n", + " diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)\n", + " d2 = (diff * diff).sum(-1)\n", + " B, V, _ = d2.shape\n", + " cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)\n", + " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", + " s = (-1.0)**V; f = math.factorial(V-1)\n", + " return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)\n", + "\n", + "\n", + "def pentachoron_cv(W, n_samples=200):\n", + " \"\"\"CV on weight matrix rows.\"\"\"\n", + " if W.dim() != 2 or W.shape[0] < 5:\n", + " return 0.0\n", + " B = W.shape[0]\n", + " vols = []\n", + " for _ in range(n_samples):\n", + " idx = torch.randperm(B, device=W.device)[:5]\n", + " v2 = cayley_menger_vol2(W[idx].unsqueeze(0))\n", + " v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()\n", + " if v > 0:\n", + " vols.append(v)\n", + " if len(vols) < 10:\n", + " return 0.0\n", + " a = np.array(vols, dtype=np.float64)\n", + " return float(a.std() / max(a.mean(), 1e-12))\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# SVD PROJECTION\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def svd_project_matrix(W, out_dim, in_dim):\n", + " \"\"\"Project weight matrix via truncated SVD.\"\"\"\n", + " W = W.float()\n", + " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", + " k = min(S.shape[0], out_dim, in_dim)\n", + " U_k = U[:min(W.shape[0], out_dim), :k]\n", + " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", + " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", + " result = torch.zeros(out_dim, in_dim, dtype=W.dtype, device=W.device)\n", + " r, c = W_small.shape\n", + " result[:r, :c] = W_small\n", + " return result\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def svd_project_vector(v, target_dim):\n", + " \"\"\"Project 1D vector (bias, layernorm) by truncation or padding.\"\"\"\n", + " if v.shape[0] == target_dim:\n", + " return v.clone()\n", + " elif v.shape[0] > target_dim:\n", + " return v[:target_dim].clone()\n", + " else:\n", + " result = torch.zeros(target_dim, dtype=v.dtype, device=v.device)\n", + " result[:v.shape[0]] = v\n", + " return result\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def svd_project_embedding(E, target_dim):\n", + " \"\"\"Project embedding matrix (vocab_size, hidden) → (vocab_size, target_dim).\"\"\"\n", + " E = E.float()\n", + " # Keep all vocab rows, reduce hidden dim via SVD on the embedding matrix\n", + " U, S, Vt = torch.linalg.svd(E, full_matrices=False)\n", + " k = min(S.shape[0], target_dim)\n", + " # Reconstruct at reduced dimension\n", + " projected = U[:, :k] @ torch.diag(S[:k]) @ Vt[:k, :target_dim]\n", + " if projected.shape[1] < target_dim:\n", + " result = torch.zeros(E.shape[0], target_dim, dtype=E.dtype, device=E.device)\n", + " result[:, :projected.shape[1]] = projected\n", + " return result\n", + " return projected\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# BERT WEIGHT TRANSFER\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def create_scaled_bert(source_model, target_hidden, target_intermediate, device):\n", + " \"\"\"\n", + " Create a new BERT with smaller hidden/intermediate dims,\n", + " SVD-projecting all weights from source.\n", + " \"\"\"\n", + " src_config = source_model.config\n", + " src_hidden = src_config.hidden_size\n", + " src_inter = src_config.intermediate_size\n", + " n_heads = src_config.num_attention_heads\n", + " head_dim = target_hidden // n_heads\n", + "\n", + " # New config\n", + " new_config = BertConfig(\n", + " vocab_size=src_config.vocab_size,\n", + " hidden_size=target_hidden,\n", + " num_hidden_layers=src_config.num_hidden_layers,\n", + " num_attention_heads=n_heads,\n", + " intermediate_size=target_intermediate,\n", + " max_position_embeddings=src_config.max_position_embeddings,\n", + " type_vocab_size=src_config.type_vocab_size,\n", + " hidden_act=src_config.hidden_act,\n", + " hidden_dropout_prob=0.0,\n", + " attention_probs_dropout_prob=0.0,\n", + " )\n", + "\n", + " new_model = BertForMaskedLM(new_config).to(device)\n", + " src_sd = source_model.state_dict()\n", + " new_sd = new_model.state_dict()\n", + "\n", + " transferred = {}\n", + "\n", + " for name, param in new_sd.items():\n", + " if name not in src_sd:\n", + " continue\n", + " src_p = src_sd[name].to(device)\n", + "\n", + " if src_p.shape == param.shape:\n", + " transferred[name] = src_p.clone()\n", + " elif src_p.dim() == 2:\n", + " transferred[name] = svd_project_matrix(\n", + " src_p, param.shape[0], param.shape[1])\n", + " elif src_p.dim() == 1:\n", + " transferred[name] = svd_project_vector(src_p, param.shape[0])\n", + " else:\n", + " # Skip or pad higher-dim tensors\n", + " transferred[name] = param.clone()\n", + "\n", + " # Load transferred weights\n", + " missing, unexpected = new_model.load_state_dict(transferred, strict=False)\n", + " n_transferred = len(transferred)\n", + " n_total = len(new_sd)\n", + " print(f\" Transferred {n_transferred}/{n_total} params, \"\n", + " f\"{len(missing)} missing, {len(unexpected)} unexpected\")\n", + "\n", + " return new_model\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# PROFILING\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def profile_bert(model, tag=\"\"):\n", + " \"\"\"Profile CV of attention and FFN weight matrices.\"\"\"\n", + " cvs = []\n", + " for name, param in model.named_parameters():\n", + " if param.dim() == 2 and param.shape[0] >= 5 and param.shape[1] >= 5:\n", + " if \"weight\" in name and (\"dense\" in name or \"query\" in name\n", + " or \"key\" in name or \"value\" in name):\n", + " cv = pentachoron_cv(param.detach(), n_samples=100)\n", + " cvs.append(cv)\n", + " mean_cv = np.mean(cvs) if cvs else 0.0\n", + " n_params = sum(p.numel() for p in model.parameters())\n", + " if tag:\n", + " print(f\" [{tag}] {n_params:,} params, mean CV={mean_cv:.4f} \"\n", + " f\"(across {len(cvs)} weight matrices)\")\n", + " return mean_cv, n_params\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# EVALUATION: MLM accuracy on short stories\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def evaluate_mlm(model, tokenizer, texts, device, mask_prob=0.15, max_len=128):\n", + " \"\"\"\n", + " Mask random tokens, see if model predicts them correctly.\n", + " Returns: top-1 accuracy, top-5 accuracy.\n", + " \"\"\"\n", + " model.eval()\n", + " total_correct_1 = 0\n", + " total_correct_5 = 0\n", + " total_masked = 0\n", + "\n", + " for text in texts:\n", + " tokens = tokenizer(text, return_tensors=\"pt\", max_length=max_len,\n", + " truncation=True, padding=False).to(device)\n", + " input_ids = tokens[\"input_ids\"][0]\n", + " seq_len = input_ids.shape[0]\n", + "\n", + " if seq_len < 5:\n", + " continue\n", + "\n", + " # Create masks (skip [CLS], [SEP], [PAD])\n", + " special_mask = torch.zeros(seq_len, dtype=torch.bool, device=device)\n", + " special_mask[0] = True # CLS\n", + " special_mask[seq_len - 1] = True # SEP\n", + " special_mask[input_ids == tokenizer.pad_token_id] = True\n", + "\n", + " maskable = ~special_mask\n", + " n_mask = max(1, int(maskable.sum().item() * mask_prob))\n", + " mask_positions = maskable.nonzero(as_tuple=True)[0]\n", + " if len(mask_positions) == 0:\n", + " continue\n", + " chosen = mask_positions[torch.randperm(len(mask_positions))[:n_mask]]\n", + "\n", + " # Save originals\n", + " original_ids = input_ids[chosen].clone()\n", + "\n", + " # Mask\n", + " masked_ids = input_ids.clone()\n", + " masked_ids[chosen] = tokenizer.mask_token_id\n", + "\n", + " # Forward\n", + " outputs = model(masked_ids.unsqueeze(0),\n", + " attention_mask=tokens[\"attention_mask\"])\n", + " logits = outputs.logits[0, chosen] # (n_mask, vocab_size)\n", + "\n", + " # Top-1\n", + " preds = logits.argmax(dim=-1)\n", + " total_correct_1 += (preds == original_ids).sum().item()\n", + "\n", + " # Top-5\n", + " top5 = logits.topk(5, dim=-1).indices\n", + " total_correct_5 += (top5 == original_ids.unsqueeze(-1)).any(dim=-1).sum().item()\n", + "\n", + " total_masked += n_mask\n", + "\n", + " if total_masked == 0:\n", + " return 0.0, 0.0\n", + " return total_correct_1 / total_masked, total_correct_5 / total_masked\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# HEAL: minimal MLM training\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def heal_mlm(model, tokenizer, texts, device, n_epochs=1,\n", + " lr=5e-5, max_len=128, batch_size=16):\n", + " \"\"\"Quick MLM training to heal projection distortion.\"\"\"\n", + " model.train()\n", + "\n", + " # Tokenize\n", + " encodings = tokenizer(texts, max_length=max_len, truncation=True,\n", + " padding=\"max_length\", return_tensors=\"pt\")\n", + " dataset_ids = encodings[\"input_ids\"]\n", + " dataset_mask = encodings[\"attention_mask\"]\n", + "\n", + " collator = DataCollatorForLanguageModeling(\n", + " tokenizer=tokenizer, mlm=True, mlm_probability=0.15)\n", + "\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", + " n_samples = dataset_ids.shape[0]\n", + "\n", + " total_loss = 0\n", + " n_batches = 0\n", + "\n", + " for epoch in range(n_epochs):\n", + " perm = torch.randperm(n_samples)\n", + " for i in range(0, n_samples, batch_size):\n", + " idx = perm[i:i+batch_size]\n", + " batch_ids = dataset_ids[idx]\n", + " batch_mask = dataset_mask[idx]\n", + "\n", + " # Manual masking\n", + " collated = collator([{\"input_ids\": ids, \"attention_mask\": m}\n", + " for ids, m in zip(batch_ids, batch_mask)])\n", + "\n", + " input_ids = collated[\"input_ids\"].to(device)\n", + " attention_mask = collated[\"attention_mask\"].to(device)\n", + " labels = collated[\"labels\"].to(device)\n", + "\n", + " outputs = model(input_ids=input_ids,\n", + " attention_mask=attention_mask,\n", + " labels=labels)\n", + " loss = outputs.loss\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + " n_batches += 1\n", + "\n", + " return total_loss / max(n_batches, 1)\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# EXPERIMENT\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def run_experiment():\n", + " print(\"=\" * 70)\n", + " print(\"ITERATIVE CASCADE ON PRETRAINED BERT-BASE\")\n", + " print(\"=\" * 70)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " print(f\" Device: {device}\")\n", + "\n", + " # ── Configuration ──\n", + " # hidden_size must be divisible by num_heads=12\n", + " # 768/12=64, 672/12=56, 576/12=48, 480/12=40, 384/12=32\n", + " SCALES = [768, 720, 672, 624, 576, 528, 480, 432, 384]\n", + "\n", + " # intermediate scales proportionally: 3072 → ...\n", + " INTER_SCALES = [3072, 2880, 2688, 2496, 2304, 2112, 1920, 1728, 1536]\n", + " N_EVAL_TEXTS = 200\n", + " N_HEAL_TEXTS = 5000\n", + " HEAL_EPOCHS = 5\n", + "\n", + " print(f\" Scales: {' → '.join(str(s) for s in SCALES)}\")\n", + " print(f\" Compression: {SCALES[0]} → {SCALES[-1]} \"\n", + " f\"({SCALES[-1]/SCALES[0]:.0%})\")\n", + "\n", + " # ── Load data ──\n", + " print(f\"\\n Loading evaluation data...\")\n", + " ds = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\")\n", + " eval_texts = [r[\"text\"].strip() for r in ds if len(r[\"text\"].strip()) > 50]\n", + " eval_texts = eval_texts[:N_EVAL_TEXTS]\n", + " print(f\" {len(eval_texts)} eval texts\")\n", + "\n", + " ds_train = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"train\")\n", + " heal_texts = [r[\"text\"].strip() for r in ds_train if len(r[\"text\"].strip()) > 100]\n", + " heal_texts = heal_texts[:N_HEAL_TEXTS]\n", + " print(f\" {len(heal_texts)} heal texts\")\n", + "\n", + " # ── Load BERT-base ──\n", + " print(f\"\\n Loading BERT-base...\")\n", + " tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")\n", + " root_model = BertForMaskedLM.from_pretrained(\"google-bert/bert-base-uncased\").to(device)\n", + " root_model.eval()\n", + "\n", + " # ── Profile + evaluate root ──\n", + " root_cv, root_params = profile_bert(root_model, \"Root 768-dim\")\n", + "\n", + " print(f\" Evaluating root MLM accuracy...\")\n", + " root_top1, root_top5 = evaluate_mlm(root_model, tokenizer, eval_texts, device)\n", + " print(f\" Root: top1={root_top1:.4f} top5={root_top5:.4f}\")\n", + "\n", + " # ── Cascade ──\n", + " results = [{\n", + " \"scale\": SCALES[0],\n", + " \"inter\": INTER_SCALES[0],\n", + " \"params\": root_params,\n", + " \"cv\": root_cv,\n", + " \"top1_proj\": root_top1,\n", + " \"top5_proj\": root_top5,\n", + " \"top1_heal\": root_top1,\n", + " \"top5_heal\": root_top5,\n", + " \"heal_loss\": 0,\n", + " \"heal_epochs\": 0,\n", + " }]\n", + "\n", + " parent_model = root_model\n", + "\n", + " for i in range(1, len(SCALES)):\n", + " hidden = SCALES[i]\n", + " inter = INTER_SCALES[i]\n", + " parent_hidden = SCALES[i-1]\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"SCALE {i}: {parent_hidden} → {hidden} \"\n", + " f\"({(parent_hidden-hidden)/parent_hidden:.0%} reduction)\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " # ── Project ──\n", + " print(f\" Projecting...\")\n", + " t0 = time.time()\n", + " child_model = create_scaled_bert(\n", + " parent_model, hidden, inter, device)\n", + " proj_time = time.time() - t0\n", + " print(f\" Projection took {proj_time:.1f}s\")\n", + "\n", + " # ── Profile + evaluate after projection ──\n", + " child_cv, child_params = profile_bert(child_model, f\"Projected {hidden}-dim\")\n", + "\n", + " print(f\" Evaluating MLM after projection...\")\n", + " proj_top1, proj_top5 = evaluate_mlm(\n", + " child_model, tokenizer, eval_texts, device)\n", + " print(f\" After projection: top1={proj_top1:.4f} top5={proj_top5:.4f}\")\n", + "\n", + " # ── Heal ──\n", + " print(f\" Healing ({HEAL_EPOCHS} epoch MLM)...\")\n", + " t0 = time.time()\n", + " heal_loss = heal_mlm(child_model, tokenizer, heal_texts, device,\n", + " n_epochs=HEAL_EPOCHS, lr=5e-5)\n", + " heal_time = time.time() - t0\n", + " print(f\" Heal loss: {heal_loss:.4f} ({heal_time:.1f}s)\")\n", + "\n", + " # ── Profile + evaluate after heal ──\n", + " heal_cv, _ = profile_bert(child_model, f\"Healed {hidden}-dim\")\n", + "\n", + " print(f\" Evaluating MLM after heal...\")\n", + " heal_top1, heal_top5 = evaluate_mlm(\n", + " child_model, tokenizer, eval_texts, device)\n", + " print(f\" After heal: top1={heal_top1:.4f} top5={heal_top5:.4f}\")\n", + "\n", + " results.append({\n", + " \"scale\": hidden,\n", + " \"inter\": inter,\n", + " \"params\": child_params,\n", + " \"cv\": heal_cv,\n", + " \"top1_proj\": proj_top1,\n", + " \"top5_proj\": proj_top5,\n", + " \"top1_heal\": heal_top1,\n", + " \"top5_heal\": heal_top5,\n", + " \"heal_loss\": heal_loss,\n", + " \"heal_epochs\": HEAL_EPOCHS,\n", + " })\n", + "\n", + " parent_model = child_model\n", + "\n", + " # ── Direct jump: 768 → 384 ──\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"DIRECT PROJECTION: {SCALES[0]} → {SCALES[-1]}\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " direct_model = create_scaled_bert(\n", + " root_model, SCALES[-1], INTER_SCALES[-1], device)\n", + " direct_cv, direct_params = profile_bert(direct_model, f\"Direct {SCALES[-1]}-dim\")\n", + " direct_top1, direct_top5 = evaluate_mlm(\n", + " direct_model, tokenizer, eval_texts, device)\n", + " print(f\" Direct: top1={direct_top1:.4f} top5={direct_top5:.4f}\")\n", + "\n", + " # ── Report ──\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"RESULTS\")\n", + " print(f\"{'='*70}\\n\")\n", + "\n", + " print(f\" {'Scale':>6s} {'Params':>12s} {'Top1(proj)':>11s} {'Top1(heal)':>11s} \"\n", + " f\"{'Top5(proj)':>11s} {'Top5(heal)':>11s} {'CV':>7s}\")\n", + " print(f\" {'─'*6} {'─'*12} {'─'*11} {'─'*11} {'─'*11} {'─'*11} {'─'*7}\")\n", + "\n", + " for r in results:\n", + " print(f\" {r['scale']:>6d} {r['params']:>12,} {r['top1_proj']:>11.4f} \"\n", + " f\"{r['top1_heal']:>11.4f} {r['top5_proj']:>11.4f} \"\n", + " f\"{r['top5_heal']:>11.4f} {r['cv']:>7.4f}\")\n", + "\n", + " print(f\"\\n DIRECT {SCALES[-1]}: {direct_params:>12,} \"\n", + " f\"top1={direct_top1:.4f} top5={direct_top5:.4f} cv={direct_cv:.4f}\")\n", + "\n", + " # ── Retention ──\n", + " final = results[-1]\n", + " print(f\"\\n SUMMARY:\")\n", + " print(f\" Root: {root_params:>12,} params top1={root_top1:.4f} top5={root_top5:.4f}\")\n", + " print(f\" Cascade: {final['params']:>12,} params \"\n", + " f\"top1={final['top1_heal']:.4f} top5={final['top5_heal']:.4f}\")\n", + " print(f\" Direct: {direct_params:>12,} params \"\n", + " f\"top1={direct_top1:.4f} top5={direct_top5:.4f}\")\n", + " print(f\" Compression: {root_params/final['params']:.1f}×\")\n", + " print(f\" Top1 retained (cascade): {final['top1_heal']/root_top1:.1%}\")\n", + " print(f\" Top1 retained (direct): {direct_top1/root_top1:.1%}\")\n", + "\n", + " print(f\"\\nDone.\")\n", + " return results\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " results = run_experiment()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "d5c2a63f6f8544e79268b9bade807345", + "f027a57fcb6a49cfbbd28b95a6c0adf7", + "4df257d047cb4e7ab2a0a6c57be58b81", + "214861fbbd124b45ba164e934f91a024", + "f6cf4e1cd78c4c0da8bf756a7cc8760a", + "bd3077d18b9640abb14a4c385cac7c39", + "c39efce145764de59d3dc9cb7a818d33", + "b9fc695d6d70408690bbd7ef8bb5841a", + "a1486fd248674e9584e90d907b5156c2", + "3472ba9857a543568ec3cd2f340571d6", + "89eae54541d447bdaa9625cbbe357e55" + ] + }, + "id": "dOhM9mTJXnQl", + "outputId": "f68265c8-ba6a-4182-fba7-e5736bbaf37d" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "======================================================================\n", + "ITERATIVE CASCADE ON PRETRAINED BERT-BASE\n", + "======================================================================\n", + " Device: cuda\n", + " Scales: 768 → 720 → 672 → 624 → 576 → 528 → 480 → 432 → 384\n", + " Compression: 768 → 384 (50%)\n", + "\n", + " Loading evaluation data...\n", + " 200 eval texts\n", + " 5000 heal texts\n", + "\n", + " Loading BERT-base...\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/202 [00:00 0:\n", + " vols.append(v)\n", + " if len(vols) < 10:\n", + " return 0.0\n", + " a = np.array(vols, dtype=np.float64)\n", + " return float(a.std() / max(a.mean(), 1e-12))\n", + "\n", + "def profile_bert(model, tag=\"\"):\n", + " cvs = []\n", + " for name, param in model.named_parameters():\n", + " if param.dim() == 2 and param.shape[0] >= 5 and param.shape[1] >= 5:\n", + " if \"weight\" in name and (\"dense\" in name or \"query\" in name\n", + " or \"key\" in name or \"value\" in name):\n", + " cv = pentachoron_cv(param.detach(), n_samples=100)\n", + " cvs.append(cv)\n", + " mean_cv = np.mean(cvs) if cvs else 0.0\n", + " n_params = sum(p.numel() for p in model.parameters())\n", + " if tag:\n", + " print(f\" [{tag}] {n_params:,} params, CV={mean_cv:.4f} ({len(cvs)} matrices)\")\n", + " return mean_cv, n_params\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# FIX 3: SVD PROJECTION WITH PROPER BIAS HANDLING\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def svd_project_weight_and_bias(W, b, out_dim, in_dim):\n", + " \"\"\"\n", + " Project weight matrix via truncated SVD AND project bias\n", + " into the new SVD output basis.\n", + "\n", + " W: (D_out, D_in) → (out_dim, in_dim)\n", + " b: (D_out,) → (out_dim,) projected via U_k.T\n", + "\n", + " Returns: W_new, b_new\n", + " \"\"\"\n", + " W = W.float()\n", + " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", + " k = min(S.shape[0], out_dim, in_dim)\n", + "\n", + " U_k = U[:, :k] # (D_out, k) — output basis\n", + " S_k = S[:k] # (k,)\n", + " Vt_k = Vt[:k, :] # (k, D_in)\n", + "\n", + " # Truncate input dimension\n", + " Vt_k_trunc = Vt_k[:, :min(W.shape[1], in_dim)]\n", + "\n", + " # Reconstruct at target dimensions\n", + " W_new = torch.zeros(out_dim, in_dim, dtype=W.dtype, device=W.device)\n", + " core = U_k[:min(W.shape[0], out_dim), :] @ torch.diag(S_k) @ Vt_k_trunc\n", + " r, c = core.shape\n", + " W_new[:r, :c] = core\n", + "\n", + " # Project bias into new output basis\n", + " b_new = torch.zeros(out_dim, dtype=W.dtype, device=W.device)\n", + " if b is not None:\n", + " b = b.float()\n", + " # b_projected = U_k[:out_dim, :].T @ b → but U_k might be (D_out, k) with k < out_dim\n", + " # Use the same truncation as W\n", + " b_proj = U_k[:min(W.shape[0], out_dim), :].T @ b[:min(W.shape[0], out_dim)]\n", + " b_new[:min(k, out_dim)] = b_proj[:min(k, out_dim)]\n", + "\n", + " return W_new, b_new\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def svd_project_matrix_only(W, out_dim, in_dim):\n", + " \"\"\"SVD project weight matrix without bias.\"\"\"\n", + " W = W.float()\n", + " U, S, Vt = torch.linalg.svd(W, full_matrices=True)\n", + " k = min(S.shape[0], out_dim, in_dim)\n", + " U_k = U[:min(W.shape[0], out_dim), :k]\n", + " Vt_k = Vt[:k, :min(W.shape[1], in_dim)]\n", + " W_small = U_k @ torch.diag(S[:k]) @ Vt_k\n", + " result = torch.zeros(out_dim, in_dim, dtype=W.dtype, device=W.device)\n", + " r, c = W_small.shape\n", + " result[:r, :c] = W_small\n", + " return result\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# FIX 4: L1 MAGNITUDE PRUNING FOR FFN INTERMEDIATE\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def l1_prune_ffn(W_up, b_up, W_down, b_down, target_intermediate):\n", + " \"\"\"\n", + " Prune FFN intermediate dimension by keeping rows/cols with highest L1 norm.\n", + " Preserves coordinate alignment with GELU nonlinearity.\n", + "\n", + " W_up: (src_inter, src_hidden) — expands\n", + " b_up: (src_inter,)\n", + " W_down: (src_hidden, src_inter) — contracts\n", + " b_down: (src_hidden,)\n", + "\n", + " Returns: pruned W_up, b_up, W_down (columns pruned)\n", + " \"\"\"\n", + " src_inter = W_up.shape[0]\n", + " if src_inter <= target_intermediate:\n", + " return W_up, b_up, W_down\n", + "\n", + " # Importance = L1 norm of each intermediate neuron\n", + " # Combined from both up-projection row and down-projection column\n", + " importance = W_up.float().abs().sum(dim=1) + W_down.float().abs().sum(dim=0)\n", + "\n", + " # Keep top-k\n", + " _, keep_idx = importance.topk(target_intermediate)\n", + " keep_idx = keep_idx.sort().values\n", + "\n", + " W_up_pruned = W_up[keep_idx, :]\n", + " b_up_pruned = b_up[keep_idx] if b_up is not None else None\n", + " W_down_pruned = W_down[:, keep_idx]\n", + " # b_down stays same dimension (hidden_size)\n", + "\n", + " return W_up_pruned, b_up_pruned, W_down_pruned\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# CORRECTED BERT PROJECTION\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def create_projected_bert(source_model, target_hidden, target_intermediate, device):\n", + " \"\"\"\n", + " Project BERT with:\n", + " - SVD + proper bias projection for attention/embedding matrices\n", + " - L1 magnitude pruning for FFN intermediate (respects GELU)\n", + " \"\"\"\n", + " src_config = source_model.config\n", + " new_config = BertConfig(\n", + " vocab_size=src_config.vocab_size,\n", + " hidden_size=target_hidden,\n", + " num_hidden_layers=src_config.num_hidden_layers,\n", + " num_attention_heads=src_config.num_attention_heads,\n", + " intermediate_size=target_intermediate,\n", + " max_position_embeddings=src_config.max_position_embeddings,\n", + " type_vocab_size=src_config.type_vocab_size,\n", + " hidden_act=src_config.hidden_act,\n", + " hidden_dropout_prob=0.0,\n", + " attention_probs_dropout_prob=0.0,\n", + " )\n", + " target = BertForMaskedLM(new_config).to(device)\n", + " src = source_model\n", + "\n", + " # ── Embeddings (SVD on hidden dim, keep vocab) ──\n", + " for emb_name in [\"word_embeddings\", \"position_embeddings\", \"token_type_embeddings\"]:\n", + " src_w = getattr(src.bert.embeddings, emb_name).weight.data\n", + " tgt_w = getattr(target.bert.embeddings, emb_name).weight.data\n", + " tgt_w.copy_(svd_project_matrix_only(src_w, tgt_w.shape[0], tgt_w.shape[1]))\n", + "\n", + " # Embedding LayerNorm — truncate (element-wise, no rotation)\n", + " target.bert.embeddings.LayerNorm.weight.data.copy_(\n", + " src.bert.embeddings.LayerNorm.weight.data[:target_hidden])\n", + " target.bert.embeddings.LayerNorm.bias.data.copy_(\n", + " src.bert.embeddings.LayerNorm.bias.data[:target_hidden])\n", + "\n", + " # ── Encoder layers ──\n", + " for i, (src_layer, tgt_layer) in enumerate(\n", + " zip(src.bert.encoder.layer, target.bert.encoder.layer)):\n", + "\n", + " # Q, K, V: (src_hidden, src_hidden) → (target_hidden, target_hidden)\n", + " for attr in [\"query\", \"key\", \"value\"]:\n", + " src_mod = getattr(src_layer.attention.self, attr)\n", + " tgt_mod = getattr(tgt_layer.attention.self, attr)\n", + " W_new, b_new = svd_project_weight_and_bias(\n", + " src_mod.weight.data, src_mod.bias.data,\n", + " target_hidden, target_hidden)\n", + " tgt_mod.weight.data.copy_(W_new)\n", + " tgt_mod.bias.data.copy_(b_new)\n", + "\n", + " # Attention output: (src_hidden, src_hidden) → (target_hidden, target_hidden)\n", + " W_new, b_new = svd_project_weight_and_bias(\n", + " src_layer.attention.output.dense.weight.data,\n", + " src_layer.attention.output.dense.bias.data,\n", + " target_hidden, target_hidden)\n", + " tgt_layer.attention.output.dense.weight.data.copy_(W_new)\n", + " tgt_layer.attention.output.dense.bias.data.copy_(b_new)\n", + "\n", + " # Attention LayerNorm — truncate\n", + " tgt_layer.attention.output.LayerNorm.weight.data.copy_(\n", + " src_layer.attention.output.LayerNorm.weight.data[:target_hidden])\n", + " tgt_layer.attention.output.LayerNorm.bias.data.copy_(\n", + " src_layer.attention.output.LayerNorm.bias.data[:target_hidden])\n", + "\n", + " # FFN: L1 magnitude pruning (respects GELU coordinate alignment)\n", + " W_up = src_layer.intermediate.dense.weight.data.to(device)\n", + " b_up = src_layer.intermediate.dense.bias.data.to(device)\n", + " W_down = src_layer.output.dense.weight.data.to(device)\n", + " b_down = src_layer.output.dense.bias.data.to(device)\n", + "\n", + " # First prune intermediate dimension\n", + " W_up_p, b_up_p, W_down_p = l1_prune_ffn(\n", + " W_up, b_up, W_down, b_down, target_intermediate)\n", + "\n", + " # Then SVD-project the hidden dimensions with proper bias\n", + " W_up_final, b_up_final = svd_project_weight_and_bias(\n", + " W_up_p, b_up_p, target_intermediate, target_hidden)\n", + " tgt_layer.intermediate.dense.weight.data.copy_(W_up_final)\n", + " tgt_layer.intermediate.dense.bias.data.copy_(b_up_final)\n", + "\n", + " W_down_final, b_down_final = svd_project_weight_and_bias(\n", + " W_down_p, b_down, target_hidden, target_intermediate)\n", + " tgt_layer.output.dense.weight.data.copy_(W_down_final)\n", + " tgt_layer.output.dense.bias.data.copy_(b_down_final)\n", + "\n", + " # Output LayerNorm — truncate\n", + " tgt_layer.output.LayerNorm.weight.data.copy_(\n", + " src_layer.output.LayerNorm.weight.data[:target_hidden])\n", + " tgt_layer.output.LayerNorm.bias.data.copy_(\n", + " src_layer.output.LayerNorm.bias.data[:target_hidden])\n", + "\n", + " # ── MLM Head ──\n", + " if hasattr(src.cls.predictions.transform, 'dense'):\n", + " W_new, b_new = svd_project_weight_and_bias(\n", + " src.cls.predictions.transform.dense.weight.data,\n", + " src.cls.predictions.transform.dense.bias.data,\n", + " target_hidden, target_hidden)\n", + " target.cls.predictions.transform.dense.weight.data.copy_(W_new)\n", + " target.cls.predictions.transform.dense.bias.data.copy_(b_new)\n", + "\n", + " if hasattr(src.cls.predictions.transform, 'LayerNorm'):\n", + " target.cls.predictions.transform.LayerNorm.weight.data.copy_(\n", + " src.cls.predictions.transform.LayerNorm.weight.data[:target_hidden])\n", + " target.cls.predictions.transform.LayerNorm.bias.data.copy_(\n", + " src.cls.predictions.transform.LayerNorm.bias.data[:target_hidden])\n", + "\n", + " if hasattr(src.cls.predictions, 'bias'):\n", + " target.cls.predictions.bias.data.copy_(src.cls.predictions.bias.data)\n", + "\n", + " return target\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# FIX 2: FROZEN PER-LAYER PROJECTORS\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "class FrozenLayerProjectors(nn.Module):\n", + " \"\"\"\n", + " Per-layer projectors initialized from Procrustes, then FROZEN.\n", + " The student must move to match the fixed target — no shortcut collapse.\n", + " \"\"\"\n", + " def __init__(self, teacher_dim, student_dim, n_layers, device):\n", + " super().__init__()\n", + " self.projectors = nn.ModuleList([\n", + " nn.Linear(teacher_dim, student_dim, bias=False).to(device)\n", + " for _ in range(n_layers + 1)\n", + " ])\n", + "\n", + " @torch.no_grad()\n", + " def init_from_layer_procrustes(self, teacher_model, student_dim, device):\n", + " teacher_dim = teacher_model.config.hidden_size\n", + " for i, layer in enumerate(teacher_model.bert.encoder.layer):\n", + " weights = [\n", + " layer.attention.self.query.weight.data.T,\n", + " layer.attention.self.key.weight.data.T,\n", + " layer.attention.self.value.weight.data.T,\n", + " layer.intermediate.dense.weight.data.T,\n", + " ]\n", + " L = torch.cat(weights, dim=1).float().to(device)\n", + " U, S, Vt = torch.linalg.svd(L, full_matrices=False)\n", + " P_layer = U[:, :student_dim] # (teacher_dim, student_dim)\n", + " self.projectors[i + 1].weight.data.copy_(P_layer.T)\n", + " if i == 0:\n", + " self.projectors[0].weight.data.copy_(P_layer.T)\n", + "\n", + " # FREEZE all projectors\n", + " for p in self.parameters():\n", + " p.requires_grad = False\n", + "\n", + " print(f\" {len(self.projectors)} projectors initialized + FROZEN\")\n", + "\n", + " def forward(self, teacher_hiddens):\n", + " projected = []\n", + " for t_h, proj in zip(teacher_hiddens, self.projectors):\n", + " projected.append(proj(t_h.float()))\n", + " return projected\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# FIX 1: DISTILLATION LOSS ON ACTIVE TOKENS ONLY\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "class TeacherGuidedHealerV3:\n", + " def __init__(self, teacher_model, projectors, device,\n", + " mlm_weight=1.0, distill_weight=2.0):\n", + " self.teacher = teacher_model\n", + " self.teacher.eval()\n", + " self.projectors = projectors # FROZEN\n", + " self.device = device\n", + " self.mlm_weight = mlm_weight\n", + " self.distill_weight = distill_weight\n", + "\n", + " def compute_distillation_loss(self, student_model, input_ids, attention_mask):\n", + " with torch.no_grad():\n", + " teacher_out = self.teacher.bert(\n", + " input_ids=input_ids, attention_mask=attention_mask,\n", + " output_hidden_states=True, return_dict=True)\n", + "\n", + " student_out = student_model.bert(\n", + " input_ids=input_ids, attention_mask=attention_mask,\n", + " output_hidden_states=True, return_dict=True)\n", + "\n", + " # Per-layer projection (frozen projectors)\n", + " projected = self.projectors(teacher_out.hidden_states)\n", + " student_hiddens = student_out.hidden_states\n", + "\n", + " n_layers = min(len(projected), len(student_hiddens))\n", + " total_loss = torch.tensor(0.0, device=self.device)\n", + "\n", + " # FIX 1: Only compute loss on active (non-padding) tokens\n", + " active_mask = attention_mask.float() # (B, seq), 1=active, 0=pad\n", + " n_active = active_mask.sum().clamp(min=1.0)\n", + "\n", + " for layer_idx in range(1, n_layers):\n", + " t_proj = projected[layer_idx] # (B, seq, student_dim)\n", + " s_h = student_hiddens[layer_idx].float()\n", + "\n", + " # Cosine similarity per token\n", + " t_norm = F.normalize(t_proj, dim=-1)\n", + " s_norm = F.normalize(s_h, dim=-1)\n", + " cos_sim = (t_norm * s_norm).sum(-1) # (B, seq)\n", + "\n", + " # Mask out padding, average over active tokens only\n", + " cos_sim_active = cos_sim * active_mask\n", + " layer_loss = 1.0 - cos_sim_active.sum() / n_active\n", + " total_loss = total_loss + layer_loss\n", + "\n", + " return total_loss / max(n_layers - 1, 1)\n", + "\n", + " def heal(self, student_model, tokenizer, texts, n_epochs=5,\n", + " lr=5e-5, max_len=128, batch_size=16):\n", + " student_model.train()\n", + " # Only student params — projectors are frozen\n", + " optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr)\n", + "\n", + " enc = tokenizer(texts, max_length=max_len, truncation=True,\n", + " padding=\"max_length\", return_tensors=\"pt\")\n", + " ids, masks = enc[\"input_ids\"], enc[\"attention_mask\"]\n", + " collator = DataCollatorForLanguageModeling(\n", + " tokenizer=tokenizer, mlm=True, mlm_probability=0.15)\n", + " n = ids.shape[0]\n", + " total_loss = 0\n", + " n_batches = 0\n", + "\n", + " for epoch in range(n_epochs):\n", + " perm = torch.randperm(n)\n", + " for i in range(0, n, batch_size):\n", + " idx = perm[i:i+batch_size]\n", + " batch = [{\"input_ids\": ids[j], \"attention_mask\": masks[j]}\n", + " for j in idx]\n", + " c = collator(batch)\n", + " c_ids = c[\"input_ids\"].to(self.device)\n", + " c_mask = c[\"attention_mask\"].to(self.device)\n", + " c_labels = c[\"labels\"].to(self.device)\n", + "\n", + " mlm_out = student_model(\n", + " input_ids=c_ids, attention_mask=c_mask, labels=c_labels)\n", + "\n", + " # Distillation on unmasked input\n", + " orig_ids = ids[idx].to(self.device)\n", + " orig_mask = masks[idx].to(self.device)\n", + " distill_loss = self.compute_distillation_loss(\n", + " student_model, orig_ids, orig_mask)\n", + "\n", + " loss = (self.mlm_weight * mlm_out.loss +\n", + " self.distill_weight * distill_loss)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + " n_batches += 1\n", + "\n", + " if (epoch + 1) % 2 == 0 or epoch == 0:\n", + " nb = max(n_batches, 1)\n", + " print(f\" Epoch {epoch+1}: loss={total_loss/nb:.4f} \"\n", + " f\"(mlm≈{mlm_out.loss.item():.3f}, \"\n", + " f\"distill≈{distill_loss.item():.3f})\")\n", + "\n", + " return total_loss / max(n_batches, 1)\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# FIX 5: SEEDED EVALUATION\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "@torch.no_grad()\n", + "def evaluate_mlm(model, tokenizer, texts, device, mask_prob=0.15,\n", + " max_len=128, seed=42):\n", + " \"\"\"Deterministic masking for consistent cross-scale comparison.\"\"\"\n", + " model.eval()\n", + " gen = torch.Generator().manual_seed(seed)\n", + " total_1 = total_5 = total_m = 0\n", + "\n", + " for text in texts:\n", + " tokens = tokenizer(text, return_tensors=\"pt\", max_length=max_len,\n", + " truncation=True, padding=False).to(device)\n", + " input_ids = tokens[\"input_ids\"][0]\n", + " seq_len = input_ids.shape[0]\n", + " if seq_len < 5:\n", + " continue\n", + " special = torch.zeros(seq_len, dtype=torch.bool, device=device)\n", + " special[0] = special[seq_len-1] = True\n", + " special[input_ids == tokenizer.pad_token_id] = True\n", + " maskable = (~special).nonzero(as_tuple=True)[0]\n", + " if len(maskable) == 0:\n", + " continue\n", + " n_mask = max(1, int(len(maskable) * mask_prob))\n", + " chosen = maskable[torch.randperm(len(maskable), generator=gen)[:n_mask]]\n", + " orig = input_ids[chosen].clone()\n", + " masked = input_ids.clone()\n", + " masked[chosen] = tokenizer.mask_token_id\n", + " logits = model(masked.unsqueeze(0),\n", + " attention_mask=tokens[\"attention_mask\"]).logits[0, chosen]\n", + " total_1 += (logits.argmax(-1) == orig).sum().item()\n", + " top5 = logits.topk(5, dim=-1).indices\n", + " total_5 += (top5 == orig.unsqueeze(-1)).any(-1).sum().item()\n", + " total_m += n_mask\n", + " if total_m == 0:\n", + " return 0.0, 0.0\n", + " return total_1 / total_m, total_5 / total_m\n", + "\n", + "\n", + "# ══════════════════════════════════════════════════════════════════\n", + "# EXPERIMENT\n", + "# ══════════════════════════════════════════════════════════════════\n", + "\n", + "def run_experiment():\n", + " print(\"=\" * 70)\n", + " print(\"TEACHER-GUIDED CASCADE v3 — ALL FIXES\")\n", + " print(\"=\" * 70)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " print(f\" Device: {device}\")\n", + "\n", + " SCALES = [768, 720, 672, 624, 576, 528, 480, 432, 384]\n", + " INTER_SCALES = [3072, 2880, 2688, 2496, 2304, 2112, 1920, 1728, 1536]\n", + " N_EVAL = 200\n", + " N_HEAL = 5000\n", + " HEAL_EPOCHS = 5\n", + "\n", + " print(f\" Scales: {' → '.join(str(s) for s in SCALES)}\")\n", + " print(f\" Fixes: padding-masked loss, frozen projectors, \"\n", + " f\"SVD bias projection, L1 FFN pruning, seeded eval\")\n", + "\n", + " # ── Data ──\n", + " print(f\"\\n Loading data...\")\n", + " ds_val = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\")\n", + " eval_texts = [r[\"text\"].strip() for r in ds_val if len(r[\"text\"].strip()) > 50][:N_EVAL]\n", + " ds_train = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"train\")\n", + " heal_texts = [r[\"text\"].strip() for r in ds_train if len(r[\"text\"].strip()) > 100][:N_HEAL]\n", + " print(f\" {len(eval_texts)} eval, {len(heal_texts)} heal texts\")\n", + "\n", + " # ── Teacher ──\n", + " print(f\"\\n Loading BERT-base (teacher)...\")\n", + " tokenizer = BertTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")\n", + " teacher = BertForMaskedLM.from_pretrained(\"google-bert/bert-base-uncased\").to(device)\n", + " teacher.eval()\n", + " for p in teacher.parameters():\n", + " p.requires_grad = False\n", + "\n", + " root_cv, root_params = profile_bert(teacher, \"Teacher 768\")\n", + " root_top1, root_top5 = evaluate_mlm(teacher, tokenizer, eval_texts, device)\n", + " print(f\" Teacher: top1={root_top1:.4f} top5={root_top5:.4f}\")\n", + "\n", + " results = [{\n", + " \"scale\": 768, \"params\": root_params, \"cv\": root_cv,\n", + " \"top1_proj\": root_top1, \"top5_proj\": root_top5,\n", + " \"top1_heal\": root_top1, \"top5_heal\": root_top5,\n", + " }]\n", + "\n", + " parent = teacher\n", + " n_encoder_layers = teacher.config.num_hidden_layers\n", + "\n", + " for i in range(1, len(SCALES)):\n", + " hidden = SCALES[i]\n", + " inter = INTER_SCALES[i]\n", + " parent_hidden = SCALES[i-1]\n", + "\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"SCALE {i}: {parent_hidden} → {hidden} \"\n", + " f\"({(parent_hidden-hidden)/parent_hidden:.0%} reduction)\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " # ── Project ──\n", + " print(f\" Projecting (SVD + L1 FFN prune)...\")\n", + " t0 = time.time()\n", + " child = create_projected_bert(parent, hidden, inter, device)\n", + " print(f\" Projection: {time.time()-t0:.1f}s\")\n", + "\n", + " child_cv, child_params = profile_bert(child, f\"Projected {hidden}\")\n", + " proj_top1, proj_top5 = evaluate_mlm(child, tokenizer, eval_texts, device)\n", + " print(f\" After proj: top1={proj_top1:.4f} top5={proj_top5:.4f}\")\n", + "\n", + " # ── Frozen per-layer projectors ──\n", + " print(f\" Initializing frozen per-layer projectors...\")\n", + " projectors = FrozenLayerProjectors(768, hidden, n_encoder_layers, device)\n", + " projectors.init_from_layer_procrustes(teacher, hidden, device)\n", + "\n", + " # ── Teacher-guided healing ──\n", + " print(f\" Healing ({HEAL_EPOCHS} epochs)...\")\n", + " healer = TeacherGuidedHealerV3(\n", + " teacher, projectors, device,\n", + " mlm_weight=1.0, distill_weight=2.0)\n", + " t0 = time.time()\n", + " heal_loss = healer.heal(child, tokenizer, heal_texts,\n", + " n_epochs=HEAL_EPOCHS, lr=5e-5)\n", + " print(f\" Heal: {time.time()-t0:.1f}s\")\n", + "\n", + " heal_cv, _ = profile_bert(child, f\"Healed {hidden}\")\n", + " heal_top1, heal_top5 = evaluate_mlm(child, tokenizer, eval_texts, device)\n", + " print(f\" After heal: top1={heal_top1:.4f} top5={heal_top5:.4f}\")\n", + "\n", + " results.append({\n", + " \"scale\": hidden, \"params\": child_params, \"cv\": heal_cv,\n", + " \"top1_proj\": proj_top1, \"top5_proj\": proj_top5,\n", + " \"top1_heal\": heal_top1, \"top5_heal\": heal_top5,\n", + " })\n", + "\n", + " parent = child\n", + "\n", + " # ── Report ──\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"RESULTS\")\n", + " print(f\"{'='*70}\\n\")\n", + "\n", + " print(f\" {'Scale':>6s} {'Params':>12s} {'Top1(proj)':>11s} {'Top1(heal)':>11s} \"\n", + " f\"{'Top5(proj)':>11s} {'Top5(heal)':>11s} {'CV':>7s}\")\n", + " print(f\" {'─'*6} {'─'*12} {'─'*11} {'─'*11} {'─'*11} {'─'*11} {'─'*7}\")\n", + "\n", + " for r in results:\n", + " print(f\" {r['scale']:>6d} {r['params']:>12,} {r['top1_proj']:>11.4f} \"\n", + " f\"{r['top1_heal']:>11.4f} {r['top5_proj']:>11.4f} \"\n", + " f\"{r['top5_heal']:>11.4f} {r['cv']:>7.4f}\")\n", + "\n", + " final = results[-1]\n", + " print(f\"\\n SUMMARY:\")\n", + " print(f\" Teacher: {root_params:>12,} top1={root_top1:.4f} top5={root_top5:.4f}\")\n", + " print(f\" Cascade: {final['params']:>12,} \"\n", + " f\"top1={final['top1_heal']:.4f} top5={final['top5_heal']:.4f}\")\n", + " print(f\" Compression: {root_params/final['params']:.1f}×\")\n", + " print(f\" Top1 retained: {final['top1_heal']/root_top1:.1%}\")\n", + "\n", + " print(f\"\\n ALL APPROACHES:\")\n", + " print(f\" v1 Independent SVD + blind MLM: 61.5%\")\n", + " print(f\" v2 + teacher global P: 62.5%\")\n", + " print(f\" v2 + per-layer projectors (buggy): ???\")\n", + " print(f\" v3 all fixes: {final['top1_heal']/root_top1:.1%}\")\n", + "\n", + " print(f\"\\nDone.\")\n", + " return results\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " results = run_experiment()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "ae1bdb89a9704d09a1c03ec82354905e", + "8d03191c19b64a698be6d6fc141817cb", + "8d24919b783b47748aefac2a6c234313", + "38bea3901f6c4e339d17a0269f0e535b", + "e4655e0917814b1b950d53a8f59d1fa5", + "325923e6a4054298be70581c2cd1061d", + "97bcf95ab5ac4823b2f696f42db534da", + "866d9713b1f5436b924a5505e36b9e7d", + "8f481ca713e04111a7dcab82f19a072b", + "8566ff523c5d4bb5b799cdd12a95c633", + "bd53b01466524897aea749b68000684a" + ] + }, + "id": "ddAQ1-RGp3Fx", + "outputId": "a9f9c9f5-fc9a-4120-d1c4-06a50e9f7483" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "======================================================================\n", + "TEACHER-GUIDED CASCADE v3 — ALL FIXES\n", + "======================================================================\n", + " Device: cuda\n", + " Scales: 768 → 720 → 672 → 624 → 576 → 528 → 480 → 432 → 384\n", + " Fixes: padding-masked loss, frozen projectors, SVD bias projection, L1 FFN pruning, seeded eval\n", + "\n", + " Loading data...\n", + " 200 eval, 5000 heal texts\n", + "\n", + " Loading BERT-base (teacher)...\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading weights: 0%| | 0/202 [00:00