Spaces:
Running
Running
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>Compare Yamoe and Binned MoE Implementations</title> | |
| <script> | |
| // Apply theme immediately to prevent flicker | |
| (function() { | |
| const configTheme = 'dark'; | |
| let theme; | |
| if (configTheme === 'auto') { | |
| theme = window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light'; | |
| } else { | |
| theme = localStorage.getItem('uvnote-theme') || configTheme; | |
| } | |
| document.documentElement.setAttribute('data-theme', theme); | |
| })(); | |
| </script> | |
| <style> | |
| :root[data-theme="light"] { | |
| --bg-primary: #ffffff; | |
| --bg-secondary: #f6f8fa; | |
| --bg-tertiary: #f8f9fa; | |
| --bg-code: #f8f9fa; | |
| --bg-error: #fdf2f2; | |
| --bg-artifact: #e6f3ff; | |
| --bg-artifact-hover: #d0e7ff; | |
| --text-primary: #333; | |
| --text-secondary: #656d76; | |
| --text-error: #c53030; | |
| --text-link: #0969da; | |
| --border-primary: #e1e5e9; | |
| --border-error: #e53e3e; | |
| --border-cell-failed: #d73a49; | |
| --shadow: rgba(0, 0, 0, 0.1); | |
| } | |
| :root[data-theme="dark"] { | |
| --bg-primary: #0a0a0a; | |
| --bg-secondary: #121212; | |
| --bg-tertiary: #181818; | |
| --bg-code: #0d0d0d; | |
| --bg-error: #1a0f0f; | |
| --bg-artifact: #151515; | |
| --bg-artifact-hover: #1a1a1a; | |
| --text-primary: #e0e0e0; | |
| --text-secondary: #888888; | |
| --text-error: #ff6b6b; | |
| --text-link: #64b5f6; | |
| --border-primary: #2a2a2a; | |
| --border-error: #ff6b6b; | |
| --border-cell-failed: #ff6b6b; | |
| --shadow: rgba(255, 255, 255, 0.05); | |
| } | |
| body { | |
| font-family: 'Cascadia Mono', 'Cascadia Code', 'JetBrains Mono', 'SF Mono', Monaco, 'Consolas', monospace; | |
| line-height: 1.4; | |
| max-width: 1000px; | |
| margin: 0 auto; | |
| padding: 15px; | |
| color: var(--text-primary); | |
| background: var(--bg-primary); | |
| transition: background-color 0.2s ease, color 0.2s ease; | |
| } | |
| /* Two panel layout removed */ | |
| .controls { | |
| position: fixed; | |
| top: 20px; | |
| right: 20px; | |
| display: flex; | |
| gap: 0.5rem; | |
| z-index: 1000; | |
| } | |
| .theme-toggle, .reset-toggle { | |
| background: var(--bg-secondary); | |
| border: 1px solid var(--border-primary); | |
| border-radius: 2px; | |
| padding: 0.4rem 0.6rem; | |
| cursor: pointer; | |
| font-family: inherit; | |
| font-size: 0.8rem; | |
| color: var(--text-secondary); | |
| user-select: none; | |
| transition: all 0.2s ease; | |
| text-transform: lowercase; | |
| letter-spacing: 0; | |
| } | |
| .theme-toggle:hover, .reset-toggle:hover { | |
| background: var(--bg-tertiary); | |
| border-color: var(--text-secondary); | |
| color: var(--text-primary); | |
| } | |
| .minimap { | |
| position: fixed; | |
| bottom: 20px; | |
| right: 20px; | |
| width: 220px; | |
| max-height: 400px; | |
| background: var(--bg-secondary); | |
| border: 1px solid var(--border-primary); | |
| border-radius: 2px; | |
| padding: 0.5rem; | |
| font-size: 0.7rem; | |
| overflow-y: auto; | |
| z-index: 100; | |
| opacity: 0.9; | |
| transition: opacity 0.2s ease; | |
| } | |
| .file-explorer { | |
| position: fixed; | |
| bottom: 20px; /* default; JS will stack */ | |
| right: 20px; | |
| left: auto; | |
| top: auto; | |
| width: 220px; | |
| max-height: 400px; | |
| background: var(--bg-secondary); | |
| border: 1px solid var(--border-primary); | |
| border-radius: 2px; | |
| padding: 0.5rem; | |
| font-size: 0.7rem; | |
| overflow-y: auto; | |
| z-index: 100; | |
| opacity: 0.9; | |
| transition: opacity 0.2s ease; | |
| } | |
| /* Drawing overlay */ | |
| .draw-overlay { | |
| position: fixed; | |
| top: 0; | |
| left: 0; | |
| width: 100vw; | |
| height: 100vh; | |
| z-index: 80; /* under widgets (100) and controls (1000) */ | |
| display: block; | |
| pointer-events: none; /* enabled only when a tool is active */ | |
| } | |
| /* Tools widget */ | |
| .tools-widget { | |
| position: fixed; | |
| bottom: 20px; /* default; JS will stack */ | |
| right: 20px; | |
| left: auto; | |
| top: auto; | |
| width: 220px; | |
| background: var(--bg-secondary); | |
| border: 1px solid var(--border-primary); | |
| border-radius: 2px; | |
| padding: 0.5rem; | |
| font-size: 0.7rem; | |
| z-index: 100; | |
| opacity: 0.95; | |
| } | |
| .tools-title { | |
| font-weight: bold; | |
| color: var(--text-secondary); | |
| margin-bottom: 0.5rem; | |
| padding-bottom: 0.25rem; | |
| border-bottom: 1px solid var(--border-primary); | |
| cursor: grab; | |
| user-select: none; | |
| } | |
| .tools-row { display: flex; gap: 0.4rem; flex-wrap: wrap; } | |
| .tool-button { | |
| background: var(--bg-tertiary); | |
| border: 1px solid var(--border-primary); | |
| border-radius: 2px; | |
| padding: 0.25rem 0.4rem; | |
| cursor: pointer; | |
| color: var(--text-secondary); | |
| font-family: inherit; | |
| font-size: 0.75rem; | |
| user-select: none; | |
| } | |
| .tool-button:hover { color: var(--text-primary); } | |
| .tool-button.active { color: var(--text-primary); border-color: var(--text-secondary); background: var(--bg-secondary); } | |
| .minimap:hover, .file-explorer:hover { | |
| opacity: 1; | |
| } | |
| .minimap-title { | |
| font-weight: bold; | |
| color: var(--text-secondary); | |
| margin-bottom: 0.5rem; | |
| padding-bottom: 0.25rem; | |
| border-bottom: 1px solid var(--border-primary); | |
| cursor: grab; /* drag handle */ | |
| user-select: none; | |
| } | |
| .minimap-item { | |
| display: block; | |
| color: var(--text-secondary); | |
| text-decoration: none; | |
| padding: 0.15rem 0; | |
| border-left: 2px solid transparent; | |
| padding-left: 0.5rem; | |
| transition: all 0.2s ease; | |
| cursor: pointer; | |
| } | |
| .minimap-item:hover { | |
| color: var(--text-primary); | |
| border-left-color: var(--text-secondary); | |
| } | |
| .minimap-item.active { | |
| color: var(--text-primary); | |
| border-left-color: var(--text-link); | |
| } | |
| .minimap-heading { | |
| font-weight: normal; | |
| } | |
| .minimap-heading.h1 { padding-left: 0.5rem; } | |
| .minimap-heading.h2 { padding-left: 1rem; } | |
| .minimap-heading.h3 { padding-left: 1.5rem; } | |
| .minimap-heading.h4 { padding-left: 2rem; } | |
| .minimap-heading.h5 { padding-left: 2.5rem; } | |
| .minimap-heading.h6 { padding-left: 3rem; } | |
| .minimap-cell { | |
| color: var(--text-link); | |
| opacity: 0.8; | |
| font-style: italic; | |
| } | |
| .minimap-cell:hover { | |
| opacity: 1; | |
| } | |
| .file-explorer-title { | |
| font-weight: bold; | |
| color: var(--text-secondary); | |
| margin-bottom: 0.5rem; | |
| padding-bottom: 0.25rem; | |
| border-bottom: 1px solid var(--border-primary); | |
| cursor: grab; /* drag handle */ | |
| user-select: none; | |
| } | |
| .file-explorer-section { | |
| margin-bottom: 0.75rem; | |
| } | |
| .file-explorer-section-title { | |
| font-weight: bold; | |
| color: var(--text-secondary); | |
| font-size: 0.65rem; | |
| margin-bottom: 0.25rem; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| } | |
| .file-explorer-item { | |
| display: block; | |
| color: var(--text-secondary); | |
| text-decoration: none; | |
| padding: 0.1rem 0; | |
| margin-left: 0.5rem; | |
| transition: color 0.2s ease; | |
| cursor: pointer; | |
| font-family: monospace; | |
| } | |
| .file-explorer-item:hover { | |
| color: var(--text-primary); | |
| } | |
| .file-explorer-item.script { | |
| color: var(--text-link); | |
| } | |
| .file-explorer-item.artifact { | |
| color: var(--text-secondary); | |
| opacity: 0.8; | |
| } | |
| /* Slide functionality */ | |
| .minimap, .file-explorer, .tools-widget { | |
| transition: transform 0.3s ease; | |
| } | |
| .minimap.slide-off, .file-explorer.slide-off, .tools-widget.slide-off { | |
| transform: translateX(calc(100% - 20px)); | |
| } | |
| .minimap-title::before, .file-explorer-title::before, .tools-title::before { | |
| content: '‹'; | |
| float: left; | |
| cursor: pointer; | |
| color: var(--text-secondary); | |
| user-select: none; | |
| margin-right: 8px; | |
| } | |
| .minimap-title::after, .file-explorer-title::after, .tools-title::after { | |
| content: '›'; | |
| float: right; | |
| cursor: pointer; | |
| color: var(--text-secondary); | |
| user-select: none; | |
| } | |
| .minimap.slide-off .minimap-title::after, | |
| .file-explorer.slide-off .file-explorer-title::after, | |
| .tools-widget.slide-off .tools-title::after { | |
| content: '‹'; | |
| } | |
| /* Hide widgets on smaller screens */ | |
| @media (max-width: 768px) { | |
| .minimap, .file-explorer, .tools-widget { | |
| display: none; | |
| } | |
| } | |
| .cell { | |
| margin: 1rem 0; | |
| border: 1px solid var(--border-primary); | |
| border-radius: 2px; | |
| overflow: hidden; | |
| background: var(--bg-secondary); | |
| } | |
| .cell-header { | |
| background: var(--bg-secondary); | |
| padding: 0.5rem 1rem; | |
| border-bottom: 1px solid var(--border-primary); | |
| font-family: inherit; | |
| font-size: 0.85rem; | |
| color: var(--text-secondary); | |
| cursor: pointer; | |
| user-select: none; | |
| transition: background-color 0.2s ease; | |
| } | |
| .cell-header:hover { | |
| background: var(--bg-tertiary); | |
| } | |
| .collapse-indicators { | |
| color: var(--text-secondary); | |
| font-size: 0.8rem; | |
| opacity: 0.7; | |
| } | |
| .collapse-indicators span:hover { | |
| color: var(--text-primary); | |
| opacity: 1; | |
| } | |
| .cell-code { | |
| display: block; | |
| background: var(--bg-code); | |
| } | |
| .cell-code.collapsed { | |
| display: none; | |
| } | |
| .cell-code pre { | |
| margin: 0; | |
| padding: 0.75rem; | |
| background: var(--bg-code); | |
| overflow-x: auto; | |
| color: var(--text-primary); | |
| } | |
| .cell-output { | |
| padding: 0.75rem; | |
| background: var(--bg-primary); | |
| } | |
| .cell-output.collapsed { | |
| display: none; | |
| } | |
| .cell-stdout { | |
| background: var(--bg-tertiary); | |
| padding: 0.75rem; | |
| border-radius: 1px; | |
| margin: 0.25rem 0; | |
| font-family: inherit; | |
| font-size: 0.9rem; | |
| white-space: pre-wrap; | |
| color: var(--text-primary); | |
| } | |
| .cell-stderr { | |
| background: var(--bg-error); | |
| border-left: 2px solid var(--border-error); | |
| padding: 1rem; | |
| margin: 0.5rem 0; | |
| font-family: inherit; | |
| font-size: 0.9rem; | |
| color: var(--text-error); | |
| white-space: pre-wrap; | |
| } | |
| .cell-artifacts { | |
| margin: 1rem 0; | |
| } | |
| .cell-artifacts h4 { | |
| margin: 0 0 0.5rem 0; | |
| color: var(--text-secondary); | |
| font-size: 0.9rem; | |
| } | |
| .artifact { | |
| display: inline-block; | |
| background: var(--bg-artifact); | |
| padding: 0.25rem 0.5rem; | |
| border-radius: 1px; | |
| margin: 0.25rem 0.5rem 0.25rem 0; | |
| font-family: inherit; | |
| font-size: 0.8rem; | |
| color: var(--text-link); | |
| text-decoration: none; | |
| transition: background-color 0.2s ease; | |
| border: 1px solid var(--border-primary); | |
| } | |
| .artifact:hover { | |
| background: var(--bg-artifact-hover); | |
| } | |
| .artifact-preview { | |
| margin-top: 1rem; | |
| } | |
| .artifact-preview img { | |
| max-width: 100%; | |
| height: auto; | |
| border: 1px solid var(--border-primary); | |
| border-radius: 1px; | |
| } | |
| .artifact-preview svg { | |
| max-width: 100%; | |
| height: auto; | |
| border: 1px solid var(--border-primary); | |
| border-radius: 1px; | |
| display: block; | |
| } | |
| /* Style SVG text elements */ | |
| .artifact-preview svg g { | |
| fill: var(--text-primary) ; | |
| } | |
| /* Auto-theme SVG elements */ | |
| .artifact-preview svg { | |
| background: transparent; | |
| } | |
| .cell-failed { | |
| border-color: var(--border-cell-failed); | |
| } | |
| .cell-failed .cell-header { | |
| background: var(--bg-error); | |
| color: var(--text-error); | |
| } | |
| .run-btn { | |
| background: var(--bg-tertiary); | |
| border: 1px solid var(--border-primary); | |
| padding: 2px 6px; | |
| border-radius: 2px; | |
| color: var(--text-secondary); | |
| cursor: pointer; | |
| font-size: 0.75em; | |
| font-family: inherit; | |
| margin-left: 4px; | |
| } | |
| .run-btn:hover { | |
| color: var(--text-primary); | |
| background: var(--bg-primary); | |
| } | |
| .run-btn:disabled { | |
| opacity: 0.6; | |
| cursor: not-allowed; | |
| } | |
| .copy-btn { | |
| background: var(--bg-tertiary); | |
| border: 1px solid var(--border-primary); | |
| padding: 2px 6px; | |
| border-radius: 2px; | |
| color: var(--text-secondary); | |
| cursor: pointer; | |
| font-size: 0.75em; | |
| font-family: inherit; | |
| margin-left: 4px; | |
| } | |
| .copy-btn:hover { | |
| color: var(--text-primary); | |
| background: var(--bg-primary); | |
| } | |
| .copy-btn:disabled { | |
| opacity: 0.6; | |
| cursor: not-allowed; | |
| } | |
| .output-stale { | |
| opacity: 0.5; | |
| position: relative; | |
| } | |
| .output-stale::after { | |
| content: '⏳ updating...'; | |
| position: absolute; | |
| top: 8px; | |
| right: 8px; | |
| background: var(--bg-secondary); | |
| padding: 4px 8px; | |
| border-radius: 2px; | |
| font-size: 0.75em; | |
| color: var(--text-secondary); | |
| border: 1px solid var(--border-primary); | |
| } | |
| h1, h2, h3, h4, h5, h6 { | |
| margin-top: 1.5rem; | |
| margin-bottom: 0.75rem; | |
| color: var(--text-primary); | |
| } | |
| h1 { | |
| margin-top: 0; | |
| margin-bottom: 1rem; | |
| } | |
| p { | |
| margin: 0.75rem 0; | |
| color: var(--text-primary); | |
| } | |
| a { | |
| color: var(--text-link); | |
| } | |
| img { | |
| max-width: 100%; | |
| height: auto; | |
| border-radius: 1px; | |
| box-shadow: none; | |
| } | |
| pre, code { | |
| font-family: 'Cascadia Mono', 'Cascadia Code', 'JetBrains Mono', 'SF Mono', Monaco, 'Consolas', monospace; | |
| } | |
| /* Line numbers */ | |
| .highlight-with-lines { | |
| display: flex; | |
| } | |
| .line-numbers { | |
| background: var(--bg-tertiary); | |
| padding: 0.75rem 0.5rem; | |
| font-family: 'Cascadia Mono', 'Cascadia Code', 'JetBrains Mono', 'SF Mono', Monaco, 'Consolas', monospace; | |
| font-size: 0.9rem; | |
| color: var(--text-secondary); | |
| user-select: none; | |
| text-align: right; | |
| border-right: 1px solid var(--border-primary); | |
| } | |
| .line-numbers .line-number { | |
| display: block; | |
| line-height: 1.5; | |
| } | |
| .highlight-with-lines .highlight { | |
| flex: 1; | |
| } | |
| .highlight-with-lines .highlight pre { | |
| padding-left: 0.75rem; | |
| } | |
| /* Collapsed code styling */ | |
| .cell-code.collapsed { | |
| display: none; | |
| } | |
| .cell-code.expanded { | |
| display: block; | |
| } | |
| .cell-code { | |
| display: block; | |
| } | |
| pre { line-height: 125%; } | |
| td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } | |
| span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } | |
| td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } | |
| span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } | |
| [data-theme="light"] .highlight .hll { background-color: #ffffcc } | |
| [data-theme="light"] .highlight { background: #f8f8f8; } | |
| [data-theme="light"] .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */ | |
| [data-theme="light"] .highlight .err { border: 1px solid #F00 } /* Error */ | |
| [data-theme="light"] .highlight .k { color: #008000; font-weight: bold } /* Keyword */ | |
| [data-theme="light"] .highlight .o { color: #666 } /* Operator */ | |
| [data-theme="light"] .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ | |
| [data-theme="light"] .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ | |
| [data-theme="light"] .highlight .cp { color: #9C6500 } /* Comment.Preproc */ | |
| [data-theme="light"] .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ | |
| [data-theme="light"] .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ | |
| [data-theme="light"] .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ | |
| [data-theme="light"] .highlight .gd { color: #A00000 } /* Generic.Deleted */ | |
| [data-theme="light"] .highlight .ge { font-style: italic } /* Generic.Emph */ | |
| [data-theme="light"] .highlight .ges { font-weight: bold; font-style: italic } /* Generic.EmphStrong */ | |
| [data-theme="light"] .highlight .gr { color: #E40000 } /* Generic.Error */ | |
| [data-theme="light"] .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ | |
| [data-theme="light"] .highlight .gi { color: #008400 } /* Generic.Inserted */ | |
| [data-theme="light"] .highlight .go { color: #717171 } /* Generic.Output */ | |
| [data-theme="light"] .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ | |
| [data-theme="light"] .highlight .gs { font-weight: bold } /* Generic.Strong */ | |
| [data-theme="light"] .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ | |
| [data-theme="light"] .highlight .gt { color: #04D } /* Generic.Traceback */ | |
| [data-theme="light"] .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ | |
| [data-theme="light"] .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ | |
| [data-theme="light"] .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ | |
| [data-theme="light"] .highlight .kp { color: #008000 } /* Keyword.Pseudo */ | |
| [data-theme="light"] .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ | |
| [data-theme="light"] .highlight .kt { color: #B00040 } /* Keyword.Type */ | |
| [data-theme="light"] .highlight .m { color: #666 } /* Literal.Number */ | |
| [data-theme="light"] .highlight .s { color: #BA2121 } /* Literal.String */ | |
| [data-theme="light"] .highlight .na { color: #687822 } /* Name.Attribute */ | |
| [data-theme="light"] .highlight .nb { color: #008000 } /* Name.Builtin */ | |
| [data-theme="light"] .highlight .nc { color: #00F; font-weight: bold } /* Name.Class */ | |
| [data-theme="light"] .highlight .no { color: #800 } /* Name.Constant */ | |
| [data-theme="light"] .highlight .nd { color: #A2F } /* Name.Decorator */ | |
| [data-theme="light"] .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */ | |
| [data-theme="light"] .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ | |
| [data-theme="light"] .highlight .nf { color: #00F } /* Name.Function */ | |
| [data-theme="light"] .highlight .nl { color: #767600 } /* Name.Label */ | |
| [data-theme="light"] .highlight .nn { color: #00F; font-weight: bold } /* Name.Namespace */ | |
| [data-theme="light"] .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ | |
| [data-theme="light"] .highlight .nv { color: #19177C } /* Name.Variable */ | |
| [data-theme="light"] .highlight .ow { color: #A2F; font-weight: bold } /* Operator.Word */ | |
| [data-theme="light"] .highlight .w { color: #BBB } /* Text.Whitespace */ | |
| [data-theme="light"] .highlight .mb { color: #666 } /* Literal.Number.Bin */ | |
| [data-theme="light"] .highlight .mf { color: #666 } /* Literal.Number.Float */ | |
| [data-theme="light"] .highlight .mh { color: #666 } /* Literal.Number.Hex */ | |
| [data-theme="light"] .highlight .mi { color: #666 } /* Literal.Number.Integer */ | |
| [data-theme="light"] .highlight .mo { color: #666 } /* Literal.Number.Oct */ | |
| [data-theme="light"] .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ | |
| [data-theme="light"] .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ | |
| [data-theme="light"] .highlight .sc { color: #BA2121 } /* Literal.String.Char */ | |
| [data-theme="light"] .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ | |
| [data-theme="light"] .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ | |
| [data-theme="light"] .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ | |
| [data-theme="light"] .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ | |
| [data-theme="light"] .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ | |
| [data-theme="light"] .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ | |
| [data-theme="light"] .highlight .sx { color: #008000 } /* Literal.String.Other */ | |
| [data-theme="light"] .highlight .sr { color: #A45A77 } /* Literal.String.Regex */ | |
| [data-theme="light"] .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ | |
| [data-theme="light"] .highlight .ss { color: #19177C } /* Literal.String.Symbol */ | |
| [data-theme="light"] .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ | |
| [data-theme="light"] .highlight .fm { color: #00F } /* Name.Function.Magic */ | |
| [data-theme="light"] .highlight .vc { color: #19177C } /* Name.Variable.Class */ | |
| [data-theme="light"] .highlight .vg { color: #19177C } /* Name.Variable.Global */ | |
| [data-theme="light"] .highlight .vi { color: #19177C } /* Name.Variable.Instance */ | |
| [data-theme="light"] .highlight .vm { color: #19177C } /* Name.Variable.Magic */ | |
| [data-theme="light"] .highlight .il { color: #666 } /* Literal.Number.Integer.Long */ | |
| pre { line-height: 125%; } | |
| td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } | |
| span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } | |
| td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } | |
| span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } | |
| [data-theme="dark"] .highlight .hll { background-color: #49483e } | |
| [data-theme="dark"] .highlight { background: #272822; color: #F8F8F2 } | |
| [data-theme="dark"] .highlight .c { color: #959077 } /* Comment */ | |
| [data-theme="dark"] .highlight .err { color: #ED007E; background-color: #1E0010 } /* Error */ | |
| [data-theme="dark"] .highlight .esc { color: #F8F8F2 } /* Escape */ | |
| [data-theme="dark"] .highlight .g { color: #F8F8F2 } /* Generic */ | |
| [data-theme="dark"] .highlight .k { color: #66D9EF } /* Keyword */ | |
| [data-theme="dark"] .highlight .l { color: #AE81FF } /* Literal */ | |
| [data-theme="dark"] .highlight .n { color: #F8F8F2 } /* Name */ | |
| [data-theme="dark"] .highlight .o { color: #FF4689 } /* Operator */ | |
| [data-theme="dark"] .highlight .x { color: #F8F8F2 } /* Other */ | |
| [data-theme="dark"] .highlight .p { color: #F8F8F2 } /* Punctuation */ | |
| [data-theme="dark"] .highlight .ch { color: #959077 } /* Comment.Hashbang */ | |
| [data-theme="dark"] .highlight .cm { color: #959077 } /* Comment.Multiline */ | |
| [data-theme="dark"] .highlight .cp { color: #959077 } /* Comment.Preproc */ | |
| [data-theme="dark"] .highlight .cpf { color: #959077 } /* Comment.PreprocFile */ | |
| [data-theme="dark"] .highlight .c1 { color: #959077 } /* Comment.Single */ | |
| [data-theme="dark"] .highlight .cs { color: #959077 } /* Comment.Special */ | |
| [data-theme="dark"] .highlight .gd { color: #FF4689 } /* Generic.Deleted */ | |
| [data-theme="dark"] .highlight .ge { color: #F8F8F2; font-style: italic } /* Generic.Emph */ | |
| [data-theme="dark"] .highlight .ges { color: #F8F8F2; font-weight: bold; font-style: italic } /* Generic.EmphStrong */ | |
| [data-theme="dark"] .highlight .gr { color: #F8F8F2 } /* Generic.Error */ | |
| [data-theme="dark"] .highlight .gh { color: #F8F8F2 } /* Generic.Heading */ | |
| [data-theme="dark"] .highlight .gi { color: #A6E22E } /* Generic.Inserted */ | |
| [data-theme="dark"] .highlight .go { color: #66D9EF } /* Generic.Output */ | |
| [data-theme="dark"] .highlight .gp { color: #FF4689; font-weight: bold } /* Generic.Prompt */ | |
| [data-theme="dark"] .highlight .gs { color: #F8F8F2; font-weight: bold } /* Generic.Strong */ | |
| [data-theme="dark"] .highlight .gu { color: #959077 } /* Generic.Subheading */ | |
| [data-theme="dark"] .highlight .gt { color: #F8F8F2 } /* Generic.Traceback */ | |
| [data-theme="dark"] .highlight .kc { color: #66D9EF } /* Keyword.Constant */ | |
| [data-theme="dark"] .highlight .kd { color: #66D9EF } /* Keyword.Declaration */ | |
| [data-theme="dark"] .highlight .kn { color: #FF4689 } /* Keyword.Namespace */ | |
| [data-theme="dark"] .highlight .kp { color: #66D9EF } /* Keyword.Pseudo */ | |
| [data-theme="dark"] .highlight .kr { color: #66D9EF } /* Keyword.Reserved */ | |
| [data-theme="dark"] .highlight .kt { color: #66D9EF } /* Keyword.Type */ | |
| [data-theme="dark"] .highlight .ld { color: #E6DB74 } /* Literal.Date */ | |
| [data-theme="dark"] .highlight .m { color: #AE81FF } /* Literal.Number */ | |
| [data-theme="dark"] .highlight .s { color: #E6DB74 } /* Literal.String */ | |
| [data-theme="dark"] .highlight .na { color: #A6E22E } /* Name.Attribute */ | |
| [data-theme="dark"] .highlight .nb { color: #F8F8F2 } /* Name.Builtin */ | |
| [data-theme="dark"] .highlight .nc { color: #A6E22E } /* Name.Class */ | |
| [data-theme="dark"] .highlight .no { color: #66D9EF } /* Name.Constant */ | |
| [data-theme="dark"] .highlight .nd { color: #A6E22E } /* Name.Decorator */ | |
| [data-theme="dark"] .highlight .ni { color: #F8F8F2 } /* Name.Entity */ | |
| [data-theme="dark"] .highlight .ne { color: #A6E22E } /* Name.Exception */ | |
| [data-theme="dark"] .highlight .nf { color: #A6E22E } /* Name.Function */ | |
| [data-theme="dark"] .highlight .nl { color: #F8F8F2 } /* Name.Label */ | |
| [data-theme="dark"] .highlight .nn { color: #F8F8F2 } /* Name.Namespace */ | |
| [data-theme="dark"] .highlight .nx { color: #A6E22E } /* Name.Other */ | |
| [data-theme="dark"] .highlight .py { color: #F8F8F2 } /* Name.Property */ | |
| [data-theme="dark"] .highlight .nt { color: #FF4689 } /* Name.Tag */ | |
| [data-theme="dark"] .highlight .nv { color: #F8F8F2 } /* Name.Variable */ | |
| [data-theme="dark"] .highlight .ow { color: #FF4689 } /* Operator.Word */ | |
| [data-theme="dark"] .highlight .pm { color: #F8F8F2 } /* Punctuation.Marker */ | |
| [data-theme="dark"] .highlight .w { color: #F8F8F2 } /* Text.Whitespace */ | |
| [data-theme="dark"] .highlight .mb { color: #AE81FF } /* Literal.Number.Bin */ | |
| [data-theme="dark"] .highlight .mf { color: #AE81FF } /* Literal.Number.Float */ | |
| [data-theme="dark"] .highlight .mh { color: #AE81FF } /* Literal.Number.Hex */ | |
| [data-theme="dark"] .highlight .mi { color: #AE81FF } /* Literal.Number.Integer */ | |
| [data-theme="dark"] .highlight .mo { color: #AE81FF } /* Literal.Number.Oct */ | |
| [data-theme="dark"] .highlight .sa { color: #E6DB74 } /* Literal.String.Affix */ | |
| [data-theme="dark"] .highlight .sb { color: #E6DB74 } /* Literal.String.Backtick */ | |
| [data-theme="dark"] .highlight .sc { color: #E6DB74 } /* Literal.String.Char */ | |
| [data-theme="dark"] .highlight .dl { color: #E6DB74 } /* Literal.String.Delimiter */ | |
| [data-theme="dark"] .highlight .sd { color: #E6DB74 } /* Literal.String.Doc */ | |
| [data-theme="dark"] .highlight .s2 { color: #E6DB74 } /* Literal.String.Double */ | |
| [data-theme="dark"] .highlight .se { color: #AE81FF } /* Literal.String.Escape */ | |
| [data-theme="dark"] .highlight .sh { color: #E6DB74 } /* Literal.String.Heredoc */ | |
| [data-theme="dark"] .highlight .si { color: #E6DB74 } /* Literal.String.Interpol */ | |
| [data-theme="dark"] .highlight .sx { color: #E6DB74 } /* Literal.String.Other */ | |
| [data-theme="dark"] .highlight .sr { color: #E6DB74 } /* Literal.String.Regex */ | |
| [data-theme="dark"] .highlight .s1 { color: #E6DB74 } /* Literal.String.Single */ | |
| [data-theme="dark"] .highlight .ss { color: #E6DB74 } /* Literal.String.Symbol */ | |
| [data-theme="dark"] .highlight .bp { color: #F8F8F2 } /* Name.Builtin.Pseudo */ | |
| [data-theme="dark"] .highlight .fm { color: #A6E22E } /* Name.Function.Magic */ | |
| [data-theme="dark"] .highlight .vc { color: #F8F8F2 } /* Name.Variable.Class */ | |
| [data-theme="dark"] .highlight .vg { color: #F8F8F2 } /* Name.Variable.Global */ | |
| [data-theme="dark"] .highlight .vi { color: #F8F8F2 } /* Name.Variable.Instance */ | |
| [data-theme="dark"] .highlight .vm { color: #F8F8F2 } /* Name.Variable.Magic */ | |
| [data-theme="dark"] .highlight .il { color: #AE81FF } /* Literal.Number.Integer.Long */ | |
| /* Custom CSS from frontmatter */ | |
| .cell-stderr { display: none; } | |
| .minimap { display: none ; } | |
| .file-explorer { display: none ; } | |
| .cell-code { max-height: 400px; overflow: auto; } | |
| /* Cursor for tools */ | |
| body[data-tool="arrow"] .main-content { cursor: crosshair; } | |
| body[data-tool="pen"] .main-content { cursor: pointer; } | |
| body[data-tool="eraser"] .main-content { cursor: cell; } | |
| /* Color picker styles */ | |
| .tools-section-title { | |
| font-weight: bold; | |
| color: var(--text-secondary); | |
| font-size: 0.65rem; | |
| margin: 0.75rem 0 0.5rem 0; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| } | |
| .color-row { | |
| display: grid; | |
| grid-template-columns: repeat(6, 1fr); | |
| gap: 0.25rem; | |
| margin-bottom: 0.5rem; | |
| } | |
| .color-swatch { | |
| width: 18px; | |
| height: 18px; | |
| border: 2px solid var(--border-primary); | |
| border-radius: 3px; | |
| cursor: pointer; | |
| transition: all 0.2s ease; | |
| position: relative; | |
| } | |
| .color-swatch:hover { | |
| transform: scale(1.1); | |
| border-color: var(--text-secondary); | |
| } | |
| .color-swatch.selected { | |
| border-color: var(--text-primary); | |
| box-shadow: 0 0 0 2px var(--text-link); | |
| } | |
| .color-swatch.selected::after { | |
| content: '✓'; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| color: white; | |
| font-size: 10px; | |
| font-weight: bold; | |
| text-shadow: 1px 1px 1px black; | |
| } | |
| .color-input { | |
| width: 24px; | |
| height: 24px; | |
| border: 2px solid var(--border-primary); | |
| border-radius: 3px; | |
| cursor: pointer; | |
| background: none; | |
| padding: 0; | |
| grid-column: span 2; | |
| justify-self: center; | |
| } | |
| .color-input:hover { | |
| border-color: var(--text-secondary); | |
| } | |
| /* Thickness slider styles */ | |
| .thickness-row { | |
| display: flex; | |
| align-items: center; | |
| gap: 0.5rem; | |
| margin-top: 0.75rem; | |
| } | |
| .thickness-slider { | |
| flex: 1; | |
| -webkit-appearance: none; | |
| appearance: none; | |
| height: 4px; | |
| background: var(--border-primary); | |
| border-radius: 2px; | |
| outline: none; | |
| opacity: 0.7; | |
| transition: opacity 0.2s; | |
| } | |
| .thickness-slider:hover { | |
| opacity: 1; | |
| } | |
| .thickness-slider::-webkit-slider-thumb { | |
| -webkit-appearance: none; | |
| appearance: none; | |
| width: 12px; | |
| height: 12px; | |
| background: var(--text-link); | |
| border-radius: 50%; | |
| cursor: pointer; | |
| } | |
| .thickness-slider::-moz-range-thumb { | |
| width: 12px; | |
| height: 12px; | |
| background: var(--text-link); | |
| border-radius: 50%; | |
| cursor: pointer; | |
| border: none; | |
| } | |
| .thickness-value { | |
| font-size: 0.7rem; | |
| color: var(--text-secondary); | |
| min-width: 20px; | |
| text-align: right; | |
| } | |
| .highlight { | |
| background: none ; | |
| } | |
| /* Loading animations */ | |
| .loading-spinner { | |
| display: inline-block; | |
| width: 16px; | |
| height: 16px; | |
| border: 2px solid var(--border-primary); | |
| border-radius: 50%; | |
| border-top-color: var(--text-link); | |
| animation: spin 1s linear infinite; | |
| margin-right: 8px; | |
| vertical-align: middle; | |
| } | |
| @keyframes spin { | |
| to { transform: rotate(360deg); } | |
| } | |
| .loading-skeleton { | |
| display: inline-block; | |
| background: var(--bg-tertiary); | |
| background: linear-gradient( | |
| 90deg, | |
| var(--bg-tertiary) 25%, | |
| var(--bg-secondary) 50%, | |
| var(--bg-tertiary) 75% | |
| ); | |
| background-size: 200% 100%; | |
| animation: loading-shimmer 2s ease-in-out infinite; | |
| border-radius: 2px; | |
| height: 1em; | |
| width: 80px; | |
| vertical-align: middle; | |
| } | |
| @keyframes loading-shimmer { | |
| 0% { background-position: -200% 0; } | |
| 100% { background-position: 200% 0; } | |
| } | |
| /* Loading state for cell output */ | |
| .cell-output:has(.loading-spinner) { | |
| opacity: 0.7; | |
| background: var(--bg-secondary); | |
| border-left: 3px solid var(--text-link); | |
| } | |
| </style> | |
| <script> | |
| // --- Drag utilities --- | |
| function clamp(val, min, max) { return Math.max(min, Math.min(max, val)); } | |
| function restorePosition(el, storageKey) { | |
| try { | |
| const raw = localStorage.getItem(storageKey); | |
| if (!raw) return; | |
| const pos = JSON.parse(raw); | |
| if (typeof pos.left === 'number' && typeof pos.top === 'number') { | |
| el.style.left = pos.left + 'px'; | |
| el.style.top = pos.top + 'px'; | |
| el.style.right = 'auto'; | |
| el.style.bottom = 'auto'; | |
| } | |
| } catch (_) {} | |
| } | |
| function savePosition(el, storageKey) { | |
| try { | |
| const left = parseFloat(el.style.left || 'NaN'); | |
| const top = parseFloat(el.style.top || 'NaN'); | |
| if (!Number.isNaN(left) && !Number.isNaN(top)) { | |
| localStorage.setItem(storageKey, JSON.stringify({ left, top })); | |
| } | |
| } catch (_) {} | |
| } | |
| function addSlideToggle(widget, titleEl) { | |
| titleEl.onclick = function(e) { | |
| const rect = titleEl.getBoundingClientRect(); | |
| const clickX = e.clientX - rect.left; | |
| // Left arrow (always slides back on screen) | |
| if (clickX < 30) { | |
| widget.classList.remove('slide-off'); | |
| e.stopPropagation(); | |
| } | |
| // Right arrow (always slides off screen) | |
| else if (clickX > rect.width - 30) { | |
| widget.classList.add('slide-off'); | |
| e.stopPropagation(); | |
| } | |
| }; | |
| } | |
| function makeDraggable(el, storageKey, handleEl) { | |
| let dragging = false; | |
| let startX = 0, startY = 0; // cursor | |
| let origLeft = 0, origTop = 0; // element | |
| const onMove = (e) => { | |
| if (!dragging) return; | |
| const clientX = e.touches ? e.touches[0].clientX : e.clientX; | |
| const clientY = e.touches ? e.touches[0].clientY : e.clientY; | |
| const dx = clientX - startX; | |
| const dy = clientY - startY; | |
| const w = el.offsetWidth; | |
| const h = el.offsetHeight; | |
| const maxX = window.innerWidth - w; | |
| const maxY = window.innerHeight - h; | |
| const newLeft = clamp(origLeft + dx, 0, maxX); | |
| const newTop = clamp(origTop + dy, 0, maxY); | |
| el.style.left = newLeft + 'px'; | |
| el.style.top = newTop + 'px'; | |
| el.style.right = 'auto'; | |
| el.style.bottom = 'auto'; | |
| }; | |
| const endDrag = () => { | |
| if (!dragging) return; | |
| dragging = false; | |
| document.removeEventListener('mousemove', onMove); | |
| document.removeEventListener('mouseup', endDrag); | |
| document.removeEventListener('touchmove', onMove); | |
| document.removeEventListener('touchend', endDrag); | |
| handleEl && (handleEl.style.cursor = 'grab'); | |
| savePosition(el, storageKey); | |
| // ensure no-overlap constraint after a drag | |
| try { layoutWidgetsStackedBottomRight(); } catch (_) {} | |
| }; | |
| const startDrag = (e) => { | |
| // Check if click is on arrow areas - if so, don't start drag | |
| if (handleEl) { | |
| const rect = handleEl.getBoundingClientRect(); | |
| const clickX = e.clientX - rect.left; | |
| if (clickX < 30 || clickX > rect.width - 30) { | |
| return; // Don't start drag on arrow areas | |
| } | |
| } | |
| // Start from element's current on-screen rect | |
| const elRect = el.getBoundingClientRect(); | |
| el.style.left = elRect.left + 'px'; | |
| el.style.top = elRect.top + 'px'; | |
| el.style.right = 'auto'; | |
| el.style.bottom = 'auto'; | |
| dragging = true; | |
| startX = e.touches ? e.touches[0].clientX : e.clientX; | |
| startY = e.touches ? e.touches[0].clientY : e.clientY; | |
| origLeft = elRect.left; | |
| origTop = elRect.top; | |
| document.addEventListener('mousemove', onMove); | |
| document.addEventListener('mouseup', endDrag); | |
| document.addEventListener('touchmove', onMove, { passive: false }); | |
| document.addEventListener('touchend', endDrag); | |
| handleEl && (handleEl.style.cursor = 'grabbing'); | |
| e.preventDefault(); | |
| }; | |
| (handleEl || el).addEventListener('mousedown', startDrag); | |
| (handleEl || el).addEventListener('touchstart', startDrag, { passive: false }); | |
| // Apply any saved position on init | |
| restorePosition(el, storageKey); | |
| } | |
| function toggleCell(cellId) { | |
| const codeElement = document.getElementById('code-' + cellId); | |
| const outputElement = document.getElementById('output-' + cellId); | |
| if (codeElement) { | |
| codeElement.classList.toggle('collapsed'); | |
| } | |
| if (outputElement) { | |
| outputElement.classList.toggle('collapsed'); | |
| } | |
| updateIndicators(cellId); | |
| } | |
| function toggleCode(cellId) { | |
| const codeElement = document.getElementById('code-' + cellId); | |
| if (codeElement) { | |
| codeElement.classList.toggle('collapsed'); | |
| updateIndicators(cellId); | |
| } | |
| } | |
| function toggleOutput(cellId) { | |
| const outputElement = document.getElementById('output-' + cellId); | |
| if (outputElement) { | |
| outputElement.classList.toggle('collapsed'); | |
| updateIndicators(cellId); | |
| } | |
| } | |
| function updateIndicators(cellId) { | |
| const codeElement = document.getElementById('code-' + cellId); | |
| const outputElement = document.getElementById('output-' + cellId); | |
| const indicators = document.querySelector(`[onclick*="${cellId}"]`)?.closest('.cell-header')?.querySelector('.collapse-indicators'); | |
| if (indicators) { | |
| const codeCollapsed = codeElement && codeElement.classList.contains('collapsed'); | |
| const outputCollapsed = outputElement && outputElement.classList.contains('collapsed'); | |
| const codeIcon = codeCollapsed ? '▶' : '▼'; | |
| const outputIcon = outputCollapsed ? '▶' : '▼'; | |
| const codeSpan = indicators.querySelector('[onclick*="toggleCode"]'); | |
| const outputSpan = indicators.querySelector('[onclick*="toggleOutput"]'); | |
| if (codeSpan) codeSpan.innerHTML = `${codeIcon} code`; | |
| if (outputSpan) outputSpan.innerHTML = `${outputIcon} output`; | |
| } | |
| } | |
| function toggleTheme() { | |
| const html = document.documentElement; | |
| const currentTheme = html.getAttribute('data-theme'); | |
| const newTheme = currentTheme === 'dark' ? 'light' : 'dark'; | |
| html.setAttribute('data-theme', newTheme); | |
| localStorage.setItem('uvnote-theme', newTheme); | |
| updateThemeIcon(); | |
| } | |
| // Two panel code removed | |
| function updateThemeIcon() { | |
| const theme = document.documentElement.getAttribute('data-theme'); | |
| const toggle = document.querySelector('.theme-toggle'); | |
| if (toggle) { | |
| toggle.textContent = theme === 'dark' ? 'light' : 'dark'; | |
| } | |
| } | |
| function resetLayout() { | |
| try { | |
| // Clear all uvnote-* keys | |
| const allKeys = Object.keys(localStorage); | |
| const uvnoteKeys = allKeys.filter(key => key.startsWith('uvnote-')); | |
| uvnoteKeys.forEach(k => localStorage.removeItem(k)); | |
| } catch (_) {} | |
| // Reload to reinitialize UI with defaults | |
| location.reload(); | |
| } | |
| // Layout: stack widgets bottom-right and equalize widths | |
| function hasCustomWidgetPositions() { | |
| try { | |
| return ( | |
| localStorage.getItem('uvnote-minimap-pos') || | |
| localStorage.getItem('uvnote-file-explorer-pos') || | |
| localStorage.getItem('uvnote-tools-pos') | |
| ); | |
| } catch (_) { return false; } | |
| } | |
| function rectsOverlap(r1, r2) { | |
| return !(r1.right <= r2.left || r2.right <= r1.left || r1.bottom <= r2.top || r2.bottom <= r1.top); | |
| } | |
| function widgetsOverlap(widgets) { | |
| for (let i = 0; i < widgets.length; i++) { | |
| const a = widgets[i]; | |
| const ra = a.getBoundingClientRect(); | |
| for (let j = i + 1; j < widgets.length; j++) { | |
| const b = widgets[j]; | |
| const rb = b.getBoundingClientRect(); | |
| if (rectsOverlap(ra, rb)) return true; | |
| } | |
| } | |
| return false; | |
| } | |
| function applyStackLayout(widgets, order) { | |
| if (!widgets.length) return; | |
| // Fixed equal width | |
| const fixedWidth = 220; | |
| widgets.forEach(el => { el.style.width = fixedWidth + 'px'; }); | |
| // Fit heights if needed to avoid overflow | |
| const gap = 12; | |
| const available = Math.max(0, window.innerHeight - 40 - gap * (order.length - 1)); | |
| const eachMax = Math.floor(available / order.length); | |
| order.forEach(el => { | |
| el.style.maxHeight = eachMax + 'px'; | |
| el.style.overflowY = 'auto'; | |
| }); | |
| // Stack bottom-up in the requested order | |
| let bottomOffset = 20; // base gutter | |
| order.forEach(el => { | |
| el.style.left = 'auto'; | |
| el.style.top = 'auto'; | |
| el.style.right = '20px'; | |
| el.style.bottom = bottomOffset + 'px'; | |
| bottomOffset += el.offsetHeight + gap; | |
| }); | |
| } | |
| function layoutWidgetsStackedBottomRight() { | |
| const minimap = document.querySelector('.minimap'); | |
| const fileExplorer = document.querySelector('.file-explorer'); | |
| const tools = document.querySelector('.tools-widget'); | |
| const widgets = [minimap, fileExplorer, tools].filter(el => el && getComputedStyle(el).display !== 'none'); | |
| if (!widgets.length) return; | |
| const order = [minimap, fileExplorer, tools].filter(Boolean).filter(el => getComputedStyle(el).display !== 'none'); | |
| // If user placed custom positions and there is no overlap, respect them. | |
| if (hasCustomWidgetPositions() && !widgetsOverlap(widgets)) return; | |
| applyStackLayout(widgets, order); | |
| } | |
| // Panel icon removed | |
| let _minimapScrollContainer = null; | |
| let _minimapScrollHandler = null; | |
| function initMinimap() { | |
| // Generate minimap content | |
| const minimap = createMinimap(); | |
| document.body.appendChild(minimap); | |
| // Make draggable and slideable (use title as handle) | |
| const mTitle = minimap.querySelector('.minimap-title'); | |
| makeDraggable(minimap, 'uvnote-minimap-pos', mTitle); | |
| addSlideToggle(minimap, mTitle); | |
| // Attach scroll listener to window (two-panel removed) | |
| _minimapScrollContainer = window; | |
| if (_minimapScrollContainer) { | |
| _minimapScrollHandler = () => updateMinimapActive(); | |
| if (_minimapScrollContainer === window) { | |
| window.addEventListener('scroll', _minimapScrollHandler); | |
| } else { | |
| _minimapScrollContainer.addEventListener('scroll', _minimapScrollHandler); | |
| } | |
| } | |
| updateMinimapActive(); | |
| } | |
| function teardownMinimap() { | |
| const minimap = document.querySelector('.minimap'); | |
| if (minimap && minimap.parentNode) minimap.parentNode.removeChild(minimap); | |
| if (_minimapScrollContainer && _minimapScrollHandler) { | |
| if (_minimapScrollContainer === window) { | |
| window.removeEventListener('scroll', _minimapScrollHandler); | |
| } else { | |
| _minimapScrollContainer.removeEventListener('scroll', _minimapScrollHandler); | |
| } | |
| } | |
| _minimapScrollContainer = null; | |
| _minimapScrollHandler = null; | |
| } | |
| function initFileExplorer() { | |
| // Generate file explorer content | |
| const fileExplorer = createFileExplorer(); | |
| document.body.appendChild(fileExplorer); | |
| const title = fileExplorer.querySelector('.file-explorer-title'); | |
| addSlideToggle(fileExplorer, title); | |
| } | |
| function createMinimap() { | |
| const minimap = document.createElement('div'); | |
| minimap.className = 'minimap'; | |
| const title = document.createElement('div'); | |
| title.className = 'minimap-title'; | |
| title.textContent = 'navigation'; | |
| minimap.appendChild(title); | |
| // Find all headings and cells | |
| const root = document.querySelector('.main-content') || document; | |
| const headings = root.querySelectorAll('h1, h2, h3, h4, h5, h6'); | |
| const cells = root.querySelectorAll('.cell'); | |
| // Combine and sort by position | |
| const items = []; | |
| headings.forEach(heading => { | |
| const id = heading.id || generateId(heading.textContent); | |
| if (!heading.id) heading.id = id; | |
| items.push({ | |
| element: heading, | |
| type: 'heading', | |
| level: parseInt(heading.tagName.charAt(1)), | |
| text: heading.textContent.trim(), | |
| id: id, | |
| position: heading.getBoundingClientRect().top + window.scrollY | |
| }); | |
| }); | |
| cells.forEach(cell => { | |
| const header = cell.querySelector('.cell-header'); | |
| if (header) { | |
| const id = cell.id || `cell-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; | |
| if (!cell.id) cell.id = id; | |
| items.push({ | |
| element: cell, | |
| type: 'cell', | |
| text: header.textContent.trim(), | |
| id: id, | |
| position: cell.getBoundingClientRect().top + window.scrollY | |
| }); | |
| } | |
| }); | |
| // Sort by position | |
| items.sort((a, b) => a.position - b.position); | |
| // Create minimap items | |
| items.forEach(item => { | |
| const link = document.createElement('a'); | |
| link.className = `minimap-item ${item.type === 'heading' ? 'minimap-heading' : 'minimap-cell'}`; | |
| if (item.type === 'heading') { | |
| link.classList.add(`h${item.level}`); | |
| } | |
| link.textContent = item.text.length > 25 ? item.text.substring(0, 22) + '...' : item.text; | |
| link.href = `#${item.id}`; | |
| link.onclick = function(e) { | |
| e.preventDefault(); | |
| item.element.scrollIntoView({ behavior: 'smooth', block: 'start' }); | |
| }; | |
| minimap.appendChild(link); | |
| }); | |
| return minimap; | |
| } | |
| function generateId(text) { | |
| return text.toLowerCase() | |
| .replace(/[^a-z0-9]+/g, '-') | |
| .replace(/^-+|-+$/g, '') | |
| .substring(0, 20); | |
| } | |
| function updateMinimapActive() { | |
| const minimapItems = document.querySelectorAll('.minimap-item'); | |
| const container = _minimapScrollContainer || window; | |
| const containerRect = container === window ? null : container.getBoundingClientRect(); | |
| const scrollPos = (container === window ? window.scrollY : container.scrollTop) + 100; // Offset for better detection | |
| let activeItem = null; | |
| minimapItems.forEach(item => { | |
| const targetId = item.getAttribute('href').substring(1); | |
| const target = document.getElementById(targetId); | |
| if (target) { | |
| const rectTop = target.getBoundingClientRect().top; | |
| const targetPos = (container === window) | |
| ? rectTop + window.scrollY | |
| : rectTop - containerRect.top + container.scrollTop; | |
| if (targetPos <= scrollPos) { | |
| activeItem = item; | |
| } | |
| } | |
| item.classList.remove('active'); | |
| }); | |
| if (activeItem) { | |
| activeItem.classList.add('active'); | |
| } | |
| } | |
| function createFileExplorer() { | |
| const fileExplorer = document.createElement('div'); | |
| fileExplorer.className = 'file-explorer'; | |
| const title = document.createElement('div'); | |
| title.className = 'file-explorer-title'; | |
| title.textContent = 'files'; | |
| fileExplorer.appendChild(title); | |
| // Make draggable (use title as handle) | |
| makeDraggable(fileExplorer, 'uvnote-file-explorer-pos', title); | |
| // Scripts section | |
| const scriptsSection = document.createElement('div'); | |
| scriptsSection.className = 'file-explorer-section'; | |
| const scriptsTitle = document.createElement('div'); | |
| scriptsTitle.className = 'file-explorer-section-title'; | |
| scriptsTitle.textContent = 'scripts'; | |
| scriptsSection.appendChild(scriptsTitle); | |
| // Find all cells and list their script files (single panel) | |
| const root = document.querySelector('.main-content') || document; | |
| const cells = root.querySelectorAll('.cell'); | |
| cells.forEach(cell => { | |
| const header = cell.querySelector('.cell-header'); | |
| if (header) { | |
| const cellText = header.textContent.trim(); | |
| const cellMatch = cellText.match(/Cell: ([a-zA-Z_][a-zA-Z0-9_]*)/); | |
| if (cellMatch) { | |
| const cellId = cellMatch[1]; | |
| const scriptItem = document.createElement('div'); | |
| scriptItem.className = 'file-explorer-item script'; | |
| scriptItem.textContent = `${cellId}.py`; | |
| scriptItem.onclick = function() { | |
| cell.scrollIntoView({ behavior: 'smooth', block: 'start' }); | |
| }; | |
| scriptsSection.appendChild(scriptItem); | |
| } | |
| } | |
| }); | |
| fileExplorer.appendChild(scriptsSection); | |
| // Artifacts section | |
| const artifactsSection = document.createElement('div'); | |
| artifactsSection.className = 'file-explorer-section'; | |
| const artifactsTitle = document.createElement('div'); | |
| artifactsTitle.className = 'file-explorer-section-title'; | |
| artifactsTitle.textContent = 'artifacts'; | |
| artifactsSection.appendChild(artifactsTitle); | |
| // Find all artifact links (single panel) | |
| const artifactsRoot = document.querySelector('.main-content') || document; | |
| const artifacts = artifactsRoot.querySelectorAll('.artifact'); | |
| if (artifacts.length === 0) { | |
| const noArtifacts = document.createElement('div'); | |
| noArtifacts.className = 'file-explorer-item artifact'; | |
| noArtifacts.textContent = '(none)'; | |
| noArtifacts.style.opacity = '0.5'; | |
| artifactsSection.appendChild(noArtifacts); | |
| } else { | |
| artifacts.forEach(artifact => { | |
| const artifactItem = document.createElement('div'); | |
| artifactItem.className = 'file-explorer-item artifact'; | |
| artifactItem.textContent = artifact.textContent; | |
| artifactItem.onclick = function() { | |
| artifact.click(); | |
| }; | |
| artifactsSection.appendChild(artifactItem); | |
| }); | |
| } | |
| fileExplorer.appendChild(artifactsSection); | |
| return fileExplorer; | |
| } | |
| // Tools widget | |
| function setActiveTool(tool) { | |
| if (!tool || tool === 'none') { | |
| document.body.dataset.tool = 'none'; | |
| localStorage.setItem('uvnote-active-tool', 'none'); | |
| setOverlayActive(false); | |
| return; | |
| } | |
| document.body.dataset.tool = tool; | |
| localStorage.setItem('uvnote-active-tool', tool); | |
| setOverlayActive(true); | |
| } | |
| function getArrowColor() { | |
| const saved = localStorage.getItem('uvnote-arrow-color'); | |
| if (saved) return saved; | |
| return '#e53935'; // Default red color | |
| } | |
| function setStoredArrowColor(color) { | |
| try { localStorage.setItem('uvnote-arrow-color', color); } catch (_) {} | |
| } | |
| function getLineThickness() { | |
| const saved = localStorage.getItem('uvnote-line-thickness'); | |
| if (saved) return parseInt(saved, 10); | |
| return 4; // default thickness | |
| } | |
| function setStoredLineThickness(thickness) { | |
| try { localStorage.setItem('uvnote-line-thickness', thickness); } catch (_) {} | |
| } | |
| function createToolsWidget() { | |
| const tools = document.createElement('div'); | |
| tools.className = 'tools-widget'; | |
| const title = document.createElement('div'); | |
| title.className = 'tools-title'; | |
| title.textContent = 'tools'; | |
| tools.appendChild(title); | |
| const row = document.createElement('div'); | |
| row.className = 'tools-row'; | |
| tools.appendChild(row); | |
| // Arrow tool | |
| const arrowBtn = document.createElement('div'); | |
| arrowBtn.className = 'tool-button'; | |
| arrowBtn.textContent = 'arrow'; | |
| arrowBtn.onclick = function() { | |
| const isActive = arrowBtn.classList.contains('active'); | |
| if (isActive) { | |
| arrowBtn.classList.remove('active'); | |
| setActiveTool('none'); | |
| } else { | |
| tools.querySelectorAll('.tool-button').forEach(b => b.classList.remove('active')); | |
| arrowBtn.classList.add('active'); | |
| setActiveTool('arrow'); | |
| } | |
| }; | |
| row.appendChild(arrowBtn); | |
| // Pen tool | |
| const penBtn = document.createElement('div'); | |
| penBtn.className = 'tool-button'; | |
| penBtn.textContent = 'pen'; | |
| penBtn.onclick = function() { | |
| const isActive = penBtn.classList.contains('active'); | |
| if (isActive) { | |
| penBtn.classList.remove('active'); | |
| setActiveTool('none'); | |
| } else { | |
| tools.querySelectorAll('.tool-button').forEach(b => b.classList.remove('active')); | |
| penBtn.classList.add('active'); | |
| setActiveTool('pen'); | |
| } | |
| }; | |
| row.appendChild(penBtn); | |
| // Eraser tool | |
| const eraseBtn = document.createElement('div'); | |
| eraseBtn.className = 'tool-button'; | |
| eraseBtn.textContent = 'eraser'; | |
| eraseBtn.onclick = function() { | |
| const isActive = eraseBtn.classList.contains('active'); | |
| if (isActive) { | |
| eraseBtn.classList.remove('active'); | |
| setActiveTool('none'); | |
| } else { | |
| tools.querySelectorAll('.tool-button').forEach(b => b.classList.remove('active')); | |
| eraseBtn.classList.add('active'); | |
| setActiveTool('eraser'); | |
| } | |
| }; | |
| row.appendChild(eraseBtn); | |
| // Clear all | |
| const clearBtn = document.createElement('div'); | |
| clearBtn.className = 'tool-button'; | |
| clearBtn.textContent = 'clear'; | |
| clearBtn.onclick = function() { | |
| _shapes = []; | |
| saveShapes(); | |
| renderOverlay(); | |
| }; | |
| row.appendChild(clearBtn); | |
| // Restore active state from storage | |
| const saved = localStorage.getItem('uvnote-active-tool') || 'none'; | |
| if (saved === 'arrow') { | |
| arrowBtn.classList.add('active'); | |
| setActiveTool('arrow'); | |
| } else if (saved === 'pen') { | |
| penBtn.classList.add('active'); | |
| setActiveTool('pen'); | |
| } else if (saved === 'eraser') { | |
| eraseBtn.classList.add('active'); | |
| setActiveTool('eraser'); | |
| } | |
| // Color selector | |
| const colorTitle = document.createElement('div'); | |
| colorTitle.className = 'tools-section-title'; | |
| colorTitle.textContent = 'color'; | |
| tools.appendChild(colorTitle); | |
| const colorRow = document.createElement('div'); | |
| colorRow.className = 'tools-row color-row'; | |
| tools.appendChild(colorRow); | |
| const swatchColors = [ | |
| // Primary colors | |
| '#e53935', '#fb8c00', '#fdd835', '#43a047', '#1e88e5', '#8e24aa', | |
| // Additional useful colors | |
| '#ff5722', '#795548', '#607d8b', '#9c27b0', | |
| // Grayscale | |
| '#000000', '#424242', '#9e9e9e', '#ffffff' | |
| ]; | |
| const swatches = []; | |
| swatchColors.forEach(c => { | |
| const s = document.createElement('div'); | |
| s.className = 'color-swatch'; | |
| s.style.backgroundColor = c; | |
| s.title = c; | |
| s.onclick = () => { | |
| setStoredArrowColor(c); | |
| refreshColorUI(c); | |
| }; | |
| colorRow.appendChild(s); | |
| swatches.push(s); | |
| }); | |
| const colorInput = document.createElement('input'); | |
| colorInput.type = 'color'; | |
| colorInput.className = 'color-input'; | |
| colorInput.oninput = () => { | |
| setStoredArrowColor(colorInput.value); | |
| refreshColorUI(colorInput.value); | |
| }; | |
| colorRow.appendChild(colorInput); | |
| function refreshColorUI(selected) { | |
| const selectedHex = selected.startsWith('#') ? selected.toLowerCase() : rgbToHex(selected); | |
| swatches.forEach((s, i) => { | |
| const swatchHex = swatchColors[i].toLowerCase(); | |
| if (swatchHex === selectedHex) { | |
| s.classList.add('selected'); | |
| } else { | |
| s.classList.remove('selected'); | |
| } | |
| }); | |
| try { | |
| colorInput.value = selectedHex; | |
| } catch (_) {} | |
| } | |
| function rgbToHex(rgb) { | |
| const m = rgb.match(/rgba?\((\d+),\s*(\d+),\s*(\d+)/i); | |
| if (!m) return '#000000'; | |
| const r = parseInt(m[1]).toString(16).padStart(2, '0'); | |
| const g = parseInt(m[2]).toString(16).padStart(2, '0'); | |
| const b = parseInt(m[3]).toString(16).padStart(2, '0'); | |
| return `#${r}${g}${b}`; | |
| } | |
| // Restore color selection | |
| refreshColorUI(getArrowColor()); | |
| // Thickness slider | |
| const thicknessTitle = document.createElement('div'); | |
| thicknessTitle.className = 'tools-section-title'; | |
| thicknessTitle.textContent = 'thickness'; | |
| tools.appendChild(thicknessTitle); | |
| const thicknessRow = document.createElement('div'); | |
| thicknessRow.className = 'thickness-row'; | |
| tools.appendChild(thicknessRow); | |
| const thicknessSlider = document.createElement('input'); | |
| thicknessSlider.type = 'range'; | |
| thicknessSlider.className = 'thickness-slider'; | |
| thicknessSlider.min = '1'; | |
| thicknessSlider.max = '10'; | |
| thicknessSlider.value = getLineThickness(); | |
| const thicknessValue = document.createElement('span'); | |
| thicknessValue.className = 'thickness-value'; | |
| thicknessValue.textContent = thicknessSlider.value + 'px'; | |
| thicknessSlider.oninput = function() { | |
| const value = parseInt(thicknessSlider.value, 10); | |
| setStoredLineThickness(value); | |
| thicknessValue.textContent = value + 'px'; | |
| }; | |
| thicknessRow.appendChild(thicknessSlider); | |
| thicknessRow.appendChild(thicknessValue); | |
| // Draggable behavior | |
| makeDraggable(tools, 'uvnote-tools-pos', title); | |
| return tools; | |
| } | |
| function initTools() { | |
| const widget = createToolsWidget(); | |
| document.body.appendChild(widget); | |
| const title = widget.querySelector('.tools-title'); | |
| addSlideToggle(widget, title); | |
| } | |
| function teardownTools() { | |
| const w = document.querySelector('.tools-widget'); | |
| if (w && w.parentNode) w.parentNode.removeChild(w); | |
| } | |
| // --- Canvas overlay for tools --- | |
| let _overlay = null; | |
| let _overlayCtx = null; | |
| let _overlayContainer = null; // window | |
| let _overlayMode = 'single'; | |
| let _overlayResizeHandler = null; | |
| let _overlayScrollHandler = null; | |
| let _drawing = null; // current in-progress arrow {x1,y1,x2,y2} | |
| let _shapes = []; // committed shapes for current mode | |
| let _fadeTimer = null; // timer for fade animation | |
| function getOverlayStorageKey() { return 'uvnote-shapes'; } | |
| function loadShapes() { | |
| try { | |
| const raw = localStorage.getItem(getOverlayStorageKey()); | |
| _shapes = raw ? JSON.parse(raw) : []; | |
| } catch (_) { _shapes = []; } | |
| } | |
| function saveShapes() { | |
| try { localStorage.setItem(getOverlayStorageKey(), JSON.stringify(_shapes)); } catch (_) {} | |
| } | |
| function updateShapesFade() { | |
| const now = Date.now(); | |
| const fadeStartTime = 3000; // Start fading after 3 seconds | |
| const fadeEndTime = 5000; // Fully gone after 5 seconds | |
| let needsUpdate = false; | |
| for (let i = _shapes.length - 1; i >= 0; i--) { | |
| const shape = _shapes[i]; | |
| if (!shape.createdAt) continue; // Skip old shapes without timestamps | |
| const age = now - shape.createdAt; | |
| if (age >= fadeEndTime) { | |
| // Remove completely faded shapes | |
| _shapes.splice(i, 1); | |
| needsUpdate = true; | |
| } else if (age >= fadeStartTime) { | |
| // Update opacity for fading shapes | |
| const fadeProgress = (age - fadeStartTime) / (fadeEndTime - fadeStartTime); | |
| const newOpacity = 1 - fadeProgress; | |
| if (Math.abs(shape.opacity - newOpacity) > 0.01) { | |
| shape.opacity = newOpacity; | |
| needsUpdate = true; | |
| } | |
| } | |
| } | |
| if (needsUpdate) { | |
| saveShapes(); | |
| renderOverlay(); | |
| } | |
| } | |
| function getContentContainer() { return window; } | |
| function updateOverlayModeAndContainer() { | |
| _overlayContainer = window; | |
| _overlayMode = 'single'; | |
| } | |
| function updateOverlayBounds() { | |
| if (!_overlay) return; | |
| if (_overlayContainer === window) { | |
| _overlay.style.position = 'fixed'; | |
| _overlay.style.left = '0px'; | |
| _overlay.style.top = '0px'; | |
| _overlay.width = window.innerWidth; | |
| _overlay.height = window.innerHeight; | |
| } else { | |
| const rect = _overlayContainer.getBoundingClientRect(); | |
| _overlay.style.position = 'fixed'; | |
| _overlay.style.left = rect.left + 'px'; | |
| _overlay.style.top = rect.top + 'px'; | |
| _overlay.width = Math.max(0, Math.floor(rect.width)); | |
| _overlay.height = Math.max(0, Math.floor(rect.height)); | |
| } | |
| renderOverlay(); | |
| } | |
| function containerScrollLeft() { | |
| return (_overlayContainer === window) ? (window.scrollX || 0) : (_overlayContainer.scrollLeft || 0); | |
| } | |
| function containerScrollTop() { | |
| return (_overlayContainer === window) ? (window.scrollY || 0) : (_overlayContainer.scrollTop || 0); | |
| } | |
| function toCanvasCoords(clientX, clientY) { | |
| const rect = _overlay.getBoundingClientRect(); | |
| return { x: clientX - rect.left, y: clientY - rect.top }; | |
| } | |
| function onPointerDown(e) { | |
| const tool = document.body.dataset.tool; | |
| if (tool === 'arrow') { | |
| startDrawArrow(e); | |
| } else if (tool === 'pen') { | |
| startDrawPen(e); | |
| } else if (tool === 'eraser') { | |
| eraseAt(e); | |
| } | |
| } | |
| function onPointerMove(e) { | |
| if (!_drawing) return; | |
| if (_drawing.type === 'pen') { | |
| moveDrawPen(e); | |
| } else { | |
| moveDrawArrow(e); | |
| } | |
| } | |
| function onPointerUp(e) { | |
| if (!_drawing) return; | |
| if (_drawing.type === 'pen') { | |
| endDrawPen(); | |
| } else { | |
| endDrawArrow(); | |
| } | |
| } | |
| function startDrawArrow(e) { | |
| if (document.body.dataset.tool !== 'arrow') return; | |
| const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY); | |
| _drawing = { | |
| x1: pt.x + containerScrollLeft(), | |
| y1: pt.y + containerScrollTop(), | |
| x2: pt.x + containerScrollLeft(), | |
| y2: pt.y + containerScrollTop(), | |
| color: getArrowColor(), | |
| width: getLineThickness() | |
| }; | |
| renderOverlay(); | |
| e.preventDefault(); | |
| } | |
| function moveDrawArrow(e) { | |
| if (!_drawing) return; | |
| const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY); | |
| _drawing.x2 = pt.x + containerScrollLeft(); | |
| _drawing.y2 = pt.y + containerScrollTop(); | |
| renderOverlay(); | |
| e.preventDefault(); | |
| } | |
| function endDrawArrow() { | |
| if (!_drawing) return; | |
| _shapes.push({ | |
| type: 'arrow', | |
| ..._drawing, | |
| createdAt: Date.now(), | |
| opacity: 1.0 | |
| }); | |
| _drawing = null; | |
| saveShapes(); | |
| renderOverlay(); | |
| } | |
| function startDrawPen(e) { | |
| if (document.body.dataset.tool !== 'pen') return; | |
| const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY); | |
| _drawing = { | |
| type: 'pen', | |
| points: [{ | |
| x: pt.x + containerScrollLeft(), | |
| y: pt.y + containerScrollTop() | |
| }], | |
| color: getArrowColor(), | |
| width: getLineThickness() | |
| }; | |
| renderOverlay(); | |
| e.preventDefault(); | |
| } | |
| function moveDrawPen(e) { | |
| if (!_drawing || _drawing.type !== 'pen') return; | |
| const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY); | |
| _drawing.points.push({ | |
| x: pt.x + containerScrollLeft(), | |
| y: pt.y + containerScrollTop() | |
| }); | |
| renderOverlay(); | |
| e.preventDefault(); | |
| } | |
| function endDrawPen() { | |
| if (!_drawing || _drawing.type !== 'pen') return; | |
| if (_drawing.points.length > 1) { | |
| _shapes.push({ | |
| ..._drawing, | |
| createdAt: Date.now(), | |
| opacity: 1.0 | |
| }); | |
| } | |
| _drawing = null; | |
| saveShapes(); | |
| renderOverlay(); | |
| } | |
| function distPointToSegment(px, py, x1, y1, x2, y2) { | |
| const dx = x2 - x1, dy = y2 - y1; | |
| if (dx === 0 && dy === 0) return Math.hypot(px - x1, py - y1); | |
| const t = Math.max(0, Math.min(1, ((px - x1) * dx + (py - y1) * dy) / (dx*dx + dy*dy))); | |
| const cx = x1 + t * dx, cy = y1 + t * dy; | |
| return Math.hypot(px - cx, py - cy); | |
| } | |
| function eraseAt(e) { | |
| const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY); | |
| const x = pt.x + containerScrollLeft(); | |
| const y = pt.y + containerScrollTop(); | |
| const threshold = 10; // pixels | |
| for (let i = _shapes.length - 1; i >= 0; i--) { | |
| const s = _shapes[i]; | |
| if (s.type === 'arrow') { | |
| const d = distPointToSegment(x, y, s.x1, s.y1, s.x2, s.y2); | |
| if (d <= threshold) { | |
| _shapes.splice(i, 1); | |
| saveShapes(); | |
| renderOverlay(); | |
| break; | |
| } | |
| } else if (s.type === 'pen' && s.points) { | |
| // Check if click is near any line segment in the pen stroke | |
| let minDist = Infinity; | |
| for (let j = 1; j < s.points.length; j++) { | |
| const d = distPointToSegment(x, y, s.points[j-1].x, s.points[j-1].y, s.points[j].x, s.points[j].y); | |
| minDist = Math.min(minDist, d); | |
| } | |
| if (minDist <= threshold) { | |
| _shapes.splice(i, 1); | |
| saveShapes(); | |
| renderOverlay(); | |
| break; | |
| } | |
| } | |
| } | |
| e.preventDefault(); | |
| } | |
| function drawArrow(ctx, x1, y1, x2, y2, color, width, opacity = 1.0) { | |
| // Set opacity | |
| const oldAlpha = ctx.globalAlpha; | |
| ctx.globalAlpha = opacity; | |
| ctx.strokeStyle = color; | |
| ctx.fillStyle = color; | |
| ctx.lineWidth = width; | |
| ctx.lineCap = 'round'; | |
| ctx.lineJoin = 'round'; | |
| // Calculate arrow geometry | |
| const angle = Math.atan2(y2 - y1, x2 - x1); | |
| const headLength = Math.min(15 + width * 1.5, 25); // Cap the max head size | |
| const headAngle = Math.PI / 6; // 30 degrees | |
| // Calculate where the line should end (before the arrowhead) | |
| const lineEndX = x2 - headLength * 0.8 * Math.cos(angle); | |
| const lineEndY = y2 - headLength * 0.8 * Math.sin(angle); | |
| // Draw the line | |
| ctx.beginPath(); | |
| ctx.moveTo(x1, y1); | |
| ctx.lineTo(lineEndX, lineEndY); | |
| ctx.stroke(); | |
| // Calculate arrowhead points | |
| const hx1 = x2 - headLength * Math.cos(angle - headAngle); | |
| const hy1 = y2 - headLength * Math.sin(angle - headAngle); | |
| const hx2 = x2 - headLength * Math.cos(angle + headAngle); | |
| const hy2 = y2 - headLength * Math.sin(angle + headAngle); | |
| // Draw arrowhead | |
| ctx.beginPath(); | |
| ctx.moveTo(x2, y2); | |
| ctx.lineTo(hx1, hy1); | |
| ctx.lineTo(hx2, hy2); | |
| ctx.closePath(); | |
| ctx.fill(); | |
| // Restore opacity | |
| ctx.globalAlpha = oldAlpha; | |
| } | |
| function drawPen(ctx, points, color, width, offX, offY, opacity = 1.0) { | |
| if (!points || points.length < 2) return; | |
| // Set opacity | |
| const oldAlpha = ctx.globalAlpha; | |
| ctx.globalAlpha = opacity; | |
| ctx.strokeStyle = color; | |
| ctx.lineWidth = width; | |
| ctx.lineCap = 'round'; | |
| ctx.lineJoin = 'round'; | |
| ctx.beginPath(); | |
| ctx.moveTo(points[0].x - offX, points[0].y - offY); | |
| for (let i = 1; i < points.length; i++) { | |
| ctx.lineTo(points[i].x - offX, points[i].y - offY); | |
| } | |
| ctx.stroke(); | |
| // Restore opacity | |
| ctx.globalAlpha = oldAlpha; | |
| } | |
| function renderOverlay() { | |
| if (!_overlay || !_overlayCtx) return; | |
| _overlayCtx.clearRect(0, 0, _overlay.width, _overlay.height); | |
| const offX = containerScrollLeft(); | |
| const offY = containerScrollTop(); | |
| // Draw committed shapes for current mode | |
| for (const s of _shapes) { | |
| const opacity = s.opacity !== undefined ? s.opacity : 1.0; | |
| if (s.type === 'arrow') { | |
| drawArrow(_overlayCtx, s.x1 - offX, s.y1 - offY, s.x2 - offX, s.y2 - offY, s.color || '#f00', s.width || 2, opacity); | |
| } else if (s.type === 'pen') { | |
| drawPen(_overlayCtx, s.points, s.color || '#f00', s.width || 2, offX, offY, opacity); | |
| } | |
| } | |
| // Draw current drawing | |
| if (_drawing) { | |
| if (_drawing.type === 'pen') { | |
| drawPen(_overlayCtx, _drawing.points, _drawing.color, _drawing.width, offX, offY); | |
| } else { | |
| drawArrow(_overlayCtx, _drawing.x1 - offX, _drawing.y1 - offY, _drawing.x2 - offX, _drawing.y2 - offY, _drawing.color, _drawing.width); | |
| } | |
| } | |
| } | |
| function setOverlayActive(active) { | |
| if (!_overlay) initOverlay(); | |
| _overlay.style.pointerEvents = active ? 'auto' : 'none'; | |
| // Re-render to ensure visibility aligns with content | |
| renderOverlay(); | |
| } | |
| function initOverlay() { | |
| if (_overlay) return; | |
| updateOverlayModeAndContainer(); | |
| _overlay = document.createElement('canvas'); | |
| _overlay.className = 'draw-overlay'; | |
| _overlayCtx = _overlay.getContext('2d'); | |
| document.body.appendChild(_overlay); | |
| updateOverlayBounds(); | |
| loadShapes(); | |
| renderOverlay(); | |
| // Events | |
| _overlay.addEventListener('mousedown', onPointerDown); | |
| _overlay.addEventListener('mousemove', onPointerMove); | |
| document.addEventListener('mouseup', onPointerUp); | |
| _overlay.addEventListener('touchstart', onPointerDown, { passive: false }); | |
| _overlay.addEventListener('touchmove', onPointerMove, { passive: false }); | |
| document.addEventListener('touchend', onPointerUp); | |
| _overlayResizeHandler = () => updateOverlayBounds(); | |
| window.addEventListener('resize', _overlayResizeHandler); | |
| _overlayScrollHandler = () => renderOverlay(); | |
| window.addEventListener('scroll', _overlayScrollHandler); | |
| // Start fade animation timer | |
| _fadeTimer = setInterval(updateShapesFade, 100); // Update every 100ms for smooth fade | |
| } | |
| function rebindOverlayContainer() { | |
| if (!_overlay) return; | |
| // Remove old scroll handler | |
| if (_overlayScrollHandler) { window.removeEventListener('scroll', _overlayScrollHandler); } | |
| updateOverlayModeAndContainer(); | |
| updateOverlayBounds(); | |
| loadShapes(); | |
| renderOverlay(); | |
| _overlayScrollHandler = () => renderOverlay(); | |
| window.addEventListener('scroll', _overlayScrollHandler); | |
| } | |
| function teardownOverlay() { | |
| if (!_overlay) return; | |
| _overlay.removeEventListener('mousedown', onPointerDown); | |
| _overlay.removeEventListener('mousemove', onPointerMove); | |
| document.removeEventListener('mouseup', onPointerUp); | |
| _overlay.removeEventListener('touchstart', onPointerDown); | |
| _overlay.removeEventListener('touchmove', onPointerMove); | |
| document.removeEventListener('touchend', onPointerUp); | |
| if (_overlayResizeHandler) window.removeEventListener('resize', _overlayResizeHandler); | |
| if (_overlayScrollHandler) { | |
| if (_overlayContainer === window) { | |
| window.removeEventListener('scroll', _overlayScrollHandler); | |
| } else if (_overlayContainer) { | |
| _overlayContainer.removeEventListener('scroll', _overlayScrollHandler); | |
| } | |
| } | |
| if (_fadeTimer) { | |
| clearInterval(_fadeTimer); | |
| _fadeTimer = null; | |
| } | |
| if (_overlay.parentNode) _overlay.parentNode.removeChild(_overlay); | |
| _overlay = null; _overlayCtx = null; _overlayContainer = null; _overlayResizeHandler = null; _overlayScrollHandler = null; _drawing = null; | |
| } | |
| function teardownFileExplorer() { | |
| const fe = document.querySelector('.file-explorer'); | |
| if (fe && fe.parentNode) fe.parentNode.removeChild(fe); | |
| } | |
| function runCell(cellId){ | |
| const btn=document.querySelector('.run-btn[onclick*="'+cellId+'"]'); | |
| const output=document.getElementById('output-'+cellId); | |
| if(btn){btn.textContent='⏳ running...';btn.disabled=true;} | |
| if(output){output.classList.add('output-stale');} | |
| fetch('/run/'+cellId,{method:'POST'}).then(r=>r.json()).then(data=>{ | |
| if(output){ | |
| output.classList.remove('output-stale'); | |
| let html=''; | |
| if(data.stdout) html+='<div class="cell-stdout">'+data.stdout+'</div>'; | |
| if(data.stderr) html+='<div class="cell-stderr">'+data.stderr+'</div>'; | |
| output.innerHTML=html; | |
| } | |
| if(btn){btn.textContent='▶ run';btn.disabled=false;} | |
| }).catch(e=>{ | |
| console.error('Run failed:',e); | |
| if(output){output.classList.remove('output-stale');} | |
| if(btn){btn.textContent='▶ run';btn.disabled=false;} | |
| }); | |
| } | |
| function copyCell(cellId){ | |
| console.log('copyCell called with cellId:', cellId); | |
| // Try multiple selectors to find the code element | |
| let codeElement = document.querySelector('#code-'+cellId+' code'); | |
| if (!codeElement) { | |
| codeElement = document.querySelector('#code-'+cellId+' pre code'); | |
| } | |
| if (!codeElement) { | |
| codeElement = document.querySelector('#code-'+cellId+' .highlight code'); | |
| } | |
| if (!codeElement) { | |
| // Try finding any code element within the cell | |
| const codeDiv = document.getElementById('code-'+cellId); | |
| if (codeDiv) { | |
| codeElement = codeDiv.querySelector('code'); | |
| } | |
| } | |
| const btn = document.querySelector('.copy-btn[onclick*="'+cellId+'"]'); | |
| console.log('Found codeElement:', codeElement); | |
| console.log('Found btn:', btn); | |
| console.log('Code div structure:', document.getElementById('code-'+cellId)); | |
| if (!codeElement) { | |
| console.error('Code element not found for cell:', cellId); | |
| // Log the actual structure for debugging | |
| const codeDiv = document.getElementById('code-'+cellId); | |
| if (codeDiv) { | |
| console.log('Code div HTML:', codeDiv.innerHTML); | |
| } | |
| return; | |
| } | |
| if (!btn) { | |
| console.error('Copy button not found for cell:', cellId); | |
| return; | |
| } | |
| const codeText = codeElement.textContent; | |
| console.log('Code text to copy:', codeText ? codeText.substring(0, 50) + '...' : 'empty'); | |
| if (navigator.clipboard && navigator.clipboard.writeText) { | |
| navigator.clipboard.writeText(codeText).then(function() { | |
| console.log('Clipboard copy successful'); | |
| btn.textContent = '✓ Copied!'; | |
| btn.classList.add('copied'); | |
| setTimeout(function() { | |
| btn.textContent = 'Copy'; | |
| btn.classList.remove('copied'); | |
| }, 2000); | |
| }).catch(function(err) { | |
| console.warn('Clipboard copy failed:', err); | |
| fallbackCopy(); | |
| }); | |
| } else { | |
| console.log('Using fallback copy method'); | |
| fallbackCopy(); | |
| } | |
| function fallbackCopy() { | |
| const textarea = document.createElement('textarea'); | |
| textarea.value = codeText; | |
| textarea.style.position = 'absolute'; | |
| textarea.style.left = '-9999px'; | |
| document.body.appendChild(textarea); | |
| textarea.select(); | |
| try { | |
| const success = document.execCommand('copy'); | |
| console.log('Fallback copy success:', success); | |
| btn.textContent = '✓ Copied!'; | |
| btn.classList.add('copied'); | |
| setTimeout(function() { | |
| btn.textContent = 'Copy'; | |
| btn.classList.remove('copied'); | |
| }, 2000); | |
| } catch (err) { | |
| console.error('Fallback copy failed:', err); | |
| btn.textContent = 'Copy failed'; | |
| setTimeout(function() { | |
| btn.textContent = 'Copy'; | |
| }, 2000); | |
| } | |
| document.body.removeChild(textarea); | |
| } | |
| } | |
| // Live reload functionality (robust SSE handling) | |
| (function(){ | |
| if (!('EventSource' in window)) { | |
| console.warn('SSE not supported in this browser'); | |
| return; | |
| } | |
| let source = new EventSource('/events'); | |
| let isOpen = false; | |
| source.onopen = function(){ isOpen = true; console.log('SSE connected'); }; | |
| source.onmessage = function(e){ | |
| const msg=(e.data||'').trim(); if(!msg) return; | |
| console.log('SSE message:', msg); | |
| if (msg==='reload' || msg==='incremental') { location.reload(); } | |
| // Ignore 'loading' to avoid premature reload loops | |
| }; | |
| source.onerror = function(e){ | |
| // Let EventSource auto-reconnect instead of forcing a reload | |
| if (isOpen) console.warn('SSE error after open, retrying...', e); | |
| }; | |
| window.addEventListener('beforeunload', function(){ try{source.close();}catch(_){} }); | |
| })(); | |
| document.addEventListener('DOMContentLoaded', function() { | |
| updateThemeIcon(); | |
| initMinimap(); | |
| initFileExplorer(); | |
| initTools(); | |
| initOverlay(); | |
| layoutWidgetsStackedBottomRight(); | |
| window.addEventListener('resize', layoutWidgetsStackedBottomRight); | |
| }); | |
| </script> | |
| </head> | |
| <body> | |
| <div class="controls"> | |
| <div class="theme-toggle" onclick="toggleTheme()">light</div> | |
| <div class="reset-toggle" onclick="resetLayout()">reset</div> | |
| </div> | |
| <div class="main-content"> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('utils')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('utils')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: utils | deps: torch, numpy | 30.61s | |
| | <button class="run-btn" onclick="runCell('utils')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('utils')">Copy</button> | |
| </div> | |
| <div id="code-utils" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal">10</span> | |
| <span class="normal">11</span> | |
| <span class="normal">12</span> | |
| <span class="normal">13</span> | |
| <span class="normal">14</span> | |
| <span class="normal">15</span> | |
| <span class="normal">16</span> | |
| <span class="normal">17</span> | |
| <span class="normal">18</span> | |
| <span class="normal">19</span> | |
| <span class="normal">20</span> | |
| <span class="normal">21</span> | |
| <span class="normal">22</span> | |
| <span class="normal">23</span> | |
| <span class="normal">24</span> | |
| <span class="normal">25</span> | |
| <span class="normal">26</span> | |
| <span class="normal">27</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">"""Simple utilities for running the models."""</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">to_dtype</span><span class="p">(</span><span class="n">dtype_str</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span> | |
| <span class="w"> </span><span class="sd">"""Convert string to torch dtype."""</span> | |
| <span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">"float16"</span><span class="p">:</span> | |
| <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span> | |
| <span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">"bfloat16"</span><span class="p">:</span> | |
| <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span> | |
| <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">tensor_stats</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> | |
| <span class="w"> </span><span class="sd">"""Generate stats string for a tensor."""</span> | |
| <span class="k">return</span> <span class="p">(</span><span class="sa">f</span><span class="s2">"shape=</span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"dtype=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"device=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"mean=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"std=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">set_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> | |
| <span class="w"> </span><span class="sd">"""Set seeds for reproducibility."""</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed_all</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">deterministic</span> <span class="o">=</span> <span class="kc">True</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">benchmark</span> <span class="o">=</span> <span class="kc">False</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-utils" class="cell-output"> | |
| <div class="cell-stderr">Downloading setuptools (1.1MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading sympy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 234ms | |
| </div> | |
| </div> | |
| </div> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('bench_utils')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('bench_utils')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: bench_utils | deps: torch, numpy | 31.57s | |
| | <button class="run-btn" onclick="runCell('bench_utils')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('bench_utils')">Copy</button> | |
| </div> | |
| <div id="code-bench_utils" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span> | |
| <span class="normal">107</span> | |
| <span class="normal">108</span> | |
| <span class="normal">109</span> | |
| <span class="normal">110</span> | |
| <span class="normal">111</span> | |
| <span class="normal">112</span> | |
| <span class="normal">113</span> | |
| <span class="normal">114</span> | |
| <span class="normal">115</span> | |
| <span class="normal">116</span> | |
| <span class="normal">117</span> | |
| <span class="normal">118</span> | |
| <span class="normal">119</span> | |
| <span class="normal">120</span> | |
| <span class="normal">121</span> | |
| <span class="normal">122</span> | |
| <span class="normal">123</span> | |
| <span class="normal">124</span> | |
| <span class="normal">125</span> | |
| <span class="normal">126</span> | |
| <span class="normal">127</span> | |
| <span class="normal">128</span> | |
| <span class="normal">129</span> | |
| <span class="normal">130</span> | |
| <span class="normal">131</span> | |
| <span class="normal">132</span> | |
| <span class="normal">133</span> | |
| <span class="normal">134</span> | |
| <span class="normal">135</span> | |
| <span class="normal">136</span> | |
| <span class="normal">137</span> | |
| <span class="normal">138</span> | |
| <span class="normal">139</span> | |
| <span class="normal">140</span> | |
| <span class="normal">141</span> | |
| <span class="normal">142</span> | |
| <span class="normal">143</span> | |
| <span class="normal">144</span> | |
| <span class="normal">145</span> | |
| <span class="normal">146</span> | |
| <span class="normal">147</span> | |
| <span class="normal">148</span> | |
| <span class="normal">149</span> | |
| <span class="normal">150</span> | |
| <span class="normal">151</span> | |
| <span class="normal">152</span> | |
| <span class="normal">153</span> | |
| <span class="normal">154</span> | |
| <span class="normal">155</span> | |
| <span class="normal">156</span> | |
| <span class="normal">157</span> | |
| <span class="normal">158</span> | |
| <span class="normal">159</span> | |
| <span class="normal">160</span> | |
| <span class="normal">161</span> | |
| <span class="normal">162</span> | |
| <span class="normal">163</span> | |
| <span class="normal">164</span> | |
| <span class="normal">165</span> | |
| <span class="normal">166</span> | |
| <span class="normal">167</span> | |
| <span class="normal">168</span> | |
| <span class="normal">169</span> | |
| <span class="normal">170</span> | |
| <span class="normal">171</span> | |
| <span class="normal">172</span> | |
| <span class="normal">173</span> | |
| <span class="normal">174</span> | |
| <span class="normal">175</span> | |
| <span class="normal">176</span> | |
| <span class="normal">177</span> | |
| <span class="normal">178</span> | |
| <span class="normal">179</span> | |
| <span class="normal">180</span> | |
| <span class="normal">181</span> | |
| <span class="normal">182</span> | |
| <span class="normal">183</span> | |
| <span class="normal">184</span> | |
| <span class="normal">185</span> | |
| <span class="normal">186</span> | |
| <span class="normal">187</span> | |
| <span class="normal">188</span> | |
| <span class="normal">189</span> | |
| <span class="normal">190</span> | |
| <span class="normal">191</span> | |
| <span class="normal">192</span> | |
| <span class="normal">193</span> | |
| <span class="normal">194</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">"""Reusable benchmarking utilities for performance testing."""</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">time</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">contextlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">contextmanager</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Optional</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">to_dtype</span><span class="p">(</span><span class="n">dtype_str</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span> | |
| <span class="w"> </span><span class="sd">"""Convert string to torch dtype."""</span> | |
| <span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">"float16"</span><span class="p">:</span> | |
| <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span> | |
| <span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">"bfloat16"</span><span class="p">:</span> | |
| <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span> | |
| <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">_sync</span><span class="p">(</span><span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span> | |
| <span class="w"> </span><span class="sd">"""Synchronize device if CUDA."""</span> | |
| <span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">"cuda"</span><span class="p">:</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">_compute_stats</span><span class="p">(</span><span class="n">times_s</span><span class="p">,</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">]:</span> | |
| <span class="w"> </span><span class="sd">"""Compute comprehensive latency and throughput statistics."""</span> | |
| <span class="n">lat_ms</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">t</span> <span class="o">*</span> <span class="mf">1000.0</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">times_s</span><span class="p">])</span> | |
| <span class="n">lat_ms_sorted</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">)</span> | |
| <span class="n">n</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">)</span> | |
| <span class="n">stats</span> <span class="o">=</span> <span class="p">{</span> | |
| <span class="s2">"avg_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span> | |
| <span class="s2">"min_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span> | |
| <span class="s2">"max_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span> | |
| <span class="s2">"std_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span> | |
| <span class="s2">"p50_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">,</span> <span class="mi">50</span><span class="p">),</span> | |
| <span class="s2">"p95_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">,</span> <span class="mi">95</span><span class="p">),</span> | |
| <span class="s2">"p99_ms"</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">,</span> <span class="mi">99</span><span class="p">),</span> | |
| <span class="s2">"num_iters"</span><span class="p">:</span> <span class="n">n</span> | |
| <span class="p">}</span> | |
| <span class="k">if</span> <span class="n">tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">n</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> | |
| <span class="n">avg_s</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">times_s</span><span class="p">)</span> | |
| <span class="n">stats</span><span class="p">[</span><span class="s2">"tokens_per_s"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">/</span> <span class="n">avg_s</span> <span class="k">if</span> <span class="n">avg_s</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="nb">float</span><span class="p">(</span><span class="s2">"inf"</span><span class="p">)</span> | |
| <span class="n">stats</span><span class="p">[</span><span class="s2">"throughput_variance"</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">([</span><span class="n">tokens</span> <span class="o">/</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">times_s</span> <span class="k">if</span> <span class="n">t</span> <span class="o">></span> <span class="mi">0</span><span class="p">])</span> | |
| <span class="k">return</span> <span class="n">stats</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">_format_timing_stats</span><span class="p">(</span><span class="n">stats</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> | |
| <span class="w"> </span><span class="sd">"""Format timing statistics for display."""</span> | |
| <span class="n">lines</span> <span class="o">=</span> <span class="p">[</span> | |
| <span class="s2">"</span><span class="se">\n</span><span class="s2">━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">"Iterations: </span><span class="si">{</span><span class="n">stats</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'num_iters'</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> | |
| <span class="s2">"</span><span class="se">\n</span><span class="s2">Latency Statistics:"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" Average: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'avg_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" Min: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'min_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" Max: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'max_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" Std Dev: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'std_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="s2">"</span><span class="se">\n</span><span class="s2">Percentiles:"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" P50 (median): </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'p50_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" P95: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'p95_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" P99: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'p99_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms"</span><span class="p">,</span> | |
| <span class="p">]</span> | |
| <span class="k">if</span> <span class="n">tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="s1">'tokens_per_s'</span> <span class="ow">in</span> <span class="n">stats</span><span class="p">:</span> | |
| <span class="n">lines</span><span class="o">.</span><span class="n">extend</span><span class="p">([</span> | |
| <span class="s2">"</span><span class="se">\n</span><span class="s2">Throughput:"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" Tokens/sec: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'tokens_per_s'</span><span class="p">]</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s2">" Std Dev: </span><span class="si">{</span><span class="n">stats</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'throughput_variance'</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">)</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> | |
| <span class="p">])</span> | |
| <span class="n">lines</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">"</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lines</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">_bench_engine</span><span class="p">(</span> | |
| <span class="n">call</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">Any</span><span class="p">],</span> <span class="o">*</span><span class="p">,</span> <span class="n">warmup</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">iters</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dtype</span> | |
| <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="nb">list</span><span class="p">]:</span> | |
| <span class="w"> </span><span class="sd">"""Core benchmarking engine with warmup and timing."""</span> | |
| <span class="n">use_autocast</span> <span class="o">=</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">"cuda"</span> <span class="ow">and</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">)</span> | |
| <span class="c1"># Warmup phase</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Warming up (</span><span class="si">{</span><span class="n">warmup</span><span class="si">}</span><span class="s2"> iterations)..."</span><span class="p">)</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span> | |
| <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">warmup</span><span class="p">)):</span> | |
| <span class="k">if</span> <span class="n">use_autocast</span><span class="p">:</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">autocast</span><span class="p">(</span><span class="n">device_type</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">):</span> | |
| <span class="n">_</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span> | |
| <span class="k">else</span><span class="p">:</span> | |
| <span class="n">_</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span> | |
| <span class="n">_sync</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="c1"># Measurement phase</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Benchmarking (</span><span class="si">{</span><span class="n">iters</span><span class="si">}</span><span class="s2"> iterations)..."</span><span class="p">)</span> | |
| <span class="n">times_s</span> <span class="o">=</span> <span class="p">[]</span> | |
| <span class="n">last</span> <span class="o">=</span> <span class="kc">None</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span> | |
| <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">iters</span><span class="p">)):</span> | |
| <span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span> | |
| <span class="k">if</span> <span class="n">use_autocast</span><span class="p">:</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">autocast</span><span class="p">(</span><span class="n">device_type</span><span class="o">=</span><span class="s2">"cuda"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">):</span> | |
| <span class="n">last</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span> | |
| <span class="k">else</span><span class="p">:</span> | |
| <span class="n">last</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span> | |
| <span class="n">_sync</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="n">end</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span> | |
| <span class="n">times_s</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">)</span> | |
| <span class="c1"># Progress indicator every 20% of iterations</span> | |
| <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">%</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">iters</span> <span class="o">//</span> <span class="mi">5</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> | |
| <span class="n">pct</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">/</span> <span class="n">iters</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span> | |
| <span class="n">avg_so_far</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">times_s</span><span class="p">[:</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="mi">1000</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" Progress: </span><span class="si">{</span><span class="n">pct</span><span class="si">:</span><span class="s2">.0f</span><span class="si">}</span><span class="s2">% complete (avg: </span><span class="si">{</span><span class="n">avg_so_far</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms)"</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">last</span><span class="p">,</span> <span class="n">times_s</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">tensor_stats</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span> | |
| <span class="w"> </span><span class="sd">"""Generate comprehensive stats string for a tensor."""</span> | |
| <span class="k">return</span> <span class="p">(</span><span class="sa">f</span><span class="s2">"shape=</span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"dtype=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"device=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"range=[</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">], "</span> | |
| <span class="sa">f</span><span class="s2">"mean=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"std=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, "</span> | |
| <span class="sa">f</span><span class="s2">"norm=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">norm</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nd">@contextmanager</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">bench_context</span><span class="p">(</span> | |
| <span class="o">*</span><span class="p">,</span> <span class="n">warmup</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">25</span><span class="p">,</span> <span class="n">iters</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"cuda"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">save_json</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span> | |
| <span class="p">):</span> | |
| <span class="w"> </span><span class="sd">"""Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats)."""</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">runner</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]:</span> | |
| <span class="c1"># Log configuration</span> | |
| <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">┌─ Benchmark Configuration ─────────────────────────────┐"</span><span class="p">)</span> | |
| <span class="c1"># print(f"│ Device: {device:<15} Dtype: {dtype} │")</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"│ Warmup: </span><span class="si">{</span><span class="n">warmup</span><span class="si">:</span><span class="s2"><15</span><span class="si">}</span><span class="s2"> Iters: </span><span class="si">{</span><span class="n">iters</span><span class="si">}</span><span class="s2"> │"</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="n">tokens</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"│ Tokens: </span><span class="si">{</span><span class="n">tokens</span><span class="si">}</span><span class="s2"> │"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"└────────────────────────────────────────────────────────┘"</span><span class="p">)</span> | |
| <span class="c1"># Log input if it's a tensor</span> | |
| <span class="k">if</span> <span class="n">verbose</span> <span class="ow">and</span> <span class="n">args</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Input: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="n">call</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">fn</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> | |
| <span class="n">result</span><span class="p">,</span> <span class="n">times_s</span> <span class="o">=</span> <span class="n">_bench_engine</span><span class="p">(</span><span class="n">call</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="n">iters</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> | |
| <span class="c1"># Log output if it's a tensor or tuple with tensors</span> | |
| <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output tensors:"</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" Primary: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">result</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" Primary: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span> | |
| <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" Auxiliary: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">else</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" Auxiliary: </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="c1"># Compute and display statistics</span> | |
| <span class="n">stats</span> <span class="o">=</span> <span class="n">_compute_stats</span><span class="p">(</span><span class="n">times_s</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="n">_format_timing_stats</span><span class="p">(</span><span class="n">stats</span><span class="p">,</span> <span class="n">tokens</span><span class="p">))</span> | |
| <span class="c1"># Save to JSON if requested</span> | |
| <span class="k">if</span> <span class="n">save_json</span><span class="p">:</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">json</span> | |
| <span class="n">json_data</span> <span class="o">=</span> <span class="p">{</span> | |
| <span class="s2">"implementation"</span><span class="p">:</span> <span class="n">save_json</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">".json"</span><span class="p">,</span> <span class="s2">""</span><span class="p">),</span> | |
| <span class="s2">"config"</span><span class="p">:</span> <span class="p">{</span> | |
| <span class="s2">"warmup"</span><span class="p">:</span> <span class="n">warmup</span><span class="p">,</span> | |
| <span class="s2">"iters"</span><span class="p">:</span> <span class="n">iters</span><span class="p">,</span> | |
| <span class="s2">"device"</span><span class="p">:</span> <span class="nb">str</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="c1"># Convert device to string</span> | |
| <span class="s2">"dtype"</span><span class="p">:</span> <span class="nb">str</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span> | |
| <span class="s2">"tokens"</span><span class="p">:</span> <span class="n">tokens</span> | |
| <span class="p">},</span> | |
| <span class="s2">"stats"</span><span class="p">:</span> <span class="n">stats</span><span class="p">,</span> | |
| <span class="s2">"output_sum"</span><span class="p">:</span> <span class="nb">float</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="nb">float</span><span class="p">(</span><span class="n">result</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span> | |
| <span class="p">}</span> | |
| <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">save_json</span><span class="p">,</span> <span class="s1">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> | |
| <span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">json_data</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Saved benchmark results to </span><span class="si">{</span><span class="n">save_json</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">result</span><span class="p">,</span> <span class="n">stats</span> | |
| <span class="k">yield</span> <span class="n">runner</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">set_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> | |
| <span class="w"> </span><span class="sd">"""Set seeds for reproducibility."""</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed_all</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">deterministic</span> <span class="o">=</span> <span class="kc">True</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">benchmark</span> <span class="o">=</span> <span class="kc">False</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-bench_utils" class="cell-output"> | |
| <div class="cell-stderr">Downloading networkx (1.9MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading sympy | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 234ms | |
| </div> | |
| </div> | |
| </div> | |
| <p>This notebook runs the Yamoe and Binned MoE implementations once each with identical inputs to verify they produce consistent outputs.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('config')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('config')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: config | deps: torch, numpy | 37.88s | |
| | <button class="run-btn" onclick="runCell('config')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('config')">Copy</button> | |
| </div> | |
| <div id="code-config" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal">10</span> | |
| <span class="normal">11</span> | |
| <span class="normal">12</span> | |
| <span class="normal">13</span> | |
| <span class="normal">14</span> | |
| <span class="normal">15</span> | |
| <span class="normal">16</span> | |
| <span class="normal">17</span> | |
| <span class="normal">18</span> | |
| <span class="normal">19</span> | |
| <span class="normal">20</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">"""Shared configuration for both implementations."""</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="c1"># Model configuration</span> | |
| <span class="n">NUM_EXPERTS</span> <span class="o">=</span> <span class="mi">128</span> | |
| <span class="n">HIDDEN_SIZE</span> <span class="o">=</span> <span class="mi">1152</span> | |
| <span class="n">INTERMEDIATE_SIZE</span> <span class="o">=</span> <span class="mi">3072</span> | |
| <span class="n">TOP_K</span> <span class="o">=</span> <span class="mi">4</span> | |
| <span class="c1"># Input configuration</span> | |
| <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">1</span> | |
| <span class="n">SEQ_LEN</span> <span class="o">=</span> <span class="mi">100</span> | |
| <span class="n">DTYPE</span> <span class="o">=</span> <span class="s2">"float32"</span> | |
| <span class="n">DEVICE</span> <span class="o">=</span> <span class="s2">"cuda"</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s2">"cpu"</span> | |
| <span class="c1"># Seeds for reproducibility</span> | |
| <span class="n">WEIGHT_SEED</span> <span class="o">=</span> <span class="mi">999</span> | |
| <span class="n">EXPERT_SEED</span> <span class="o">=</span> <span class="mi">777</span> | |
| <span class="n">INPUT_SEED</span> <span class="o">=</span> <span class="mi">123</span> | |
| <span class="n">GENERAL_SEED</span> <span class="o">=</span> <span class="mi">42</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-config" class="cell-output"> | |
| <div class="cell-stderr">Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading sympy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 225ms | |
| </div> | |
| </div> | |
| </div> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('save_data')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('save_data')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: save_data | deps: torch, numpy | 38.59s | |
| | <button class="run-btn" onclick="runCell('save_data')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('save_data')">Copy</button> | |
| </div> | |
| <div id="code-save_data" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal">10</span> | |
| <span class="normal">11</span> | |
| <span class="normal">12</span> | |
| <span class="normal">13</span> | |
| <span class="normal">14</span> | |
| <span class="normal">15</span> | |
| <span class="normal">16</span> | |
| <span class="normal">17</span> | |
| <span class="normal">18</span> | |
| <span class="normal">19</span> | |
| <span class="normal">20</span> | |
| <span class="normal">21</span> | |
| <span class="normal">22</span> | |
| <span class="normal">23</span> | |
| <span class="normal">24</span> | |
| <span class="normal">25</span> | |
| <span class="normal">26</span> | |
| <span class="normal">27</span> | |
| <span class="normal">28</span> | |
| <span class="normal">29</span> | |
| <span class="normal">30</span> | |
| <span class="normal">31</span> | |
| <span class="normal">32</span> | |
| <span class="normal">33</span> | |
| <span class="normal">34</span> | |
| <span class="normal">35</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">"""</span> | |
| <span class="sd">Generate deterministic shared weights once and save as artifacts so</span> | |
| <span class="sd">both implementations load identical parameters.</span> | |
| <span class="sd">"""</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">save_shared_weights</span><span class="p">():</span> | |
| <span class="c1"># Router: Kaiming uniform as used by both, bias zeros</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">WEIGHT_SEED</span><span class="p">)</span> | |
| <span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">kaiming_uniform_</span><span class="p">(</span><span class="n">router_weight</span><span class="p">)</span> | |
| <span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">)</span> | |
| <span class="c1"># Experts: normal(0, 0.02), biases zeros</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">EXPERT_SEED</span><span class="p">)</span> | |
| <span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span> | |
| <span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span> | |
| <span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span> | |
| <span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span> | |
| <span class="c1"># Save artifacts</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="s1">'router_weight.pt'</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">router_bias</span><span class="p">,</span> <span class="s1">'router_bias.pt'</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="p">,</span> <span class="s1">'gate_up_proj.pt'</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="s1">'gate_up_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">down_proj</span><span class="p">,</span> <span class="s1">'down_proj.pt'</span><span class="p">)</span> | |
| <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="p">,</span> <span class="s1">'down_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Saved shared weights to artifacts"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="n">save_shared_weights</span><span class="p">()</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-save_data" class="cell-output"> | |
| <div class="cell-stdout">Saved shared weights to artifacts | |
| Router weight sum: 12.588732 | |
| Gate/up sum: 1026.601807 | |
| Down sum: 206.729263 | |
| </div> | |
| <div class="cell-stderr">Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading sympy | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 239ms | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/save_data/down_proj.pt" class="artifact" target="_blank">down_proj.pt</a> | |
| <a href="artifacts/save_data/down_proj_bias.pt" class="artifact" target="_blank">down_proj_bias.pt</a> | |
| <a href="artifacts/save_data/gate_up_proj.pt" class="artifact" target="_blank">gate_up_proj.pt</a> | |
| <a href="artifacts/save_data/gate_up_proj_bias.pt" class="artifact" target="_blank">gate_up_proj_bias.pt</a> | |
| <a href="artifacts/save_data/router_bias.pt" class="artifact" target="_blank">router_bias.pt</a> | |
| <a href="artifacts/save_data/router_weight.pt" class="artifact" target="_blank">router_weight.pt</a> | |
| </div> | |
| </div> | |
| </div> | |
| <h2>Yamoe Implementation</h2> | |
| <p>This section runs the Yamoe MoE implementation with optimized Triton kernels.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('yamoe_run')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('yamoe_run')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: yamoe_run | deps: torch, kernels, numpy | 35.75s | |
| | <button class="run-btn" onclick="runCell('yamoe_run')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('yamoe_run')">Copy</button> | |
| </div> | |
| <div id="code-yamoe_run" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span> | |
| <span class="normal">107</span> | |
| <span class="normal">108</span> | |
| <span class="normal">109</span> | |
| <span class="normal">110</span> | |
| <span class="normal">111</span> | |
| <span class="normal">112</span> | |
| <span class="normal">113</span> | |
| <span class="normal">114</span> | |
| <span class="normal">115</span> | |
| <span class="normal">116</span> | |
| <span class="normal">117</span> | |
| <span class="normal">118</span> | |
| <span class="normal">119</span> | |
| <span class="normal">120</span> | |
| <span class="normal">121</span> | |
| <span class="normal">122</span> | |
| <span class="normal">123</span> | |
| <span class="normal">124</span> | |
| <span class="normal">125</span> | |
| <span class="normal">126</span> | |
| <span class="normal">127</span> | |
| <span class="normal">128</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">get_kernel</span><span class="p">,</span> <span class="n">get_local_kernel</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span> | |
| <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span> | |
| <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span> | |
| <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span> | |
| <span class="p">)</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> | |
| <span class="c1"># Discover the upstream artifact directory from env</span> | |
| <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_SAVE_DATA'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Loading weights from: </span><span class="si">{</span><span class="n">data_dir</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_weight.pt'</span><span class="p">)</span> | |
| <span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_bias.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Loaded shared weights from artifacts"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">YamoeRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span> | |
| <span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span> | |
| <span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> | |
| <span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">YamoeMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">YamoeRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span> | |
| <span class="c1"># Load Yamoe kernel</span> | |
| <span class="c1"># self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">yamoe</span> <span class="o">=</span> <span class="n">get_kernel</span><span class="p">(</span><span class="s2">"drbh/yamoe"</span><span class="p">,</span> <span class="n">revision</span><span class="o">=</span><span class="s2">"v0.2.0"</span><span class="p">)</span> | |
| <span class="c1"># Expert capacity - generous to avoid dropping tokens</span> | |
| <span class="c1"># self.expert_capacity = 256</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span> <span class="o">=</span> <span class="mi">12</span> | |
| <span class="c1"># Expert weights - use the loaded weights</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span> | |
| <span class="c1"># Get routing decisions</span> | |
| <span class="n">routing_weights</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span> | |
| <span class="c1"># Reshape for Yamoe kernel</span> | |
| <span class="n">hidden_states_flat</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span> | |
| <span class="n">routing_weights_flat</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">)</span> | |
| <span class="c1"># Call Yamoe optimized kernel</span> | |
| <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">yamoe</span><span class="o">.</span><span class="n">experts</span><span class="p">(</span> | |
| <span class="n">hidden_states_flat</span><span class="p">,</span> | |
| <span class="n">router_indices</span><span class="p">,</span> | |
| <span class="n">routing_weights_flat</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> | |
| <span class="p">)</span> | |
| <span class="c1"># Reshape output back</span> | |
| <span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">routing_weights</span> | |
| <span class="c1"># Run the model</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span> | |
| <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span> <span class="k">if</span> <span class="n">DEVICE</span> <span class="o">==</span> <span class="s2">"cuda"</span> <span class="k">else</span> <span class="s2">"cuda"</span><span class="p">)</span> | |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">=== Yamoe Implementation ==="</span><span class="p">)</span> | |
| <span class="c1"># Initialize model with loaded weights</span> | |
| <span class="n">model</span> <span class="o">=</span> <span class="n">YamoeMoEMLP</span><span class="p">(</span> | |
| <span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="c1"># Generate input</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span> | |
| <span class="c1"># Benchmark the model</span> | |
| <span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span> | |
| <span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">"yamoe_results.json"</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span> | |
| <span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-yamoe_run" class="cell-output"> | |
| <div class="cell-stdout">Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73 | |
| Loaded shared weights from artifacts | |
| Router weight sum: 12.588732 | |
| Gate/up sum: 1026.601807 | |
| Down sum: 206.729263 | |
| === Yamoe Implementation === | |
| Router weight sum: 12.588732 | |
| Gate/up proj sum: 1026.601807 | |
| Down proj sum: 206.729340 | |
| ┌─ Benchmark Configuration ─────────────────────────────┐ | |
| │ Warmup: 10 Iters: 50 │ | |
| │ Tokens: 100 │ | |
| └────────────────────────────────────────────────────────┘ | |
| Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 | |
| Warming up (10 iterations)... | |
| Benchmarking (50 iterations)... | |
| Progress: 20% complete (avg: 8.633 ms) | |
| Progress: 40% complete (avg: 8.627 ms) | |
| Progress: 60% complete (avg: 8.629 ms) | |
| Progress: 80% complete (avg: 8.630 ms) | |
| Output tensors: | |
| Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 | |
| Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 | |
| ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ | |
| Iterations: 50 | |
| Latency Statistics: | |
| Average: 8.631 ms | |
| Min: 8.526 ms | |
| Max: 8.661 ms | |
| Std Dev: 0.022 ms | |
| Percentiles: | |
| P50 (median): 8.636 ms | |
| P95: 8.653 ms | |
| P99: 8.658 ms | |
| Throughput: | |
| Tokens/sec: 11586.6 | |
| Std Dev: 29.1 | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Saved benchmark results to yamoe_results.json | |
| Output sum: -0.597250 | |
| </div> | |
| <div class="cell-stderr">Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading hf-xet (3.0MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading hf-xet | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading sympy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 37 packages in 287ms | |
| Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] | |
| Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:01, 3.90it/s] | |
| Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.70it/s] | |
| Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 7.44it/s] | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/yamoe_run/yamoe_results.json" class="artifact" target="_blank">yamoe_results.json</a> | |
| </div> | |
| </div> | |
| </div> | |
| <h2>Binned Implementation</h2> | |
| <p>This section runs the binned implementation that manually handles token gathering/scattering.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('binned_run')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('binned_run')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: binned_run | deps: torch, numpy | 42.05s | |
| | <button class="run-btn" onclick="runCell('binned_run')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('binned_run')">Copy</button> | |
| </div> | |
| <div id="code-binned_run" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span> | |
| <span class="normal">107</span> | |
| <span class="normal">108</span> | |
| <span class="normal">109</span> | |
| <span class="normal">110</span> | |
| <span class="normal">111</span> | |
| <span class="normal">112</span> | |
| <span class="normal">113</span> | |
| <span class="normal">114</span> | |
| <span class="normal">115</span> | |
| <span class="normal">116</span> | |
| <span class="normal">117</span> | |
| <span class="normal">118</span> | |
| <span class="normal">119</span> | |
| <span class="normal">120</span> | |
| <span class="normal">121</span> | |
| <span class="normal">122</span> | |
| <span class="normal">123</span> | |
| <span class="normal">124</span> | |
| <span class="normal">125</span> | |
| <span class="normal">126</span> | |
| <span class="normal">127</span> | |
| <span class="normal">128</span> | |
| <span class="normal">129</span> | |
| <span class="normal">130</span> | |
| <span class="normal">131</span> | |
| <span class="normal">132</span> | |
| <span class="normal">133</span> | |
| <span class="normal">134</span> | |
| <span class="normal">135</span> | |
| <span class="normal">136</span> | |
| <span class="normal">137</span> | |
| <span class="normal">138</span> | |
| <span class="normal">139</span> | |
| <span class="normal">140</span> | |
| <span class="normal">141</span> | |
| <span class="normal">142</span> | |
| <span class="normal">143</span> | |
| <span class="normal">144</span> | |
| <span class="normal">145</span> | |
| <span class="normal">146</span> | |
| <span class="normal">147</span> | |
| <span class="normal">148</span> | |
| <span class="normal">149</span> | |
| <span class="normal">150</span> | |
| <span class="normal">151</span> | |
| <span class="normal">152</span> | |
| <span class="normal">153</span> | |
| <span class="normal">154</span> | |
| <span class="normal">155</span> | |
| <span class="normal">156</span> | |
| <span class="normal">157</span> | |
| <span class="normal">158</span> | |
| <span class="normal">159</span> | |
| <span class="normal">160</span> | |
| <span class="normal">161</span> | |
| <span class="normal">162</span> | |
| <span class="normal">163</span> | |
| <span class="normal">164</span> | |
| <span class="normal">165</span> | |
| <span class="normal">166</span> | |
| <span class="normal">167</span> | |
| <span class="normal">168</span> | |
| <span class="normal">169</span> | |
| <span class="normal">170</span> | |
| <span class="normal">171</span> | |
| <span class="normal">172</span> | |
| <span class="normal">173</span> | |
| <span class="normal">174</span> | |
| <span class="normal">175</span> | |
| <span class="normal">176</span> | |
| <span class="normal">177</span> | |
| <span class="normal">178</span> | |
| <span class="normal">179</span> | |
| <span class="normal">180</span> | |
| <span class="normal">181</span> | |
| <span class="normal">182</span> | |
| <span class="normal">183</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span> | |
| <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span> | |
| <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span> | |
| <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span> | |
| <span class="p">)</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> | |
| <span class="c1"># Discover the upstream artifact directory from env</span> | |
| <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_SAVE_DATA'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_weight.pt'</span><span class="p">)</span> | |
| <span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_bias.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Loaded shared weights from artifacts"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">binned_gather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">top_k</span><span class="p">):</span> | |
| <span class="n">E</span><span class="p">,</span> <span class="n">H</span> <span class="o">=</span> <span class="n">bins</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> | |
| <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">E</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">H</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> | |
| <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">E</span><span class="p">):</span> | |
| <span class="n">start</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">e</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> | |
| <span class="n">end</span> <span class="o">=</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span><span class="p">]</span> | |
| <span class="n">n</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">)</span> | |
| <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> | |
| <span class="n">flat_pos</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">start</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> | |
| <span class="n">tok</span> <span class="o">=</span> <span class="n">flat_pos</span> <span class="o">//</span> <span class="n">top_k</span> | |
| <span class="n">out</span><span class="p">[</span><span class="n">e</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">tok</span><span class="p">]</span> | |
| <span class="k">return</span> <span class="n">out</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">binned_scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">top_k</span><span class="p">):</span> | |
| <span class="n">E</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">H</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span> | |
| <span class="n">N</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">top_k</span> | |
| <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span> <span class="n">top_k</span><span class="p">,</span> <span class="n">H</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> | |
| <span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">E</span><span class="p">):</span> | |
| <span class="n">start</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">e</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> | |
| <span class="n">end</span> <span class="o">=</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span><span class="p">]</span> | |
| <span class="n">n</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span> | |
| <span class="k">if</span> <span class="n">n</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> | |
| <span class="k">continue</span> | |
| <span class="n">take</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">)</span> | |
| <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">take</span><span class="p">):</span> | |
| <span class="n">flat_pos</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">start</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> | |
| <span class="n">tok</span> <span class="o">=</span> <span class="n">flat_pos</span> <span class="o">//</span> <span class="n">top_k</span> | |
| <span class="n">slot</span> <span class="o">=</span> <span class="n">flat_pos</span> <span class="o">%</span> <span class="n">top_k</span> | |
| <span class="n">scale</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="n">flat_pos</span><span class="p">]</span> <span class="k">if</span> <span class="n">weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mf">1.0</span> | |
| <span class="n">out</span><span class="p">[</span><span class="n">tok</span><span class="p">,</span> <span class="n">slot</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">e</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale</span> | |
| <span class="k">return</span> <span class="n">out</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">sort_tokens_by_expert</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">num_experts</span><span class="p">):</span> | |
| <span class="n">flat_indices</span> <span class="o">=</span> <span class="n">router_indices</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> | |
| <span class="n">sorted_values</span><span class="p">,</span> <span class="n">sorted_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">flat_indices</span><span class="p">)</span> | |
| <span class="n">tokens_per_expert</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bincount</span><span class="p">(</span><span class="n">sorted_values</span><span class="p">,</span> <span class="n">minlength</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span> | |
| <span class="n">bins</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">tokens_per_expert</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">sorted_indices</span><span class="p">,</span> <span class="n">sorted_values</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">tokens_per_expert</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">binned_experts_ref</span><span class="p">(</span> | |
| <span class="n">hidden_states</span><span class="p">,</span> | |
| <span class="n">router_indices</span><span class="p">,</span> | |
| <span class="n">routing_weights</span><span class="p">,</span> | |
| <span class="n">gate_up_proj</span><span class="p">,</span> | |
| <span class="n">gate_up_proj_bias</span><span class="p">,</span> | |
| <span class="n">down_proj</span><span class="p">,</span> | |
| <span class="n">down_proj_bias</span><span class="p">,</span> | |
| <span class="n">expert_capacity</span><span class="p">,</span> | |
| <span class="p">):</span> | |
| <span class="n">B</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">H</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span> | |
| <span class="n">E</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">router_indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> | |
| <span class="n">indices</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">sort_tokens_by_expert</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">E</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">binned_gather</span><span class="p">(</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span><span class="p">),</span> <span class="n">indices</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span> | |
| <span class="n">gate_up</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">)</span> | |
| <span class="n">gate_up</span> <span class="o">+=</span> <span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> | |
| <span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> | |
| <span class="c1"># clamp to limit</span> | |
| <span class="n">limit</span> <span class="o">=</span> <span class="mf">7.0</span> | |
| <span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="mf">1.702</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">)</span> <span class="o">+</span> <span class="n">down_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> | |
| <span class="c1"># build routing weights aligned to (token, slot)</span> | |
| <span class="n">flat_dense</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">E</span><span class="p">)</span> | |
| <span class="n">flat_router</span> <span class="o">=</span> <span class="n">router_indices</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span> | |
| <span class="n">selected</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="n">flat_dense</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">flat_router</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="c1"># scatter back</span> | |
| <span class="n">y</span> <span class="o">=</span> <span class="n">binned_scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">selected</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">BinnedRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span> | |
| <span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span> | |
| <span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> | |
| <span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">BinnedMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">BinnedRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span> <span class="o">=</span> <span class="mi">256</span> | |
| <span class="c1"># Expert weights - use the loaded weights</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span> | |
| <span class="n">output</span> <span class="o">=</span> <span class="n">binned_experts_ref</span><span class="p">(</span> | |
| <span class="n">hidden_states</span><span class="p">,</span> | |
| <span class="n">router_indices</span><span class="p">,</span> | |
| <span class="n">router_scores</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">,</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">,</span> | |
| <span class="p">)</span> | |
| <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">router_scores</span> | |
| <span class="c1"># Run the model</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span> | |
| <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span> | |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">=== Binned Implementation ==="</span><span class="p">)</span> | |
| <span class="c1"># Initialize model with loaded weights</span> | |
| <span class="n">model</span> <span class="o">=</span> <span class="n">BinnedMoEMLP</span><span class="p">(</span> | |
| <span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="c1"># Generate the same input as Yamoe</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span> | |
| <span class="c1"># Benchmark the model</span> | |
| <span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span> | |
| <span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">"binned_results.json"</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span> | |
| <span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-binned_run" class="cell-output"> | |
| <div class="cell-stdout">Loaded shared weights from artifacts | |
| Router weight sum: 12.588732 | |
| Gate/up sum: 1026.601807 | |
| Down sum: 206.729263 | |
| === Binned Implementation === | |
| Router weight sum: 12.588732 | |
| Gate/up proj sum: 1026.601807 | |
| Down proj sum: 206.729340 | |
| ┌─ Benchmark Configuration ─────────────────────────────┐ | |
| │ Warmup: 10 Iters: 50 │ | |
| │ Tokens: 100 │ | |
| └────────────────────────────────────────────────────────┘ | |
| Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 | |
| Warming up (10 iterations)... | |
| Benchmarking (50 iterations)... | |
| Progress: 20% complete (avg: 104.222 ms) | |
| Progress: 40% complete (avg: 104.671 ms) | |
| Progress: 60% complete (avg: 105.372 ms) | |
| Progress: 80% complete (avg: 105.570 ms) | |
| Output tensors: | |
| Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 | |
| Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 | |
| ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ | |
| Iterations: 50 | |
| Latency Statistics: | |
| Average: 105.618 ms | |
| Min: 103.417 ms | |
| Max: 107.809 ms | |
| Std Dev: 1.458 ms | |
| Percentiles: | |
| P50 (median): 105.048 ms | |
| P95: 107.729 ms | |
| P99: 107.790 ms | |
| Throughput: | |
| Tokens/sec: 946.8 | |
| Std Dev: 13.0 | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Saved benchmark results to binned_results.json | |
| Output sum: -0.597248 | |
| </div> | |
| <div class="cell-stderr">Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading sympy | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 233ms | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/binned_run/binned_results.json" class="artifact" target="_blank">binned_results.json</a> | |
| </div> | |
| </div> | |
| </div> | |
| <h2>GPT-OSS Implementation</h2> | |
| <p>This section runs the GPT-OSS MoE implementation with manual expert loop handling.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('gptoss_run')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('gptoss_run')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: gptoss_run | deps: torch, numpy | 37.86s | |
| | <button class="run-btn" onclick="runCell('gptoss_run')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('gptoss_run')">Copy</button> | |
| </div> | |
| <div id="code-gptoss_run" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span> | |
| <span class="normal">107</span> | |
| <span class="normal">108</span> | |
| <span class="normal">109</span> | |
| <span class="normal">110</span> | |
| <span class="normal">111</span> | |
| <span class="normal">112</span> | |
| <span class="normal">113</span> | |
| <span class="normal">114</span> | |
| <span class="normal">115</span> | |
| <span class="normal">116</span> | |
| <span class="normal">117</span> | |
| <span class="normal">118</span> | |
| <span class="normal">119</span> | |
| <span class="normal">120</span> | |
| <span class="normal">121</span> | |
| <span class="normal">122</span> | |
| <span class="normal">123</span> | |
| <span class="normal">124</span> | |
| <span class="normal">125</span> | |
| <span class="normal">126</span> | |
| <span class="normal">127</span> | |
| <span class="normal">128</span> | |
| <span class="normal">129</span> | |
| <span class="normal">130</span> | |
| <span class="normal">131</span> | |
| <span class="normal">132</span> | |
| <span class="normal">133</span> | |
| <span class="normal">134</span> | |
| <span class="normal">135</span> | |
| <span class="normal">136</span> | |
| <span class="normal">137</span> | |
| <span class="normal">138</span> | |
| <span class="normal">139</span> | |
| <span class="normal">140</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span> | |
| <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span> | |
| <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span> | |
| <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span> | |
| <span class="p">)</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> | |
| <span class="c1"># Discover the upstream artifact directory from env</span> | |
| <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_SAVE_DATA'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_weight.pt'</span><span class="p">)</span> | |
| <span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_bias.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Loaded shared weights from artifacts"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">GptOssRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span> | |
| <span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span> | |
| <span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> | |
| <span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">GptOssExperts</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">expert_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.702</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">limit</span> <span class="o">=</span> <span class="mf">7.0</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span> | |
| <span class="n">batch_size</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> | |
| <span class="n">num_experts</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> | |
| <span class="k">if</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="s2">"cpu"</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> | |
| <span class="n">expert_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span> | |
| <span class="n">expert_mask</span> <span class="o">=</span> <span class="n">expert_mask</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> | |
| <span class="n">expert_hit</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="n">expert_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)),</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span> | |
| <span class="k">for</span> <span class="n">expert_idx</span> <span class="ow">in</span> <span class="n">expert_hit</span><span class="p">[:]:</span> | |
| <span class="n">expert_idx</span> <span class="o">=</span> <span class="n">expert_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> | |
| <span class="n">_</span><span class="p">,</span> <span class="n">token_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">])</span> | |
| <span class="n">current_state</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="n">token_idx</span><span class="p">]</span> | |
| <span class="n">gate_up</span> <span class="o">=</span> <span class="n">current_state</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> | |
| <span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> | |
| <span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span> | |
| <span class="n">gated_output</span> <span class="o">=</span> <span class="p">(</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span> | |
| <span class="n">out</span> <span class="o">=</span> <span class="n">gated_output</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> | |
| <span class="n">weighted_output</span> <span class="o">=</span> <span class="n">out</span> <span class="o">*</span> <span class="n">routing_weights</span><span class="p">[</span><span class="n">token_idx</span><span class="p">,</span> <span class="n">expert_idx</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> | |
| <span class="n">next_states</span><span class="o">.</span><span class="n">index_add_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">token_idx</span><span class="p">,</span> <span class="n">weighted_output</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> | |
| <span class="k">else</span><span class="p">:</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> | |
| <span class="n">gate_up</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> | |
| <span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> | |
| <span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(((</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">)</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span> <span class="o">*</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">next_states</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">GptOssMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">GptOssRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">GptOssExperts</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span> | |
| <span class="n">routed_out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="n">router_scores</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">routed_out</span><span class="p">,</span> <span class="n">router_scores</span> | |
| <span class="c1"># Run the model</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span> | |
| <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span> | |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">=== GPT-OSS Implementation ==="</span><span class="p">)</span> | |
| <span class="c1"># Initialize model with loaded weights</span> | |
| <span class="n">model</span> <span class="o">=</span> <span class="n">GptOssMoEMLP</span><span class="p">(</span> | |
| <span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="c1"># Generate the same input as other implementations</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span> | |
| <span class="c1"># Benchmark the model</span> | |
| <span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span> | |
| <span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">"gptoss_results.json"</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span> | |
| <span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-gptoss_run" class="cell-output"> | |
| <div class="cell-stdout">Loaded shared weights from artifacts | |
| Router weight sum: 12.588732 | |
| Gate/up sum: 1026.601807 | |
| Down sum: 206.729263 | |
| === GPT-OSS Implementation === | |
| Router weight sum: 12.588732 | |
| Gate/up proj sum: 1026.601807 | |
| Down proj sum: 206.729340 | |
| ┌─ Benchmark Configuration ─────────────────────────────┐ | |
| │ Warmup: 10 Iters: 50 │ | |
| │ Tokens: 100 │ | |
| └────────────────────────────────────────────────────────┘ | |
| Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 | |
| Warming up (10 iterations)... | |
| Benchmarking (50 iterations)... | |
| Progress: 20% complete (avg: 46.973 ms) | |
| Progress: 40% complete (avg: 47.262 ms) | |
| Progress: 60% complete (avg: 47.067 ms) | |
| Progress: 80% complete (avg: 46.985 ms) | |
| Output tensors: | |
| Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 | |
| Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 | |
| ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ | |
| Iterations: 50 | |
| Latency Statistics: | |
| Average: 47.135 ms | |
| Min: 46.582 ms | |
| Max: 47.895 ms | |
| Std Dev: 0.503 ms | |
| Percentiles: | |
| P50 (median): 46.789 ms | |
| P95: 47.801 ms | |
| P99: 47.856 ms | |
| Throughput: | |
| Tokens/sec: 2121.6 | |
| Std Dev: 22.5 | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Saved benchmark results to gptoss_results.json | |
| Output sum: -0.597250 | |
| </div> | |
| <div class="cell-stderr">Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading sympy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 241ms | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/gptoss_run/gptoss_results.json" class="artifact" target="_blank">gptoss_results.json</a> | |
| </div> | |
| </div> | |
| </div> | |
| <h2>GPT-OSS Implementation (Training Mode)</h2> | |
| <p>This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('gptoss_training_run')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('gptoss_training_run')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: gptoss_training_run | deps: torch, numpy | 36.75s | |
| | <button class="run-btn" onclick="runCell('gptoss_training_run')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('gptoss_training_run')">Copy</button> | |
| </div> | |
| <div id="code-gptoss_training_run" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span> | |
| <span class="normal">107</span> | |
| <span class="normal">108</span> | |
| <span class="normal">109</span> | |
| <span class="normal">110</span> | |
| <span class="normal">111</span> | |
| <span class="normal">112</span> | |
| <span class="normal">113</span> | |
| <span class="normal">114</span> | |
| <span class="normal">115</span> | |
| <span class="normal">116</span> | |
| <span class="normal">117</span> | |
| <span class="normal">118</span> | |
| <span class="normal">119</span> | |
| <span class="normal">120</span> | |
| <span class="normal">121</span> | |
| <span class="normal">122</span> | |
| <span class="normal">123</span> | |
| <span class="normal">124</span> | |
| <span class="normal">125</span> | |
| <span class="normal">126</span> | |
| <span class="normal">127</span> | |
| <span class="normal">128</span> | |
| <span class="normal">129</span> | |
| <span class="normal">130</span> | |
| <span class="normal">131</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span> | |
| <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span> | |
| <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span> | |
| <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span> | |
| <span class="p">)</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> | |
| <span class="c1"># Discover the upstream artifact directory from env</span> | |
| <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_SAVE_DATA'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_weight.pt'</span><span class="p">)</span> | |
| <span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_bias.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Loaded shared weights from artifacts"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">GptOssTrainingRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span> | |
| <span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span> | |
| <span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> | |
| <span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">GptOssTrainingExperts</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">expert_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.702</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">limit</span> <span class="o">=</span> <span class="mf">7.0</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span> | |
| <span class="n">batch_size</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> | |
| <span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> | |
| <span class="n">num_experts</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> | |
| <span class="c1"># Force training mode path (expert loop instead of batched)</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> | |
| <span class="n">expert_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span> | |
| <span class="n">expert_mask</span> <span class="o">=</span> <span class="n">expert_mask</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> | |
| <span class="n">expert_hit</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="n">expert_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)),</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span> | |
| <span class="k">for</span> <span class="n">expert_idx</span> <span class="ow">in</span> <span class="n">expert_hit</span><span class="p">[:]:</span> | |
| <span class="n">expert_idx</span> <span class="o">=</span> <span class="n">expert_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> | |
| <span class="n">_</span><span class="p">,</span> <span class="n">token_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">])</span> | |
| <span class="n">current_state</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="n">token_idx</span><span class="p">]</span> | |
| <span class="n">gate_up</span> <span class="o">=</span> <span class="n">current_state</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> | |
| <span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> | |
| <span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span> | |
| <span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span> | |
| <span class="n">gated_output</span> <span class="o">=</span> <span class="p">(</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span> | |
| <span class="n">out</span> <span class="o">=</span> <span class="n">gated_output</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> | |
| <span class="n">weighted_output</span> <span class="o">=</span> <span class="n">out</span> <span class="o">*</span> <span class="n">routing_weights</span><span class="p">[</span><span class="n">token_idx</span><span class="p">,</span> <span class="n">expert_idx</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> | |
| <span class="n">next_states</span><span class="o">.</span><span class="n">index_add_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">token_idx</span><span class="p">,</span> <span class="n">weighted_output</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span> | |
| <span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">next_states</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">GptOssTrainingMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">GptOssTrainingRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">GptOssTrainingExperts</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span> | |
| <span class="n">routed_out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="n">router_scores</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">routed_out</span><span class="p">,</span> <span class="n">router_scores</span> | |
| <span class="c1"># Run the model</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span> | |
| <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span> | |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">=== GPT-OSS Implementation (Training Mode - Expert Loop) ==="</span><span class="p">)</span> | |
| <span class="c1"># Initialize model with loaded weights and force training mode</span> | |
| <span class="n">model</span> <span class="o">=</span> <span class="n">GptOssTrainingMoEMLP</span><span class="p">(</span> | |
| <span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> | |
| <span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> | |
| <span class="c1"># Set to training mode to force expert loop path</span> | |
| <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Model training mode: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">training</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="c1"># Generate the same input as other implementations</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span> | |
| <span class="c1"># Benchmark the model</span> | |
| <span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span> | |
| <span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">"gptoss_training_results.json"</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span> | |
| <span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-gptoss_training_run" class="cell-output"> | |
| <div class="cell-stdout">Loaded shared weights from artifacts | |
| Router weight sum: 12.588732 | |
| Gate/up sum: 1026.601807 | |
| Down sum: 206.729263 | |
| === GPT-OSS Implementation (Training Mode - Expert Loop) === | |
| Router weight sum: 12.588732 | |
| Gate/up proj sum: 1026.601807 | |
| Down proj sum: 206.729340 | |
| Model training mode: True | |
| ┌─ Benchmark Configuration ─────────────────────────────┐ | |
| │ Warmup: 10 Iters: 50 │ | |
| │ Tokens: 100 │ | |
| └────────────────────────────────────────────────────────┘ | |
| Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 | |
| Warming up (10 iterations)... | |
| Benchmarking (50 iterations)... | |
| Progress: 20% complete (avg: 48.328 ms) | |
| Progress: 40% complete (avg: 48.764 ms) | |
| Progress: 60% complete (avg: 48.825 ms) | |
| Progress: 80% complete (avg: 48.769 ms) | |
| Output tensors: | |
| Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 | |
| Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 | |
| ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ | |
| Iterations: 50 | |
| Latency Statistics: | |
| Average: 48.630 ms | |
| Min: 47.535 ms | |
| Max: 49.414 ms | |
| Std Dev: 0.559 ms | |
| Percentiles: | |
| P50 (median): 48.395 ms | |
| P95: 49.346 ms | |
| P99: 49.390 ms | |
| Throughput: | |
| Tokens/sec: 2056.3 | |
| Std Dev: 23.6 | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Saved benchmark results to gptoss_training_results.json | |
| Output sum: -0.597250 | |
| </div> | |
| <div class="cell-stderr">Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading setuptools (1.1MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading sympy | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 26 packages in 234ms | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/gptoss_training_run/gptoss_training_results.json" class="artifact" target="_blank">gptoss_training_results.json</a> | |
| </div> | |
| </div> | |
| </div> | |
| <h2>MegaBlocks Implementation</h2> | |
| <p>This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('megablocks_run')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('megablocks_run')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: megablocks_run | deps: torch, numpy, kernels | 43.51s | |
| | <button class="run-btn" onclick="runCell('megablocks_run')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('megablocks_run')">Copy</button> | |
| </div> | |
| <div id="code-megablocks_run" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">get_kernel</span><span class="p">,</span> <span class="n">get_local_kernel</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span> | |
| <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span> | |
| <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span> | |
| <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span> | |
| <span class="p">)</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">namedtuple</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> | |
| <span class="c1"># Discover the upstream artifact directory from env</span> | |
| <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_SAVE_DATA'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Loading weights from: </span><span class="si">{</span><span class="n">data_dir</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_weight.pt'</span><span class="p">)</span> | |
| <span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'router_bias.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj.pt'</span><span class="p">)</span> | |
| <span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'gate_up_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj.pt'</span><span class="p">)</span> | |
| <span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'down_proj_bias.pt'</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"Loaded shared weights from artifacts"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">build_megablocks_model</span><span class="p">(</span><span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">):</span> | |
| <span class="c1"># Download optimized kernels from the Hugging Face hub</span> | |
| <span class="n">megablocks</span> <span class="o">=</span> <span class="n">get_kernel</span><span class="p">(</span><span class="s2">"kernels-community/megablocks"</span><span class="p">)</span> | |
| <span class="c1"># megablocks = get_local_kernel(</span> | |
| <span class="c1"># Path("/home/ubuntu/Projects/megablocks-moe/build"), "megablocks")</span> | |
| <span class="n">model</span> <span class="o">=</span> <span class="n">megablocks</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MegaBlocksMoeMLP</span><span class="p">()</span> | |
| <span class="c1"># Create attribute container for expert weights</span> | |
| <span class="n">model</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span> | |
| <span class="s2">"Experts"</span><span class="p">,</span> <span class="p">[</span><span class="s2">"gate_up_proj"</span><span class="p">,</span> <span class="s2">"gate_up_proj_bias"</span><span class="p">,</span> <span class="s2">"down_proj"</span><span class="p">,</span> <span class="s2">"down_proj_bias"</span><span class="p">,</span> <span class="s2">"hidden_size"</span><span class="p">]</span> | |
| <span class="p">)</span> | |
| <span class="c1"># Use loaded router weights for consistency</span> | |
| <span class="n">model</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> | |
| <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span> | |
| <span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">router_weight</span><span class="p">)</span> | |
| <span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">router_bias</span><span class="p">)</span> | |
| <span class="c1"># Attach loaded expert weights to the experts container</span> | |
| <span class="n">e</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">experts</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.702</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="mi">4</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span> | |
| <span class="n">e</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span> | |
| <span class="c1"># Log weight statistics for comparison</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"[MegaBlocks] Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"[MegaBlocks] Gate/up projection shape: </span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, sum: </span><span class="si">{</span><span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"[MegaBlocks] Down projection shape: </span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, sum: </span><span class="si">{</span><span class="n">e</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">return</span> <span class="n">model</span> | |
| <span class="c1"># Create a wrapper to match the interface of other implementations</span> | |
| <span class="k">class</span><span class="w"> </span><span class="nc">MegaBlocksMoEWrapper</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> | |
| <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">megablocks_model</span><span class="p">):</span> | |
| <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> | |
| <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">megablocks_model</span> | |
| <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span> | |
| <span class="c1"># MegaBlocks expects input in the format (batch, seq_len, hidden_dim)</span> | |
| <span class="n">output</span><span class="p">,</span> <span class="n">dummy_routing_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span> | |
| <span class="c1"># Return output and dummy routing weights for consistency with other implementations</span> | |
| <span class="c1"># dummy_routing_weights = torch.zeros(</span> | |
| <span class="c1"># hidden_states.shape[0] * hidden_states.shape[1], </span> | |
| <span class="c1"># NUM_EXPERTS, </span> | |
| <span class="c1"># device=hidden_states.device,</span> | |
| <span class="c1"># dtype=hidden_states.dtype</span> | |
| <span class="c1"># )</span> | |
| <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">dummy_routing_weights</span> | |
| <span class="c1"># Run the model</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span> | |
| <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span> | |
| <span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">=== MegaBlocks Implementation ==="</span><span class="p">)</span> | |
| <span class="c1"># Build MegaBlocks model with loaded weights</span> | |
| <span class="n">megablocks_model</span> <span class="o">=</span> <span class="n">build_megablocks_model</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> | |
| <span class="n">model</span> <span class="o">=</span> <span class="n">MegaBlocksMoEWrapper</span><span class="p">(</span><span class="n">megablocks_model</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> | |
| <span class="c1"># Generate the same input as other implementations</span> | |
| <span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span> | |
| <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span> | |
| <span class="c1"># Benchmark the model</span> | |
| <span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span> | |
| <span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">"megablocks_results.json"</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span> | |
| <span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-megablocks_run" class="cell-output"> | |
| <div class="cell-stdout">Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73 | |
| Loaded shared weights from artifacts | |
| Router weight sum: 12.588732 | |
| Gate/up sum: 1026.601807 | |
| Down sum: 206.729263 | |
| === MegaBlocks Implementation === | |
| [MegaBlocks] Router weight sum: 12.588732 | |
| [MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 | |
| [MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 | |
| ┌─ Benchmark Configuration ─────────────────────────────┐ | |
| │ Warmup: 10 Iters: 50 │ | |
| │ Tokens: 100 │ | |
| └────────────────────────────────────────────────────────┘ | |
| Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 | |
| Warming up (10 iterations)... | |
| Benchmarking (50 iterations)... | |
| Progress: 20% complete (avg: 0.867 ms) | |
| Progress: 40% complete (avg: 0.853 ms) | |
| Progress: 60% complete (avg: 1.181 ms) | |
| Progress: 80% complete (avg: 3.026 ms) | |
| Output tensors: | |
| Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 | |
| Auxiliary: shape=(100, 4), dtype=torch.float32, device=cuda:0, range=[0.220910, 0.294473], mean=0.250000, std=0.010777, norm=5.004632 | |
| ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ | |
| Iterations: 50 | |
| Latency Statistics: | |
| Average: 4.133 ms | |
| Min: 0.823 ms | |
| Max: 8.589 ms | |
| Std Dev: 3.781 ms | |
| Percentiles: | |
| P50 (median): 0.864 ms | |
| P95: 8.579 ms | |
| P99: 8.589 ms | |
| Throughput: | |
| Tokens/sec: 24194.9 | |
| Std Dev: 52511.7 | |
| ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
| Saved benchmark results to megablocks_results.json | |
| Output sum: -0.597249 | |
| </div> | |
| <div class="cell-stderr">Downloading setuptools (1.1MiB) | |
| Downloading nvidia-cudnn-cu12 (674.0MiB) | |
| Downloading numpy (15.9MiB) | |
| Downloading nvidia-cusparse-cu12 (274.9MiB) | |
| Downloading nvidia-nvjitlink-cu12 (37.4MiB) | |
| Downloading hf-xet (3.0MiB) | |
| Downloading nvidia-cusolver-cu12 (255.1MiB) | |
| Downloading networkx (1.9MiB) | |
| Downloading nvidia-cufft-cu12 (184.2MiB) | |
| Downloading nvidia-cufile-cu12 (1.1MiB) | |
| Downloading triton (148.4MiB) | |
| Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) | |
| Downloading nvidia-curand-cu12 (60.7MiB) | |
| Downloading sympy (6.0MiB) | |
| Downloading nvidia-cuda-cupti-cu12 (9.8MiB) | |
| Downloading nvidia-nccl-cu12 (307.4MiB) | |
| Downloading nvidia-cusparselt-cu12 (273.9MiB) | |
| Downloading nvidia-cublas-cu12 (566.8MiB) | |
| Downloading torch (846.8MiB) | |
| Downloading nvidia-cufile-cu12 | |
| Downloading hf-xet | |
| Downloading setuptools | |
| Downloading networkx | |
| Downloading nvidia-cuda-cupti-cu12 | |
| Downloading numpy | |
| Downloading sympy | |
| Downloading nvidia-nvjitlink-cu12 | |
| Downloading nvidia-curand-cu12 | |
| Downloading nvidia-cuda-nvrtc-cu12 | |
| Downloading triton | |
| Downloading nvidia-cufft-cu12 | |
| Downloading nvidia-cusolver-cu12 | |
| Downloading nvidia-cusparse-cu12 | |
| Downloading nvidia-cusparselt-cu12 | |
| Downloading nvidia-nccl-cu12 | |
| Downloading nvidia-cublas-cu12 | |
| Downloading nvidia-cudnn-cu12 | |
| Downloading torch | |
| Installed 37 packages in 216ms | |
| Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] | |
| Fetching 66 files: 2%|▏ | 1/66 [00:00<00:22, 2.87it/s] | |
| Fetching 66 files: 26%|██▌ | 17/66 [00:01<00:04, 11.84it/s] | |
| Fetching 66 files: 100%|██████████| 66/66 [00:01<00:00, 43.56it/s] | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/megablocks_run/megablocks_results.json" class="artifact" target="_blank">megablocks_results.json</a> | |
| </div> | |
| </div> | |
| </div> | |
| <h2>Performance Visualization</h2> | |
| <p>This section reads all benchmark results and creates a comprehensive performance comparison chart.</p> | |
| <div class="cell"> | |
| <div class="cell-header"> | |
| <span class="collapse-indicators"> | |
| <span onclick="toggleCode('visualization')" style="cursor: pointer;">▼ code</span> | |
| <span onclick="toggleOutput('visualization')" style="cursor: pointer;">▼ output</span> | |
| </span> | | |
| Cell: visualization | deps: matplotlib | 3.96s | |
| | <button class="run-btn" onclick="runCell('visualization')">▶ run</button> | |
| <button class="copy-btn" onclick="copyCell('visualization')">Copy</button> | |
| </div> | |
| <div id="code-visualization" class="cell-code"> | |
| <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span> | |
| <span class="normal"> 2</span> | |
| <span class="normal"> 3</span> | |
| <span class="normal"> 4</span> | |
| <span class="normal"> 5</span> | |
| <span class="normal"> 6</span> | |
| <span class="normal"> 7</span> | |
| <span class="normal"> 8</span> | |
| <span class="normal"> 9</span> | |
| <span class="normal"> 10</span> | |
| <span class="normal"> 11</span> | |
| <span class="normal"> 12</span> | |
| <span class="normal"> 13</span> | |
| <span class="normal"> 14</span> | |
| <span class="normal"> 15</span> | |
| <span class="normal"> 16</span> | |
| <span class="normal"> 17</span> | |
| <span class="normal"> 18</span> | |
| <span class="normal"> 19</span> | |
| <span class="normal"> 20</span> | |
| <span class="normal"> 21</span> | |
| <span class="normal"> 22</span> | |
| <span class="normal"> 23</span> | |
| <span class="normal"> 24</span> | |
| <span class="normal"> 25</span> | |
| <span class="normal"> 26</span> | |
| <span class="normal"> 27</span> | |
| <span class="normal"> 28</span> | |
| <span class="normal"> 29</span> | |
| <span class="normal"> 30</span> | |
| <span class="normal"> 31</span> | |
| <span class="normal"> 32</span> | |
| <span class="normal"> 33</span> | |
| <span class="normal"> 34</span> | |
| <span class="normal"> 35</span> | |
| <span class="normal"> 36</span> | |
| <span class="normal"> 37</span> | |
| <span class="normal"> 38</span> | |
| <span class="normal"> 39</span> | |
| <span class="normal"> 40</span> | |
| <span class="normal"> 41</span> | |
| <span class="normal"> 42</span> | |
| <span class="normal"> 43</span> | |
| <span class="normal"> 44</span> | |
| <span class="normal"> 45</span> | |
| <span class="normal"> 46</span> | |
| <span class="normal"> 47</span> | |
| <span class="normal"> 48</span> | |
| <span class="normal"> 49</span> | |
| <span class="normal"> 50</span> | |
| <span class="normal"> 51</span> | |
| <span class="normal"> 52</span> | |
| <span class="normal"> 53</span> | |
| <span class="normal"> 54</span> | |
| <span class="normal"> 55</span> | |
| <span class="normal"> 56</span> | |
| <span class="normal"> 57</span> | |
| <span class="normal"> 58</span> | |
| <span class="normal"> 59</span> | |
| <span class="normal"> 60</span> | |
| <span class="normal"> 61</span> | |
| <span class="normal"> 62</span> | |
| <span class="normal"> 63</span> | |
| <span class="normal"> 64</span> | |
| <span class="normal"> 65</span> | |
| <span class="normal"> 66</span> | |
| <span class="normal"> 67</span> | |
| <span class="normal"> 68</span> | |
| <span class="normal"> 69</span> | |
| <span class="normal"> 70</span> | |
| <span class="normal"> 71</span> | |
| <span class="normal"> 72</span> | |
| <span class="normal"> 73</span> | |
| <span class="normal"> 74</span> | |
| <span class="normal"> 75</span> | |
| <span class="normal"> 76</span> | |
| <span class="normal"> 77</span> | |
| <span class="normal"> 78</span> | |
| <span class="normal"> 79</span> | |
| <span class="normal"> 80</span> | |
| <span class="normal"> 81</span> | |
| <span class="normal"> 82</span> | |
| <span class="normal"> 83</span> | |
| <span class="normal"> 84</span> | |
| <span class="normal"> 85</span> | |
| <span class="normal"> 86</span> | |
| <span class="normal"> 87</span> | |
| <span class="normal"> 88</span> | |
| <span class="normal"> 89</span> | |
| <span class="normal"> 90</span> | |
| <span class="normal"> 91</span> | |
| <span class="normal"> 92</span> | |
| <span class="normal"> 93</span> | |
| <span class="normal"> 94</span> | |
| <span class="normal"> 95</span> | |
| <span class="normal"> 96</span> | |
| <span class="normal"> 97</span> | |
| <span class="normal"> 98</span> | |
| <span class="normal"> 99</span> | |
| <span class="normal">100</span> | |
| <span class="normal">101</span> | |
| <span class="normal">102</span> | |
| <span class="normal">103</span> | |
| <span class="normal">104</span> | |
| <span class="normal">105</span> | |
| <span class="normal">106</span> | |
| <span class="normal">107</span> | |
| <span class="normal">108</span> | |
| <span class="normal">109</span> | |
| <span class="normal">110</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">json</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> | |
| <span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span> | |
| <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> | |
| <span class="c1"># List of expected result files</span> | |
| <span class="n">yamoe_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_YAMOE_RUN'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">binned_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_BINNED_RUN'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">gptoss_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_GPTOSS_RUN'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">gptoss_training_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_GPTOSS_TRAINING_RUN'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">megablocks_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'UVNOTE_INPUT_MEGABLOCKS_RUN'</span><span class="p">,</span> <span class="s1">'.'</span><span class="p">)</span> | |
| <span class="n">result_files</span> <span class="o">=</span> <span class="p">[</span> | |
| <span class="n">Path</span><span class="p">(</span><span class="n">yamoe_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">"yamoe_results.json"</span><span class="p">,</span> | |
| <span class="n">Path</span><span class="p">(</span><span class="n">binned_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">"binned_results.json"</span><span class="p">,</span> | |
| <span class="n">Path</span><span class="p">(</span><span class="n">gptoss_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">"gptoss_results.json"</span><span class="p">,</span> | |
| <span class="n">Path</span><span class="p">(</span><span class="n">gptoss_training_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">"gptoss_training_results.json"</span><span class="p">,</span> | |
| <span class="n">Path</span><span class="p">(</span><span class="n">megablocks_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">"megablocks_results.json"</span> | |
| <span class="p">]</span> | |
| <span class="c1"># Load all benchmark results</span> | |
| <span class="n">results</span> <span class="o">=</span> <span class="p">{}</span> | |
| <span class="k">for</span> <span class="n">file</span> <span class="ow">in</span> <span class="n">result_files</span><span class="p">:</span> | |
| <span class="k">if</span> <span class="n">Path</span><span class="p">(</span><span class="n">file</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span> | |
| <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> | |
| <span class="n">data</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> | |
| <span class="n">results</span><span class="p">[</span><span class="n">data</span><span class="p">[</span><span class="s1">'implementation'</span><span class="p">]]</span> <span class="o">=</span> <span class="n">data</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Loaded </span><span class="si">{</span><span class="n">file</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">else</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Missing </span><span class="si">{</span><span class="n">file</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="ow">not</span> <span class="n">results</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"No benchmark results found. Run the benchmark cells first."</span><span class="p">)</span> | |
| <span class="k">else</span><span class="p">:</span> | |
| <span class="c1"># Extract data for plotting</span> | |
| <span class="n">implementations</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">results</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> | |
| <span class="n">avg_latencies</span> <span class="o">=</span> <span class="p">[</span><span class="n">results</span><span class="p">[</span><span class="n">impl</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span> <span class="k">for</span> <span class="n">impl</span> <span class="ow">in</span> <span class="n">implementations</span><span class="p">]</span> | |
| <span class="n">p95_latencies</span> <span class="o">=</span> <span class="p">[</span><span class="n">results</span><span class="p">[</span><span class="n">impl</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'p95_ms'</span><span class="p">]</span> <span class="k">for</span> <span class="n">impl</span> <span class="ow">in</span> <span class="n">implementations</span><span class="p">]</span> | |
| <span class="n">throughputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">results</span><span class="p">[</span><span class="n">impl</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">]</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'tokens_per_s'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">impl</span> <span class="ow">in</span> <span class="n">implementations</span><span class="p">]</span> | |
| <span class="c1"># Create figure with subplots</span> | |
| <span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="n">ax2</span><span class="p">,</span> <span class="n">ax3</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">18</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> | |
| <span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s1">'MoE Implementation Performance Comparison'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="c1"># Colors for each implementation</span> | |
| <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'#FF6B6B'</span><span class="p">,</span> <span class="s1">'#4ECDC4'</span><span class="p">,</span> <span class="s1">'#45B7D1'</span><span class="p">,</span> <span class="s1">'#96CEB4'</span><span class="p">,</span> <span class="s1">'#FECA57'</span><span class="p">][:</span><span class="nb">len</span><span class="p">(</span><span class="n">implementations</span><span class="p">)]</span> | |
| <span class="c1"># 1. Average Latency Chart</span> | |
| <span class="n">bars1</span> <span class="o">=</span> <span class="n">ax1</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">implementations</span><span class="p">,</span> <span class="n">avg_latencies</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">colors</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'black'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">ax1</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">'Average Latency'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> | |
| <span class="n">ax1</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">'Latency (ms)'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="n">ax1</span><span class="o">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">'x'</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span> | |
| <span class="n">ax1</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">'y'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span> | |
| <span class="c1"># Add value labels on bars</span> | |
| <span class="k">for</span> <span class="n">bar</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bars1</span><span class="p">,</span> <span class="n">avg_latencies</span><span class="p">):</span> | |
| <span class="n">ax1</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">bar</span><span class="o">.</span><span class="n">get_x</span><span class="p">()</span> <span class="o">+</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_width</span><span class="p">()</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_height</span><span class="p">()</span> <span class="o">+</span> <span class="nb">max</span><span class="p">(</span><span class="n">avg_latencies</span><span class="p">)</span><span class="o">*</span><span class="mf">0.01</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="s1">.2f</span><span class="si">}</span><span class="s1">ms'</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s1">'center'</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s1">'bottom'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="c1"># 2. P95 Latency Chart</span> | |
| <span class="n">bars2</span> <span class="o">=</span> <span class="n">ax2</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">implementations</span><span class="p">,</span> <span class="n">p95_latencies</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">colors</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'black'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">ax2</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">'95th Percentile Latency'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> | |
| <span class="n">ax2</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">'Latency (ms)'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="n">ax2</span><span class="o">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">'x'</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span> | |
| <span class="n">ax2</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">'y'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span> | |
| <span class="c1"># Add value labels on bars</span> | |
| <span class="k">for</span> <span class="n">bar</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bars2</span><span class="p">,</span> <span class="n">p95_latencies</span><span class="p">):</span> | |
| <span class="n">ax2</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">bar</span><span class="o">.</span><span class="n">get_x</span><span class="p">()</span> <span class="o">+</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_width</span><span class="p">()</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_height</span><span class="p">()</span> <span class="o">+</span> <span class="nb">max</span><span class="p">(</span><span class="n">p95_latencies</span><span class="p">)</span><span class="o">*</span><span class="mf">0.01</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="s1">.2f</span><span class="si">}</span><span class="s1">ms'</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s1">'center'</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s1">'bottom'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="c1"># 3. Throughput Chart</span> | |
| <span class="n">bars3</span> <span class="o">=</span> <span class="n">ax3</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">implementations</span><span class="p">,</span> <span class="n">throughputs</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">colors</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">'black'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> | |
| <span class="n">ax3</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">'Throughput'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> | |
| <span class="n">ax3</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">'Tokens/sec'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="n">ax3</span><span class="o">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">'x'</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span> | |
| <span class="n">ax3</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">'y'</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span> | |
| <span class="c1"># Add value labels on bars</span> | |
| <span class="k">for</span> <span class="n">bar</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bars3</span><span class="p">,</span> <span class="n">throughputs</span><span class="p">):</span> | |
| <span class="k">if</span> <span class="n">val</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># Only show label if throughput was calculated</span> | |
| <span class="n">ax3</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">bar</span><span class="o">.</span><span class="n">get_x</span><span class="p">()</span> <span class="o">+</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_width</span><span class="p">()</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_height</span><span class="p">()</span> <span class="o">+</span> <span class="nb">max</span><span class="p">(</span><span class="n">throughputs</span><span class="p">)</span><span class="o">*</span><span class="mf">0.01</span><span class="p">,</span> | |
| <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="s1">.0f</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s1">'center'</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s1">'bottom'</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">'bold'</span><span class="p">)</span> | |
| <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span> | |
| <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="s2">"moe_performance_comparison.png"</span><span class="p">,</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">300</span><span class="p">)</span> | |
| <span class="c1"># Print summary table</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Performance Summary:"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'Implementation'</span><span class="si">:</span><span class="s2"><30</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">'Avg (ms)'</span><span class="si">:</span><span class="s2"><12</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">'P95 (ms)'</span><span class="si">:</span><span class="s2"><12</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">'Tokens/sec'</span><span class="si">:</span><span class="s2"><12</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">'Relative Speed'</span><span class="si">:</span><span class="s2"><15</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="s2">"-"</span><span class="o">*</span><span class="mi">80</span><span class="p">)</span> | |
| <span class="c1"># Sort by average latency for relative speed calculation</span> | |
| <span class="n">sorted_results</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">results</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">])</span> | |
| <span class="n">fastest_latency</span> <span class="o">=</span> <span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span> | |
| <span class="k">for</span> <span class="n">impl</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">sorted_results</span><span class="p">:</span> | |
| <span class="n">avg_ms</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span> | |
| <span class="n">p95_ms</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'p95_ms'</span><span class="p">]</span> | |
| <span class="n">tokens_s</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">'stats'</span><span class="p">]</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'tokens_per_s'</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> | |
| <span class="n">relative_speed</span> <span class="o">=</span> <span class="n">fastest_latency</span> <span class="o">/</span> <span class="n">avg_ms</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">impl</span><span class="si">:</span><span class="s2"><30</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">avg_ms</span><span class="si">:</span><span class="s2">>8.2f</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">p95_ms</span><span class="si">:</span><span class="s2">>8.2f</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">tokens_s</span><span class="si">:</span><span class="s2">>8.0f</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">relative_speed</span><span class="si">:</span><span class="s2">>6.2f</span><span class="si">}</span><span class="s2">x"</span><span class="p">)</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Fastest: </span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">ms avg)"</span><span class="p">)</span> | |
| <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">sorted_results</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Slowest: </span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">ms avg)"</span><span class="p">)</span> | |
| <span class="n">speedup</span> <span class="o">=</span> <span class="n">sorted_results</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span> <span class="o">/</span> <span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">'stats'</span><span class="p">][</span><span class="s1">'avg_ms'</span><span class="p">]</span> | |
| <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Max Speedup: </span><span class="si">{</span><span class="n">speedup</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">x"</span><span class="p">)</span> | |
| </pre></div></td></tr></table></div> | |
| </div> | |
| <div id="output-visualization" class="cell-output"> | |
| <div class="cell-stdout">Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/c5c8a351e1080ea89737c25df783e5c81cd76df0f2b017cedfd813e3bdf2f9f9/yamoe_results.json | |
| Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/af01d090b967f1cb05cacea7795553418933b27fc2f188da52f7c4642e456c24/binned_results.json | |
| Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/cf359ebbdbfd10241ce11898ee298eefd5da768c42d502b034caf3ba5b16aed6/gptoss_results.json | |
| Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/28eb2a85c2dc94e627a0c6373b55120bd67c549ef80cd5b5e94ae756ecd11aff/gptoss_training_results.json | |
| Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/a712c225c474c8776a91d23a96a2d4dd5dde0716ed16f6eb0dce9d92b65e06b8/megablocks_results.json | |
| Performance Summary: | |
| Implementation Avg (ms) P95 (ms) Tokens/sec Relative Speed | |
| -------------------------------------------------------------------------------- | |
| megablocks_results 4.13 8.58 24195 1.00x | |
| yamoe_results 8.63 8.65 11587 0.48x | |
| gptoss_results 47.14 47.80 2122 0.09x | |
| gptoss_training_results 48.63 49.35 2056 0.08x | |
| binned_results 105.62 107.73 947 0.04x | |
| Fastest: megablocks_results (4.13ms avg) | |
| Slowest: binned_results (105.62ms avg) | |
| Max Speedup: 25.6x | |
| </div> | |
| <div class="cell-stderr">Downloading numpy (15.9MiB) | |
| Downloading fonttools (4.7MiB) | |
| Downloading pillow (6.3MiB) | |
| Downloading matplotlib (8.3MiB) | |
| Downloading kiwisolver (1.4MiB) | |
| Downloading kiwisolver | |
| Downloading pillow | |
| Downloading fonttools | |
| Downloading matplotlib | |
| Downloading numpy | |
| Installed 11 packages in 24ms | |
| </div> | |
| <div class="cell-artifacts"> | |
| <h4>Artifacts:</h4> | |
| <a href="artifacts/visualization/moe_performance_comparison.png" class="artifact" target="_blank">moe_performance_comparison.png</a> | |
| <div class="artifact-preview"> | |
| <img src="artifacts/visualization/moe_performance_comparison.png" alt="moe_performance_comparison.png"> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </body> | |
| </html> |