explore-uvnote-output / compare.html
drbh's picture
drbh HF Staff
Upload folder using huggingface_hub
b19bcf6 verified
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Compare Yamoe and Binned MoE Implementations</title>
<script>
// Apply theme immediately to prevent flicker
(function() {
const configTheme = 'dark';
let theme;
if (configTheme === 'auto') {
theme = window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
} else {
theme = localStorage.getItem('uvnote-theme') || configTheme;
}
document.documentElement.setAttribute('data-theme', theme);
})();
</script>
<style>
:root[data-theme="light"] {
--bg-primary: #ffffff;
--bg-secondary: #f6f8fa;
--bg-tertiary: #f8f9fa;
--bg-code: #f8f9fa;
--bg-error: #fdf2f2;
--bg-artifact: #e6f3ff;
--bg-artifact-hover: #d0e7ff;
--text-primary: #333;
--text-secondary: #656d76;
--text-error: #c53030;
--text-link: #0969da;
--border-primary: #e1e5e9;
--border-error: #e53e3e;
--border-cell-failed: #d73a49;
--shadow: rgba(0, 0, 0, 0.1);
}
:root[data-theme="dark"] {
--bg-primary: #0a0a0a;
--bg-secondary: #121212;
--bg-tertiary: #181818;
--bg-code: #0d0d0d;
--bg-error: #1a0f0f;
--bg-artifact: #151515;
--bg-artifact-hover: #1a1a1a;
--text-primary: #e0e0e0;
--text-secondary: #888888;
--text-error: #ff6b6b;
--text-link: #64b5f6;
--border-primary: #2a2a2a;
--border-error: #ff6b6b;
--border-cell-failed: #ff6b6b;
--shadow: rgba(255, 255, 255, 0.05);
}
body {
font-family: 'Cascadia Mono', 'Cascadia Code', 'JetBrains Mono', 'SF Mono', Monaco, 'Consolas', monospace;
line-height: 1.4;
max-width: 1000px;
margin: 0 auto;
padding: 15px;
color: var(--text-primary);
background: var(--bg-primary);
transition: background-color 0.2s ease, color 0.2s ease;
}
/* Two panel layout removed */
.controls {
position: fixed;
top: 20px;
right: 20px;
display: flex;
gap: 0.5rem;
z-index: 1000;
}
.theme-toggle, .reset-toggle {
background: var(--bg-secondary);
border: 1px solid var(--border-primary);
border-radius: 2px;
padding: 0.4rem 0.6rem;
cursor: pointer;
font-family: inherit;
font-size: 0.8rem;
color: var(--text-secondary);
user-select: none;
transition: all 0.2s ease;
text-transform: lowercase;
letter-spacing: 0;
}
.theme-toggle:hover, .reset-toggle:hover {
background: var(--bg-tertiary);
border-color: var(--text-secondary);
color: var(--text-primary);
}
.minimap {
position: fixed;
bottom: 20px;
right: 20px;
width: 220px;
max-height: 400px;
background: var(--bg-secondary);
border: 1px solid var(--border-primary);
border-radius: 2px;
padding: 0.5rem;
font-size: 0.7rem;
overflow-y: auto;
z-index: 100;
opacity: 0.9;
transition: opacity 0.2s ease;
}
.file-explorer {
position: fixed;
bottom: 20px; /* default; JS will stack */
right: 20px;
left: auto;
top: auto;
width: 220px;
max-height: 400px;
background: var(--bg-secondary);
border: 1px solid var(--border-primary);
border-radius: 2px;
padding: 0.5rem;
font-size: 0.7rem;
overflow-y: auto;
z-index: 100;
opacity: 0.9;
transition: opacity 0.2s ease;
}
/* Drawing overlay */
.draw-overlay {
position: fixed;
top: 0;
left: 0;
width: 100vw;
height: 100vh;
z-index: 80; /* under widgets (100) and controls (1000) */
display: block;
pointer-events: none; /* enabled only when a tool is active */
}
/* Tools widget */
.tools-widget {
position: fixed;
bottom: 20px; /* default; JS will stack */
right: 20px;
left: auto;
top: auto;
width: 220px;
background: var(--bg-secondary);
border: 1px solid var(--border-primary);
border-radius: 2px;
padding: 0.5rem;
font-size: 0.7rem;
z-index: 100;
opacity: 0.95;
}
.tools-title {
font-weight: bold;
color: var(--text-secondary);
margin-bottom: 0.5rem;
padding-bottom: 0.25rem;
border-bottom: 1px solid var(--border-primary);
cursor: grab;
user-select: none;
}
.tools-row { display: flex; gap: 0.4rem; flex-wrap: wrap; }
.tool-button {
background: var(--bg-tertiary);
border: 1px solid var(--border-primary);
border-radius: 2px;
padding: 0.25rem 0.4rem;
cursor: pointer;
color: var(--text-secondary);
font-family: inherit;
font-size: 0.75rem;
user-select: none;
}
.tool-button:hover { color: var(--text-primary); }
.tool-button.active { color: var(--text-primary); border-color: var(--text-secondary); background: var(--bg-secondary); }
.minimap:hover, .file-explorer:hover {
opacity: 1;
}
.minimap-title {
font-weight: bold;
color: var(--text-secondary);
margin-bottom: 0.5rem;
padding-bottom: 0.25rem;
border-bottom: 1px solid var(--border-primary);
cursor: grab; /* drag handle */
user-select: none;
}
.minimap-item {
display: block;
color: var(--text-secondary);
text-decoration: none;
padding: 0.15rem 0;
border-left: 2px solid transparent;
padding-left: 0.5rem;
transition: all 0.2s ease;
cursor: pointer;
}
.minimap-item:hover {
color: var(--text-primary);
border-left-color: var(--text-secondary);
}
.minimap-item.active {
color: var(--text-primary);
border-left-color: var(--text-link);
}
.minimap-heading {
font-weight: normal;
}
.minimap-heading.h1 { padding-left: 0.5rem; }
.minimap-heading.h2 { padding-left: 1rem; }
.minimap-heading.h3 { padding-left: 1.5rem; }
.minimap-heading.h4 { padding-left: 2rem; }
.minimap-heading.h5 { padding-left: 2.5rem; }
.minimap-heading.h6 { padding-left: 3rem; }
.minimap-cell {
color: var(--text-link);
opacity: 0.8;
font-style: italic;
}
.minimap-cell:hover {
opacity: 1;
}
.file-explorer-title {
font-weight: bold;
color: var(--text-secondary);
margin-bottom: 0.5rem;
padding-bottom: 0.25rem;
border-bottom: 1px solid var(--border-primary);
cursor: grab; /* drag handle */
user-select: none;
}
.file-explorer-section {
margin-bottom: 0.75rem;
}
.file-explorer-section-title {
font-weight: bold;
color: var(--text-secondary);
font-size: 0.65rem;
margin-bottom: 0.25rem;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.file-explorer-item {
display: block;
color: var(--text-secondary);
text-decoration: none;
padding: 0.1rem 0;
margin-left: 0.5rem;
transition: color 0.2s ease;
cursor: pointer;
font-family: monospace;
}
.file-explorer-item:hover {
color: var(--text-primary);
}
.file-explorer-item.script {
color: var(--text-link);
}
.file-explorer-item.artifact {
color: var(--text-secondary);
opacity: 0.8;
}
/* Slide functionality */
.minimap, .file-explorer, .tools-widget {
transition: transform 0.3s ease;
}
.minimap.slide-off, .file-explorer.slide-off, .tools-widget.slide-off {
transform: translateX(calc(100% - 20px));
}
.minimap-title::before, .file-explorer-title::before, .tools-title::before {
content: '‹';
float: left;
cursor: pointer;
color: var(--text-secondary);
user-select: none;
margin-right: 8px;
}
.minimap-title::after, .file-explorer-title::after, .tools-title::after {
content: '›';
float: right;
cursor: pointer;
color: var(--text-secondary);
user-select: none;
}
.minimap.slide-off .minimap-title::after,
.file-explorer.slide-off .file-explorer-title::after,
.tools-widget.slide-off .tools-title::after {
content: '‹';
}
/* Hide widgets on smaller screens */
@media (max-width: 768px) {
.minimap, .file-explorer, .tools-widget {
display: none;
}
}
.cell {
margin: 1rem 0;
border: 1px solid var(--border-primary);
border-radius: 2px;
overflow: hidden;
background: var(--bg-secondary);
}
.cell-header {
background: var(--bg-secondary);
padding: 0.5rem 1rem;
border-bottom: 1px solid var(--border-primary);
font-family: inherit;
font-size: 0.85rem;
color: var(--text-secondary);
cursor: pointer;
user-select: none;
transition: background-color 0.2s ease;
}
.cell-header:hover {
background: var(--bg-tertiary);
}
.collapse-indicators {
color: var(--text-secondary);
font-size: 0.8rem;
opacity: 0.7;
}
.collapse-indicators span:hover {
color: var(--text-primary);
opacity: 1;
}
.cell-code {
display: block;
background: var(--bg-code);
}
.cell-code.collapsed {
display: none;
}
.cell-code pre {
margin: 0;
padding: 0.75rem;
background: var(--bg-code);
overflow-x: auto;
color: var(--text-primary);
}
.cell-output {
padding: 0.75rem;
background: var(--bg-primary);
}
.cell-output.collapsed {
display: none;
}
.cell-stdout {
background: var(--bg-tertiary);
padding: 0.75rem;
border-radius: 1px;
margin: 0.25rem 0;
font-family: inherit;
font-size: 0.9rem;
white-space: pre-wrap;
color: var(--text-primary);
}
.cell-stderr {
background: var(--bg-error);
border-left: 2px solid var(--border-error);
padding: 1rem;
margin: 0.5rem 0;
font-family: inherit;
font-size: 0.9rem;
color: var(--text-error);
white-space: pre-wrap;
}
.cell-artifacts {
margin: 1rem 0;
}
.cell-artifacts h4 {
margin: 0 0 0.5rem 0;
color: var(--text-secondary);
font-size: 0.9rem;
}
.artifact {
display: inline-block;
background: var(--bg-artifact);
padding: 0.25rem 0.5rem;
border-radius: 1px;
margin: 0.25rem 0.5rem 0.25rem 0;
font-family: inherit;
font-size: 0.8rem;
color: var(--text-link);
text-decoration: none;
transition: background-color 0.2s ease;
border: 1px solid var(--border-primary);
}
.artifact:hover {
background: var(--bg-artifact-hover);
}
.artifact-preview {
margin-top: 1rem;
}
.artifact-preview img {
max-width: 100%;
height: auto;
border: 1px solid var(--border-primary);
border-radius: 1px;
}
.artifact-preview svg {
max-width: 100%;
height: auto;
border: 1px solid var(--border-primary);
border-radius: 1px;
display: block;
}
/* Style SVG text elements */
.artifact-preview svg g {
fill: var(--text-primary) !important;
}
/* Auto-theme SVG elements */
.artifact-preview svg {
background: transparent;
}
.cell-failed {
border-color: var(--border-cell-failed);
}
.cell-failed .cell-header {
background: var(--bg-error);
color: var(--text-error);
}
.run-btn {
background: var(--bg-tertiary);
border: 1px solid var(--border-primary);
padding: 2px 6px;
border-radius: 2px;
color: var(--text-secondary);
cursor: pointer;
font-size: 0.75em;
font-family: inherit;
margin-left: 4px;
}
.run-btn:hover {
color: var(--text-primary);
background: var(--bg-primary);
}
.run-btn:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.copy-btn {
background: var(--bg-tertiary);
border: 1px solid var(--border-primary);
padding: 2px 6px;
border-radius: 2px;
color: var(--text-secondary);
cursor: pointer;
font-size: 0.75em;
font-family: inherit;
margin-left: 4px;
}
.copy-btn:hover {
color: var(--text-primary);
background: var(--bg-primary);
}
.copy-btn:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.output-stale {
opacity: 0.5;
position: relative;
}
.output-stale::after {
content: '⏳ updating...';
position: absolute;
top: 8px;
right: 8px;
background: var(--bg-secondary);
padding: 4px 8px;
border-radius: 2px;
font-size: 0.75em;
color: var(--text-secondary);
border: 1px solid var(--border-primary);
}
h1, h2, h3, h4, h5, h6 {
margin-top: 1.5rem;
margin-bottom: 0.75rem;
color: var(--text-primary);
}
h1 {
margin-top: 0;
margin-bottom: 1rem;
}
p {
margin: 0.75rem 0;
color: var(--text-primary);
}
a {
color: var(--text-link);
}
img {
max-width: 100%;
height: auto;
border-radius: 1px;
box-shadow: none;
}
pre, code {
font-family: 'Cascadia Mono', 'Cascadia Code', 'JetBrains Mono', 'SF Mono', Monaco, 'Consolas', monospace;
}
/* Line numbers */
.highlight-with-lines {
display: flex;
}
.line-numbers {
background: var(--bg-tertiary);
padding: 0.75rem 0.5rem;
font-family: 'Cascadia Mono', 'Cascadia Code', 'JetBrains Mono', 'SF Mono', Monaco, 'Consolas', monospace;
font-size: 0.9rem;
color: var(--text-secondary);
user-select: none;
text-align: right;
border-right: 1px solid var(--border-primary);
}
.line-numbers .line-number {
display: block;
line-height: 1.5;
}
.highlight-with-lines .highlight {
flex: 1;
}
.highlight-with-lines .highlight pre {
padding-left: 0.75rem;
}
/* Collapsed code styling */
.cell-code.collapsed {
display: none;
}
.cell-code.expanded {
display: block;
}
.cell-code {
display: block;
}
pre { line-height: 125%; }
td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
[data-theme="light"] .highlight .hll { background-color: #ffffcc }
[data-theme="light"] .highlight { background: #f8f8f8; }
[data-theme="light"] .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */
[data-theme="light"] .highlight .err { border: 1px solid #F00 } /* Error */
[data-theme="light"] .highlight .k { color: #008000; font-weight: bold } /* Keyword */
[data-theme="light"] .highlight .o { color: #666 } /* Operator */
[data-theme="light"] .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */
[data-theme="light"] .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */
[data-theme="light"] .highlight .cp { color: #9C6500 } /* Comment.Preproc */
[data-theme="light"] .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */
[data-theme="light"] .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */
[data-theme="light"] .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */
[data-theme="light"] .highlight .gd { color: #A00000 } /* Generic.Deleted */
[data-theme="light"] .highlight .ge { font-style: italic } /* Generic.Emph */
[data-theme="light"] .highlight .ges { font-weight: bold; font-style: italic } /* Generic.EmphStrong */
[data-theme="light"] .highlight .gr { color: #E40000 } /* Generic.Error */
[data-theme="light"] .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */
[data-theme="light"] .highlight .gi { color: #008400 } /* Generic.Inserted */
[data-theme="light"] .highlight .go { color: #717171 } /* Generic.Output */
[data-theme="light"] .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */
[data-theme="light"] .highlight .gs { font-weight: bold } /* Generic.Strong */
[data-theme="light"] .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
[data-theme="light"] .highlight .gt { color: #04D } /* Generic.Traceback */
[data-theme="light"] .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */
[data-theme="light"] .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */
[data-theme="light"] .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */
[data-theme="light"] .highlight .kp { color: #008000 } /* Keyword.Pseudo */
[data-theme="light"] .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */
[data-theme="light"] .highlight .kt { color: #B00040 } /* Keyword.Type */
[data-theme="light"] .highlight .m { color: #666 } /* Literal.Number */
[data-theme="light"] .highlight .s { color: #BA2121 } /* Literal.String */
[data-theme="light"] .highlight .na { color: #687822 } /* Name.Attribute */
[data-theme="light"] .highlight .nb { color: #008000 } /* Name.Builtin */
[data-theme="light"] .highlight .nc { color: #00F; font-weight: bold } /* Name.Class */
[data-theme="light"] .highlight .no { color: #800 } /* Name.Constant */
[data-theme="light"] .highlight .nd { color: #A2F } /* Name.Decorator */
[data-theme="light"] .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */
[data-theme="light"] .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */
[data-theme="light"] .highlight .nf { color: #00F } /* Name.Function */
[data-theme="light"] .highlight .nl { color: #767600 } /* Name.Label */
[data-theme="light"] .highlight .nn { color: #00F; font-weight: bold } /* Name.Namespace */
[data-theme="light"] .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */
[data-theme="light"] .highlight .nv { color: #19177C } /* Name.Variable */
[data-theme="light"] .highlight .ow { color: #A2F; font-weight: bold } /* Operator.Word */
[data-theme="light"] .highlight .w { color: #BBB } /* Text.Whitespace */
[data-theme="light"] .highlight .mb { color: #666 } /* Literal.Number.Bin */
[data-theme="light"] .highlight .mf { color: #666 } /* Literal.Number.Float */
[data-theme="light"] .highlight .mh { color: #666 } /* Literal.Number.Hex */
[data-theme="light"] .highlight .mi { color: #666 } /* Literal.Number.Integer */
[data-theme="light"] .highlight .mo { color: #666 } /* Literal.Number.Oct */
[data-theme="light"] .highlight .sa { color: #BA2121 } /* Literal.String.Affix */
[data-theme="light"] .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */
[data-theme="light"] .highlight .sc { color: #BA2121 } /* Literal.String.Char */
[data-theme="light"] .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */
[data-theme="light"] .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */
[data-theme="light"] .highlight .s2 { color: #BA2121 } /* Literal.String.Double */
[data-theme="light"] .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */
[data-theme="light"] .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */
[data-theme="light"] .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */
[data-theme="light"] .highlight .sx { color: #008000 } /* Literal.String.Other */
[data-theme="light"] .highlight .sr { color: #A45A77 } /* Literal.String.Regex */
[data-theme="light"] .highlight .s1 { color: #BA2121 } /* Literal.String.Single */
[data-theme="light"] .highlight .ss { color: #19177C } /* Literal.String.Symbol */
[data-theme="light"] .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */
[data-theme="light"] .highlight .fm { color: #00F } /* Name.Function.Magic */
[data-theme="light"] .highlight .vc { color: #19177C } /* Name.Variable.Class */
[data-theme="light"] .highlight .vg { color: #19177C } /* Name.Variable.Global */
[data-theme="light"] .highlight .vi { color: #19177C } /* Name.Variable.Instance */
[data-theme="light"] .highlight .vm { color: #19177C } /* Name.Variable.Magic */
[data-theme="light"] .highlight .il { color: #666 } /* Literal.Number.Integer.Long */
pre { line-height: 125%; }
td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
[data-theme="dark"] .highlight .hll { background-color: #49483e }
[data-theme="dark"] .highlight { background: #272822; color: #F8F8F2 }
[data-theme="dark"] .highlight .c { color: #959077 } /* Comment */
[data-theme="dark"] .highlight .err { color: #ED007E; background-color: #1E0010 } /* Error */
[data-theme="dark"] .highlight .esc { color: #F8F8F2 } /* Escape */
[data-theme="dark"] .highlight .g { color: #F8F8F2 } /* Generic */
[data-theme="dark"] .highlight .k { color: #66D9EF } /* Keyword */
[data-theme="dark"] .highlight .l { color: #AE81FF } /* Literal */
[data-theme="dark"] .highlight .n { color: #F8F8F2 } /* Name */
[data-theme="dark"] .highlight .o { color: #FF4689 } /* Operator */
[data-theme="dark"] .highlight .x { color: #F8F8F2 } /* Other */
[data-theme="dark"] .highlight .p { color: #F8F8F2 } /* Punctuation */
[data-theme="dark"] .highlight .ch { color: #959077 } /* Comment.Hashbang */
[data-theme="dark"] .highlight .cm { color: #959077 } /* Comment.Multiline */
[data-theme="dark"] .highlight .cp { color: #959077 } /* Comment.Preproc */
[data-theme="dark"] .highlight .cpf { color: #959077 } /* Comment.PreprocFile */
[data-theme="dark"] .highlight .c1 { color: #959077 } /* Comment.Single */
[data-theme="dark"] .highlight .cs { color: #959077 } /* Comment.Special */
[data-theme="dark"] .highlight .gd { color: #FF4689 } /* Generic.Deleted */
[data-theme="dark"] .highlight .ge { color: #F8F8F2; font-style: italic } /* Generic.Emph */
[data-theme="dark"] .highlight .ges { color: #F8F8F2; font-weight: bold; font-style: italic } /* Generic.EmphStrong */
[data-theme="dark"] .highlight .gr { color: #F8F8F2 } /* Generic.Error */
[data-theme="dark"] .highlight .gh { color: #F8F8F2 } /* Generic.Heading */
[data-theme="dark"] .highlight .gi { color: #A6E22E } /* Generic.Inserted */
[data-theme="dark"] .highlight .go { color: #66D9EF } /* Generic.Output */
[data-theme="dark"] .highlight .gp { color: #FF4689; font-weight: bold } /* Generic.Prompt */
[data-theme="dark"] .highlight .gs { color: #F8F8F2; font-weight: bold } /* Generic.Strong */
[data-theme="dark"] .highlight .gu { color: #959077 } /* Generic.Subheading */
[data-theme="dark"] .highlight .gt { color: #F8F8F2 } /* Generic.Traceback */
[data-theme="dark"] .highlight .kc { color: #66D9EF } /* Keyword.Constant */
[data-theme="dark"] .highlight .kd { color: #66D9EF } /* Keyword.Declaration */
[data-theme="dark"] .highlight .kn { color: #FF4689 } /* Keyword.Namespace */
[data-theme="dark"] .highlight .kp { color: #66D9EF } /* Keyword.Pseudo */
[data-theme="dark"] .highlight .kr { color: #66D9EF } /* Keyword.Reserved */
[data-theme="dark"] .highlight .kt { color: #66D9EF } /* Keyword.Type */
[data-theme="dark"] .highlight .ld { color: #E6DB74 } /* Literal.Date */
[data-theme="dark"] .highlight .m { color: #AE81FF } /* Literal.Number */
[data-theme="dark"] .highlight .s { color: #E6DB74 } /* Literal.String */
[data-theme="dark"] .highlight .na { color: #A6E22E } /* Name.Attribute */
[data-theme="dark"] .highlight .nb { color: #F8F8F2 } /* Name.Builtin */
[data-theme="dark"] .highlight .nc { color: #A6E22E } /* Name.Class */
[data-theme="dark"] .highlight .no { color: #66D9EF } /* Name.Constant */
[data-theme="dark"] .highlight .nd { color: #A6E22E } /* Name.Decorator */
[data-theme="dark"] .highlight .ni { color: #F8F8F2 } /* Name.Entity */
[data-theme="dark"] .highlight .ne { color: #A6E22E } /* Name.Exception */
[data-theme="dark"] .highlight .nf { color: #A6E22E } /* Name.Function */
[data-theme="dark"] .highlight .nl { color: #F8F8F2 } /* Name.Label */
[data-theme="dark"] .highlight .nn { color: #F8F8F2 } /* Name.Namespace */
[data-theme="dark"] .highlight .nx { color: #A6E22E } /* Name.Other */
[data-theme="dark"] .highlight .py { color: #F8F8F2 } /* Name.Property */
[data-theme="dark"] .highlight .nt { color: #FF4689 } /* Name.Tag */
[data-theme="dark"] .highlight .nv { color: #F8F8F2 } /* Name.Variable */
[data-theme="dark"] .highlight .ow { color: #FF4689 } /* Operator.Word */
[data-theme="dark"] .highlight .pm { color: #F8F8F2 } /* Punctuation.Marker */
[data-theme="dark"] .highlight .w { color: #F8F8F2 } /* Text.Whitespace */
[data-theme="dark"] .highlight .mb { color: #AE81FF } /* Literal.Number.Bin */
[data-theme="dark"] .highlight .mf { color: #AE81FF } /* Literal.Number.Float */
[data-theme="dark"] .highlight .mh { color: #AE81FF } /* Literal.Number.Hex */
[data-theme="dark"] .highlight .mi { color: #AE81FF } /* Literal.Number.Integer */
[data-theme="dark"] .highlight .mo { color: #AE81FF } /* Literal.Number.Oct */
[data-theme="dark"] .highlight .sa { color: #E6DB74 } /* Literal.String.Affix */
[data-theme="dark"] .highlight .sb { color: #E6DB74 } /* Literal.String.Backtick */
[data-theme="dark"] .highlight .sc { color: #E6DB74 } /* Literal.String.Char */
[data-theme="dark"] .highlight .dl { color: #E6DB74 } /* Literal.String.Delimiter */
[data-theme="dark"] .highlight .sd { color: #E6DB74 } /* Literal.String.Doc */
[data-theme="dark"] .highlight .s2 { color: #E6DB74 } /* Literal.String.Double */
[data-theme="dark"] .highlight .se { color: #AE81FF } /* Literal.String.Escape */
[data-theme="dark"] .highlight .sh { color: #E6DB74 } /* Literal.String.Heredoc */
[data-theme="dark"] .highlight .si { color: #E6DB74 } /* Literal.String.Interpol */
[data-theme="dark"] .highlight .sx { color: #E6DB74 } /* Literal.String.Other */
[data-theme="dark"] .highlight .sr { color: #E6DB74 } /* Literal.String.Regex */
[data-theme="dark"] .highlight .s1 { color: #E6DB74 } /* Literal.String.Single */
[data-theme="dark"] .highlight .ss { color: #E6DB74 } /* Literal.String.Symbol */
[data-theme="dark"] .highlight .bp { color: #F8F8F2 } /* Name.Builtin.Pseudo */
[data-theme="dark"] .highlight .fm { color: #A6E22E } /* Name.Function.Magic */
[data-theme="dark"] .highlight .vc { color: #F8F8F2 } /* Name.Variable.Class */
[data-theme="dark"] .highlight .vg { color: #F8F8F2 } /* Name.Variable.Global */
[data-theme="dark"] .highlight .vi { color: #F8F8F2 } /* Name.Variable.Instance */
[data-theme="dark"] .highlight .vm { color: #F8F8F2 } /* Name.Variable.Magic */
[data-theme="dark"] .highlight .il { color: #AE81FF } /* Literal.Number.Integer.Long */
/* Custom CSS from frontmatter */
.cell-stderr { display: none; }
.minimap { display: none !important; }
.file-explorer { display: none !important; }
.cell-code { max-height: 400px; overflow: auto; }
/* Cursor for tools */
body[data-tool="arrow"] .main-content { cursor: crosshair; }
body[data-tool="pen"] .main-content { cursor: pointer; }
body[data-tool="eraser"] .main-content { cursor: cell; }
/* Color picker styles */
.tools-section-title {
font-weight: bold;
color: var(--text-secondary);
font-size: 0.65rem;
margin: 0.75rem 0 0.5rem 0;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.color-row {
display: grid;
grid-template-columns: repeat(6, 1fr);
gap: 0.25rem;
margin-bottom: 0.5rem;
}
.color-swatch {
width: 18px;
height: 18px;
border: 2px solid var(--border-primary);
border-radius: 3px;
cursor: pointer;
transition: all 0.2s ease;
position: relative;
}
.color-swatch:hover {
transform: scale(1.1);
border-color: var(--text-secondary);
}
.color-swatch.selected {
border-color: var(--text-primary);
box-shadow: 0 0 0 2px var(--text-link);
}
.color-swatch.selected::after {
content: '✓';
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
color: white;
font-size: 10px;
font-weight: bold;
text-shadow: 1px 1px 1px black;
}
.color-input {
width: 24px;
height: 24px;
border: 2px solid var(--border-primary);
border-radius: 3px;
cursor: pointer;
background: none;
padding: 0;
grid-column: span 2;
justify-self: center;
}
.color-input:hover {
border-color: var(--text-secondary);
}
/* Thickness slider styles */
.thickness-row {
display: flex;
align-items: center;
gap: 0.5rem;
margin-top: 0.75rem;
}
.thickness-slider {
flex: 1;
-webkit-appearance: none;
appearance: none;
height: 4px;
background: var(--border-primary);
border-radius: 2px;
outline: none;
opacity: 0.7;
transition: opacity 0.2s;
}
.thickness-slider:hover {
opacity: 1;
}
.thickness-slider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 12px;
height: 12px;
background: var(--text-link);
border-radius: 50%;
cursor: pointer;
}
.thickness-slider::-moz-range-thumb {
width: 12px;
height: 12px;
background: var(--text-link);
border-radius: 50%;
cursor: pointer;
border: none;
}
.thickness-value {
font-size: 0.7rem;
color: var(--text-secondary);
min-width: 20px;
text-align: right;
}
.highlight {
background: none !important;
}
/* Loading animations */
.loading-spinner {
display: inline-block;
width: 16px;
height: 16px;
border: 2px solid var(--border-primary);
border-radius: 50%;
border-top-color: var(--text-link);
animation: spin 1s linear infinite;
margin-right: 8px;
vertical-align: middle;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.loading-skeleton {
display: inline-block;
background: var(--bg-tertiary);
background: linear-gradient(
90deg,
var(--bg-tertiary) 25%,
var(--bg-secondary) 50%,
var(--bg-tertiary) 75%
);
background-size: 200% 100%;
animation: loading-shimmer 2s ease-in-out infinite;
border-radius: 2px;
height: 1em;
width: 80px;
vertical-align: middle;
}
@keyframes loading-shimmer {
0% { background-position: -200% 0; }
100% { background-position: 200% 0; }
}
/* Loading state for cell output */
.cell-output:has(.loading-spinner) {
opacity: 0.7;
background: var(--bg-secondary);
border-left: 3px solid var(--text-link);
}
</style>
<script>
// --- Drag utilities ---
function clamp(val, min, max) { return Math.max(min, Math.min(max, val)); }
function restorePosition(el, storageKey) {
try {
const raw = localStorage.getItem(storageKey);
if (!raw) return;
const pos = JSON.parse(raw);
if (typeof pos.left === 'number' && typeof pos.top === 'number') {
el.style.left = pos.left + 'px';
el.style.top = pos.top + 'px';
el.style.right = 'auto';
el.style.bottom = 'auto';
}
} catch (_) {}
}
function savePosition(el, storageKey) {
try {
const left = parseFloat(el.style.left || 'NaN');
const top = parseFloat(el.style.top || 'NaN');
if (!Number.isNaN(left) && !Number.isNaN(top)) {
localStorage.setItem(storageKey, JSON.stringify({ left, top }));
}
} catch (_) {}
}
function addSlideToggle(widget, titleEl) {
titleEl.onclick = function(e) {
const rect = titleEl.getBoundingClientRect();
const clickX = e.clientX - rect.left;
// Left arrow (always slides back on screen)
if (clickX < 30) {
widget.classList.remove('slide-off');
e.stopPropagation();
}
// Right arrow (always slides off screen)
else if (clickX > rect.width - 30) {
widget.classList.add('slide-off');
e.stopPropagation();
}
};
}
function makeDraggable(el, storageKey, handleEl) {
let dragging = false;
let startX = 0, startY = 0; // cursor
let origLeft = 0, origTop = 0; // element
const onMove = (e) => {
if (!dragging) return;
const clientX = e.touches ? e.touches[0].clientX : e.clientX;
const clientY = e.touches ? e.touches[0].clientY : e.clientY;
const dx = clientX - startX;
const dy = clientY - startY;
const w = el.offsetWidth;
const h = el.offsetHeight;
const maxX = window.innerWidth - w;
const maxY = window.innerHeight - h;
const newLeft = clamp(origLeft + dx, 0, maxX);
const newTop = clamp(origTop + dy, 0, maxY);
el.style.left = newLeft + 'px';
el.style.top = newTop + 'px';
el.style.right = 'auto';
el.style.bottom = 'auto';
};
const endDrag = () => {
if (!dragging) return;
dragging = false;
document.removeEventListener('mousemove', onMove);
document.removeEventListener('mouseup', endDrag);
document.removeEventListener('touchmove', onMove);
document.removeEventListener('touchend', endDrag);
handleEl && (handleEl.style.cursor = 'grab');
savePosition(el, storageKey);
// ensure no-overlap constraint after a drag
try { layoutWidgetsStackedBottomRight(); } catch (_) {}
};
const startDrag = (e) => {
// Check if click is on arrow areas - if so, don't start drag
if (handleEl) {
const rect = handleEl.getBoundingClientRect();
const clickX = e.clientX - rect.left;
if (clickX < 30 || clickX > rect.width - 30) {
return; // Don't start drag on arrow areas
}
}
// Start from element's current on-screen rect
const elRect = el.getBoundingClientRect();
el.style.left = elRect.left + 'px';
el.style.top = elRect.top + 'px';
el.style.right = 'auto';
el.style.bottom = 'auto';
dragging = true;
startX = e.touches ? e.touches[0].clientX : e.clientX;
startY = e.touches ? e.touches[0].clientY : e.clientY;
origLeft = elRect.left;
origTop = elRect.top;
document.addEventListener('mousemove', onMove);
document.addEventListener('mouseup', endDrag);
document.addEventListener('touchmove', onMove, { passive: false });
document.addEventListener('touchend', endDrag);
handleEl && (handleEl.style.cursor = 'grabbing');
e.preventDefault();
};
(handleEl || el).addEventListener('mousedown', startDrag);
(handleEl || el).addEventListener('touchstart', startDrag, { passive: false });
// Apply any saved position on init
restorePosition(el, storageKey);
}
function toggleCell(cellId) {
const codeElement = document.getElementById('code-' + cellId);
const outputElement = document.getElementById('output-' + cellId);
if (codeElement) {
codeElement.classList.toggle('collapsed');
}
if (outputElement) {
outputElement.classList.toggle('collapsed');
}
updateIndicators(cellId);
}
function toggleCode(cellId) {
const codeElement = document.getElementById('code-' + cellId);
if (codeElement) {
codeElement.classList.toggle('collapsed');
updateIndicators(cellId);
}
}
function toggleOutput(cellId) {
const outputElement = document.getElementById('output-' + cellId);
if (outputElement) {
outputElement.classList.toggle('collapsed');
updateIndicators(cellId);
}
}
function updateIndicators(cellId) {
const codeElement = document.getElementById('code-' + cellId);
const outputElement = document.getElementById('output-' + cellId);
const indicators = document.querySelector(`[onclick*="${cellId}"]`)?.closest('.cell-header')?.querySelector('.collapse-indicators');
if (indicators) {
const codeCollapsed = codeElement && codeElement.classList.contains('collapsed');
const outputCollapsed = outputElement && outputElement.classList.contains('collapsed');
const codeIcon = codeCollapsed ? '▶' : '▼';
const outputIcon = outputCollapsed ? '▶' : '▼';
const codeSpan = indicators.querySelector('[onclick*="toggleCode"]');
const outputSpan = indicators.querySelector('[onclick*="toggleOutput"]');
if (codeSpan) codeSpan.innerHTML = `${codeIcon} code`;
if (outputSpan) outputSpan.innerHTML = `${outputIcon} output`;
}
}
function toggleTheme() {
const html = document.documentElement;
const currentTheme = html.getAttribute('data-theme');
const newTheme = currentTheme === 'dark' ? 'light' : 'dark';
html.setAttribute('data-theme', newTheme);
localStorage.setItem('uvnote-theme', newTheme);
updateThemeIcon();
}
// Two panel code removed
function updateThemeIcon() {
const theme = document.documentElement.getAttribute('data-theme');
const toggle = document.querySelector('.theme-toggle');
if (toggle) {
toggle.textContent = theme === 'dark' ? 'light' : 'dark';
}
}
function resetLayout() {
try {
// Clear all uvnote-* keys
const allKeys = Object.keys(localStorage);
const uvnoteKeys = allKeys.filter(key => key.startsWith('uvnote-'));
uvnoteKeys.forEach(k => localStorage.removeItem(k));
} catch (_) {}
// Reload to reinitialize UI with defaults
location.reload();
}
// Layout: stack widgets bottom-right and equalize widths
function hasCustomWidgetPositions() {
try {
return (
localStorage.getItem('uvnote-minimap-pos') ||
localStorage.getItem('uvnote-file-explorer-pos') ||
localStorage.getItem('uvnote-tools-pos')
);
} catch (_) { return false; }
}
function rectsOverlap(r1, r2) {
return !(r1.right <= r2.left || r2.right <= r1.left || r1.bottom <= r2.top || r2.bottom <= r1.top);
}
function widgetsOverlap(widgets) {
for (let i = 0; i < widgets.length; i++) {
const a = widgets[i];
const ra = a.getBoundingClientRect();
for (let j = i + 1; j < widgets.length; j++) {
const b = widgets[j];
const rb = b.getBoundingClientRect();
if (rectsOverlap(ra, rb)) return true;
}
}
return false;
}
function applyStackLayout(widgets, order) {
if (!widgets.length) return;
// Fixed equal width
const fixedWidth = 220;
widgets.forEach(el => { el.style.width = fixedWidth + 'px'; });
// Fit heights if needed to avoid overflow
const gap = 12;
const available = Math.max(0, window.innerHeight - 40 - gap * (order.length - 1));
const eachMax = Math.floor(available / order.length);
order.forEach(el => {
el.style.maxHeight = eachMax + 'px';
el.style.overflowY = 'auto';
});
// Stack bottom-up in the requested order
let bottomOffset = 20; // base gutter
order.forEach(el => {
el.style.left = 'auto';
el.style.top = 'auto';
el.style.right = '20px';
el.style.bottom = bottomOffset + 'px';
bottomOffset += el.offsetHeight + gap;
});
}
function layoutWidgetsStackedBottomRight() {
const minimap = document.querySelector('.minimap');
const fileExplorer = document.querySelector('.file-explorer');
const tools = document.querySelector('.tools-widget');
const widgets = [minimap, fileExplorer, tools].filter(el => el && getComputedStyle(el).display !== 'none');
if (!widgets.length) return;
const order = [minimap, fileExplorer, tools].filter(Boolean).filter(el => getComputedStyle(el).display !== 'none');
// If user placed custom positions and there is no overlap, respect them.
if (hasCustomWidgetPositions() && !widgetsOverlap(widgets)) return;
applyStackLayout(widgets, order);
}
// Panel icon removed
let _minimapScrollContainer = null;
let _minimapScrollHandler = null;
function initMinimap() {
// Generate minimap content
const minimap = createMinimap();
document.body.appendChild(minimap);
// Make draggable and slideable (use title as handle)
const mTitle = minimap.querySelector('.minimap-title');
makeDraggable(minimap, 'uvnote-minimap-pos', mTitle);
addSlideToggle(minimap, mTitle);
// Attach scroll listener to window (two-panel removed)
_minimapScrollContainer = window;
if (_minimapScrollContainer) {
_minimapScrollHandler = () => updateMinimapActive();
if (_minimapScrollContainer === window) {
window.addEventListener('scroll', _minimapScrollHandler);
} else {
_minimapScrollContainer.addEventListener('scroll', _minimapScrollHandler);
}
}
updateMinimapActive();
}
function teardownMinimap() {
const minimap = document.querySelector('.minimap');
if (minimap && minimap.parentNode) minimap.parentNode.removeChild(minimap);
if (_minimapScrollContainer && _minimapScrollHandler) {
if (_minimapScrollContainer === window) {
window.removeEventListener('scroll', _minimapScrollHandler);
} else {
_minimapScrollContainer.removeEventListener('scroll', _minimapScrollHandler);
}
}
_minimapScrollContainer = null;
_minimapScrollHandler = null;
}
function initFileExplorer() {
// Generate file explorer content
const fileExplorer = createFileExplorer();
document.body.appendChild(fileExplorer);
const title = fileExplorer.querySelector('.file-explorer-title');
addSlideToggle(fileExplorer, title);
}
function createMinimap() {
const minimap = document.createElement('div');
minimap.className = 'minimap';
const title = document.createElement('div');
title.className = 'minimap-title';
title.textContent = 'navigation';
minimap.appendChild(title);
// Find all headings and cells
const root = document.querySelector('.main-content') || document;
const headings = root.querySelectorAll('h1, h2, h3, h4, h5, h6');
const cells = root.querySelectorAll('.cell');
// Combine and sort by position
const items = [];
headings.forEach(heading => {
const id = heading.id || generateId(heading.textContent);
if (!heading.id) heading.id = id;
items.push({
element: heading,
type: 'heading',
level: parseInt(heading.tagName.charAt(1)),
text: heading.textContent.trim(),
id: id,
position: heading.getBoundingClientRect().top + window.scrollY
});
});
cells.forEach(cell => {
const header = cell.querySelector('.cell-header');
if (header) {
const id = cell.id || `cell-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
if (!cell.id) cell.id = id;
items.push({
element: cell,
type: 'cell',
text: header.textContent.trim(),
id: id,
position: cell.getBoundingClientRect().top + window.scrollY
});
}
});
// Sort by position
items.sort((a, b) => a.position - b.position);
// Create minimap items
items.forEach(item => {
const link = document.createElement('a');
link.className = `minimap-item ${item.type === 'heading' ? 'minimap-heading' : 'minimap-cell'}`;
if (item.type === 'heading') {
link.classList.add(`h${item.level}`);
}
link.textContent = item.text.length > 25 ? item.text.substring(0, 22) + '...' : item.text;
link.href = `#${item.id}`;
link.onclick = function(e) {
e.preventDefault();
item.element.scrollIntoView({ behavior: 'smooth', block: 'start' });
};
minimap.appendChild(link);
});
return minimap;
}
function generateId(text) {
return text.toLowerCase()
.replace(/[^a-z0-9]+/g, '-')
.replace(/^-+|-+$/g, '')
.substring(0, 20);
}
function updateMinimapActive() {
const minimapItems = document.querySelectorAll('.minimap-item');
const container = _minimapScrollContainer || window;
const containerRect = container === window ? null : container.getBoundingClientRect();
const scrollPos = (container === window ? window.scrollY : container.scrollTop) + 100; // Offset for better detection
let activeItem = null;
minimapItems.forEach(item => {
const targetId = item.getAttribute('href').substring(1);
const target = document.getElementById(targetId);
if (target) {
const rectTop = target.getBoundingClientRect().top;
const targetPos = (container === window)
? rectTop + window.scrollY
: rectTop - containerRect.top + container.scrollTop;
if (targetPos <= scrollPos) {
activeItem = item;
}
}
item.classList.remove('active');
});
if (activeItem) {
activeItem.classList.add('active');
}
}
function createFileExplorer() {
const fileExplorer = document.createElement('div');
fileExplorer.className = 'file-explorer';
const title = document.createElement('div');
title.className = 'file-explorer-title';
title.textContent = 'files';
fileExplorer.appendChild(title);
// Make draggable (use title as handle)
makeDraggable(fileExplorer, 'uvnote-file-explorer-pos', title);
// Scripts section
const scriptsSection = document.createElement('div');
scriptsSection.className = 'file-explorer-section';
const scriptsTitle = document.createElement('div');
scriptsTitle.className = 'file-explorer-section-title';
scriptsTitle.textContent = 'scripts';
scriptsSection.appendChild(scriptsTitle);
// Find all cells and list their script files (single panel)
const root = document.querySelector('.main-content') || document;
const cells = root.querySelectorAll('.cell');
cells.forEach(cell => {
const header = cell.querySelector('.cell-header');
if (header) {
const cellText = header.textContent.trim();
const cellMatch = cellText.match(/Cell: ([a-zA-Z_][a-zA-Z0-9_]*)/);
if (cellMatch) {
const cellId = cellMatch[1];
const scriptItem = document.createElement('div');
scriptItem.className = 'file-explorer-item script';
scriptItem.textContent = `${cellId}.py`;
scriptItem.onclick = function() {
cell.scrollIntoView({ behavior: 'smooth', block: 'start' });
};
scriptsSection.appendChild(scriptItem);
}
}
});
fileExplorer.appendChild(scriptsSection);
// Artifacts section
const artifactsSection = document.createElement('div');
artifactsSection.className = 'file-explorer-section';
const artifactsTitle = document.createElement('div');
artifactsTitle.className = 'file-explorer-section-title';
artifactsTitle.textContent = 'artifacts';
artifactsSection.appendChild(artifactsTitle);
// Find all artifact links (single panel)
const artifactsRoot = document.querySelector('.main-content') || document;
const artifacts = artifactsRoot.querySelectorAll('.artifact');
if (artifacts.length === 0) {
const noArtifacts = document.createElement('div');
noArtifacts.className = 'file-explorer-item artifact';
noArtifacts.textContent = '(none)';
noArtifacts.style.opacity = '0.5';
artifactsSection.appendChild(noArtifacts);
} else {
artifacts.forEach(artifact => {
const artifactItem = document.createElement('div');
artifactItem.className = 'file-explorer-item artifact';
artifactItem.textContent = artifact.textContent;
artifactItem.onclick = function() {
artifact.click();
};
artifactsSection.appendChild(artifactItem);
});
}
fileExplorer.appendChild(artifactsSection);
return fileExplorer;
}
// Tools widget
function setActiveTool(tool) {
if (!tool || tool === 'none') {
document.body.dataset.tool = 'none';
localStorage.setItem('uvnote-active-tool', 'none');
setOverlayActive(false);
return;
}
document.body.dataset.tool = tool;
localStorage.setItem('uvnote-active-tool', tool);
setOverlayActive(true);
}
function getArrowColor() {
const saved = localStorage.getItem('uvnote-arrow-color');
if (saved) return saved;
return '#e53935'; // Default red color
}
function setStoredArrowColor(color) {
try { localStorage.setItem('uvnote-arrow-color', color); } catch (_) {}
}
function getLineThickness() {
const saved = localStorage.getItem('uvnote-line-thickness');
if (saved) return parseInt(saved, 10);
return 4; // default thickness
}
function setStoredLineThickness(thickness) {
try { localStorage.setItem('uvnote-line-thickness', thickness); } catch (_) {}
}
function createToolsWidget() {
const tools = document.createElement('div');
tools.className = 'tools-widget';
const title = document.createElement('div');
title.className = 'tools-title';
title.textContent = 'tools';
tools.appendChild(title);
const row = document.createElement('div');
row.className = 'tools-row';
tools.appendChild(row);
// Arrow tool
const arrowBtn = document.createElement('div');
arrowBtn.className = 'tool-button';
arrowBtn.textContent = 'arrow';
arrowBtn.onclick = function() {
const isActive = arrowBtn.classList.contains('active');
if (isActive) {
arrowBtn.classList.remove('active');
setActiveTool('none');
} else {
tools.querySelectorAll('.tool-button').forEach(b => b.classList.remove('active'));
arrowBtn.classList.add('active');
setActiveTool('arrow');
}
};
row.appendChild(arrowBtn);
// Pen tool
const penBtn = document.createElement('div');
penBtn.className = 'tool-button';
penBtn.textContent = 'pen';
penBtn.onclick = function() {
const isActive = penBtn.classList.contains('active');
if (isActive) {
penBtn.classList.remove('active');
setActiveTool('none');
} else {
tools.querySelectorAll('.tool-button').forEach(b => b.classList.remove('active'));
penBtn.classList.add('active');
setActiveTool('pen');
}
};
row.appendChild(penBtn);
// Eraser tool
const eraseBtn = document.createElement('div');
eraseBtn.className = 'tool-button';
eraseBtn.textContent = 'eraser';
eraseBtn.onclick = function() {
const isActive = eraseBtn.classList.contains('active');
if (isActive) {
eraseBtn.classList.remove('active');
setActiveTool('none');
} else {
tools.querySelectorAll('.tool-button').forEach(b => b.classList.remove('active'));
eraseBtn.classList.add('active');
setActiveTool('eraser');
}
};
row.appendChild(eraseBtn);
// Clear all
const clearBtn = document.createElement('div');
clearBtn.className = 'tool-button';
clearBtn.textContent = 'clear';
clearBtn.onclick = function() {
_shapes = [];
saveShapes();
renderOverlay();
};
row.appendChild(clearBtn);
// Restore active state from storage
const saved = localStorage.getItem('uvnote-active-tool') || 'none';
if (saved === 'arrow') {
arrowBtn.classList.add('active');
setActiveTool('arrow');
} else if (saved === 'pen') {
penBtn.classList.add('active');
setActiveTool('pen');
} else if (saved === 'eraser') {
eraseBtn.classList.add('active');
setActiveTool('eraser');
}
// Color selector
const colorTitle = document.createElement('div');
colorTitle.className = 'tools-section-title';
colorTitle.textContent = 'color';
tools.appendChild(colorTitle);
const colorRow = document.createElement('div');
colorRow.className = 'tools-row color-row';
tools.appendChild(colorRow);
const swatchColors = [
// Primary colors
'#e53935', '#fb8c00', '#fdd835', '#43a047', '#1e88e5', '#8e24aa',
// Additional useful colors
'#ff5722', '#795548', '#607d8b', '#9c27b0',
// Grayscale
'#000000', '#424242', '#9e9e9e', '#ffffff'
];
const swatches = [];
swatchColors.forEach(c => {
const s = document.createElement('div');
s.className = 'color-swatch';
s.style.backgroundColor = c;
s.title = c;
s.onclick = () => {
setStoredArrowColor(c);
refreshColorUI(c);
};
colorRow.appendChild(s);
swatches.push(s);
});
const colorInput = document.createElement('input');
colorInput.type = 'color';
colorInput.className = 'color-input';
colorInput.oninput = () => {
setStoredArrowColor(colorInput.value);
refreshColorUI(colorInput.value);
};
colorRow.appendChild(colorInput);
function refreshColorUI(selected) {
const selectedHex = selected.startsWith('#') ? selected.toLowerCase() : rgbToHex(selected);
swatches.forEach((s, i) => {
const swatchHex = swatchColors[i].toLowerCase();
if (swatchHex === selectedHex) {
s.classList.add('selected');
} else {
s.classList.remove('selected');
}
});
try {
colorInput.value = selectedHex;
} catch (_) {}
}
function rgbToHex(rgb) {
const m = rgb.match(/rgba?\((\d+),\s*(\d+),\s*(\d+)/i);
if (!m) return '#000000';
const r = parseInt(m[1]).toString(16).padStart(2, '0');
const g = parseInt(m[2]).toString(16).padStart(2, '0');
const b = parseInt(m[3]).toString(16).padStart(2, '0');
return `#${r}${g}${b}`;
}
// Restore color selection
refreshColorUI(getArrowColor());
// Thickness slider
const thicknessTitle = document.createElement('div');
thicknessTitle.className = 'tools-section-title';
thicknessTitle.textContent = 'thickness';
tools.appendChild(thicknessTitle);
const thicknessRow = document.createElement('div');
thicknessRow.className = 'thickness-row';
tools.appendChild(thicknessRow);
const thicknessSlider = document.createElement('input');
thicknessSlider.type = 'range';
thicknessSlider.className = 'thickness-slider';
thicknessSlider.min = '1';
thicknessSlider.max = '10';
thicknessSlider.value = getLineThickness();
const thicknessValue = document.createElement('span');
thicknessValue.className = 'thickness-value';
thicknessValue.textContent = thicknessSlider.value + 'px';
thicknessSlider.oninput = function() {
const value = parseInt(thicknessSlider.value, 10);
setStoredLineThickness(value);
thicknessValue.textContent = value + 'px';
};
thicknessRow.appendChild(thicknessSlider);
thicknessRow.appendChild(thicknessValue);
// Draggable behavior
makeDraggable(tools, 'uvnote-tools-pos', title);
return tools;
}
function initTools() {
const widget = createToolsWidget();
document.body.appendChild(widget);
const title = widget.querySelector('.tools-title');
addSlideToggle(widget, title);
}
function teardownTools() {
const w = document.querySelector('.tools-widget');
if (w && w.parentNode) w.parentNode.removeChild(w);
}
// --- Canvas overlay for tools ---
let _overlay = null;
let _overlayCtx = null;
let _overlayContainer = null; // window
let _overlayMode = 'single';
let _overlayResizeHandler = null;
let _overlayScrollHandler = null;
let _drawing = null; // current in-progress arrow {x1,y1,x2,y2}
let _shapes = []; // committed shapes for current mode
let _fadeTimer = null; // timer for fade animation
function getOverlayStorageKey() { return 'uvnote-shapes'; }
function loadShapes() {
try {
const raw = localStorage.getItem(getOverlayStorageKey());
_shapes = raw ? JSON.parse(raw) : [];
} catch (_) { _shapes = []; }
}
function saveShapes() {
try { localStorage.setItem(getOverlayStorageKey(), JSON.stringify(_shapes)); } catch (_) {}
}
function updateShapesFade() {
const now = Date.now();
const fadeStartTime = 3000; // Start fading after 3 seconds
const fadeEndTime = 5000; // Fully gone after 5 seconds
let needsUpdate = false;
for (let i = _shapes.length - 1; i >= 0; i--) {
const shape = _shapes[i];
if (!shape.createdAt) continue; // Skip old shapes without timestamps
const age = now - shape.createdAt;
if (age >= fadeEndTime) {
// Remove completely faded shapes
_shapes.splice(i, 1);
needsUpdate = true;
} else if (age >= fadeStartTime) {
// Update opacity for fading shapes
const fadeProgress = (age - fadeStartTime) / (fadeEndTime - fadeStartTime);
const newOpacity = 1 - fadeProgress;
if (Math.abs(shape.opacity - newOpacity) > 0.01) {
shape.opacity = newOpacity;
needsUpdate = true;
}
}
}
if (needsUpdate) {
saveShapes();
renderOverlay();
}
}
function getContentContainer() { return window; }
function updateOverlayModeAndContainer() {
_overlayContainer = window;
_overlayMode = 'single';
}
function updateOverlayBounds() {
if (!_overlay) return;
if (_overlayContainer === window) {
_overlay.style.position = 'fixed';
_overlay.style.left = '0px';
_overlay.style.top = '0px';
_overlay.width = window.innerWidth;
_overlay.height = window.innerHeight;
} else {
const rect = _overlayContainer.getBoundingClientRect();
_overlay.style.position = 'fixed';
_overlay.style.left = rect.left + 'px';
_overlay.style.top = rect.top + 'px';
_overlay.width = Math.max(0, Math.floor(rect.width));
_overlay.height = Math.max(0, Math.floor(rect.height));
}
renderOverlay();
}
function containerScrollLeft() {
return (_overlayContainer === window) ? (window.scrollX || 0) : (_overlayContainer.scrollLeft || 0);
}
function containerScrollTop() {
return (_overlayContainer === window) ? (window.scrollY || 0) : (_overlayContainer.scrollTop || 0);
}
function toCanvasCoords(clientX, clientY) {
const rect = _overlay.getBoundingClientRect();
return { x: clientX - rect.left, y: clientY - rect.top };
}
function onPointerDown(e) {
const tool = document.body.dataset.tool;
if (tool === 'arrow') {
startDrawArrow(e);
} else if (tool === 'pen') {
startDrawPen(e);
} else if (tool === 'eraser') {
eraseAt(e);
}
}
function onPointerMove(e) {
if (!_drawing) return;
if (_drawing.type === 'pen') {
moveDrawPen(e);
} else {
moveDrawArrow(e);
}
}
function onPointerUp(e) {
if (!_drawing) return;
if (_drawing.type === 'pen') {
endDrawPen();
} else {
endDrawArrow();
}
}
function startDrawArrow(e) {
if (document.body.dataset.tool !== 'arrow') return;
const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY);
_drawing = {
x1: pt.x + containerScrollLeft(),
y1: pt.y + containerScrollTop(),
x2: pt.x + containerScrollLeft(),
y2: pt.y + containerScrollTop(),
color: getArrowColor(),
width: getLineThickness()
};
renderOverlay();
e.preventDefault();
}
function moveDrawArrow(e) {
if (!_drawing) return;
const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY);
_drawing.x2 = pt.x + containerScrollLeft();
_drawing.y2 = pt.y + containerScrollTop();
renderOverlay();
e.preventDefault();
}
function endDrawArrow() {
if (!_drawing) return;
_shapes.push({
type: 'arrow',
..._drawing,
createdAt: Date.now(),
opacity: 1.0
});
_drawing = null;
saveShapes();
renderOverlay();
}
function startDrawPen(e) {
if (document.body.dataset.tool !== 'pen') return;
const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY);
_drawing = {
type: 'pen',
points: [{
x: pt.x + containerScrollLeft(),
y: pt.y + containerScrollTop()
}],
color: getArrowColor(),
width: getLineThickness()
};
renderOverlay();
e.preventDefault();
}
function moveDrawPen(e) {
if (!_drawing || _drawing.type !== 'pen') return;
const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY);
_drawing.points.push({
x: pt.x + containerScrollLeft(),
y: pt.y + containerScrollTop()
});
renderOverlay();
e.preventDefault();
}
function endDrawPen() {
if (!_drawing || _drawing.type !== 'pen') return;
if (_drawing.points.length > 1) {
_shapes.push({
..._drawing,
createdAt: Date.now(),
opacity: 1.0
});
}
_drawing = null;
saveShapes();
renderOverlay();
}
function distPointToSegment(px, py, x1, y1, x2, y2) {
const dx = x2 - x1, dy = y2 - y1;
if (dx === 0 && dy === 0) return Math.hypot(px - x1, py - y1);
const t = Math.max(0, Math.min(1, ((px - x1) * dx + (py - y1) * dy) / (dx*dx + dy*dy)));
const cx = x1 + t * dx, cy = y1 + t * dy;
return Math.hypot(px - cx, py - cy);
}
function eraseAt(e) {
const pt = toCanvasCoords(e.touches ? e.touches[0].clientX : e.clientX, e.touches ? e.touches[0].clientY : e.clientY);
const x = pt.x + containerScrollLeft();
const y = pt.y + containerScrollTop();
const threshold = 10; // pixels
for (let i = _shapes.length - 1; i >= 0; i--) {
const s = _shapes[i];
if (s.type === 'arrow') {
const d = distPointToSegment(x, y, s.x1, s.y1, s.x2, s.y2);
if (d <= threshold) {
_shapes.splice(i, 1);
saveShapes();
renderOverlay();
break;
}
} else if (s.type === 'pen' && s.points) {
// Check if click is near any line segment in the pen stroke
let minDist = Infinity;
for (let j = 1; j < s.points.length; j++) {
const d = distPointToSegment(x, y, s.points[j-1].x, s.points[j-1].y, s.points[j].x, s.points[j].y);
minDist = Math.min(minDist, d);
}
if (minDist <= threshold) {
_shapes.splice(i, 1);
saveShapes();
renderOverlay();
break;
}
}
}
e.preventDefault();
}
function drawArrow(ctx, x1, y1, x2, y2, color, width, opacity = 1.0) {
// Set opacity
const oldAlpha = ctx.globalAlpha;
ctx.globalAlpha = opacity;
ctx.strokeStyle = color;
ctx.fillStyle = color;
ctx.lineWidth = width;
ctx.lineCap = 'round';
ctx.lineJoin = 'round';
// Calculate arrow geometry
const angle = Math.atan2(y2 - y1, x2 - x1);
const headLength = Math.min(15 + width * 1.5, 25); // Cap the max head size
const headAngle = Math.PI / 6; // 30 degrees
// Calculate where the line should end (before the arrowhead)
const lineEndX = x2 - headLength * 0.8 * Math.cos(angle);
const lineEndY = y2 - headLength * 0.8 * Math.sin(angle);
// Draw the line
ctx.beginPath();
ctx.moveTo(x1, y1);
ctx.lineTo(lineEndX, lineEndY);
ctx.stroke();
// Calculate arrowhead points
const hx1 = x2 - headLength * Math.cos(angle - headAngle);
const hy1 = y2 - headLength * Math.sin(angle - headAngle);
const hx2 = x2 - headLength * Math.cos(angle + headAngle);
const hy2 = y2 - headLength * Math.sin(angle + headAngle);
// Draw arrowhead
ctx.beginPath();
ctx.moveTo(x2, y2);
ctx.lineTo(hx1, hy1);
ctx.lineTo(hx2, hy2);
ctx.closePath();
ctx.fill();
// Restore opacity
ctx.globalAlpha = oldAlpha;
}
function drawPen(ctx, points, color, width, offX, offY, opacity = 1.0) {
if (!points || points.length < 2) return;
// Set opacity
const oldAlpha = ctx.globalAlpha;
ctx.globalAlpha = opacity;
ctx.strokeStyle = color;
ctx.lineWidth = width;
ctx.lineCap = 'round';
ctx.lineJoin = 'round';
ctx.beginPath();
ctx.moveTo(points[0].x - offX, points[0].y - offY);
for (let i = 1; i < points.length; i++) {
ctx.lineTo(points[i].x - offX, points[i].y - offY);
}
ctx.stroke();
// Restore opacity
ctx.globalAlpha = oldAlpha;
}
function renderOverlay() {
if (!_overlay || !_overlayCtx) return;
_overlayCtx.clearRect(0, 0, _overlay.width, _overlay.height);
const offX = containerScrollLeft();
const offY = containerScrollTop();
// Draw committed shapes for current mode
for (const s of _shapes) {
const opacity = s.opacity !== undefined ? s.opacity : 1.0;
if (s.type === 'arrow') {
drawArrow(_overlayCtx, s.x1 - offX, s.y1 - offY, s.x2 - offX, s.y2 - offY, s.color || '#f00', s.width || 2, opacity);
} else if (s.type === 'pen') {
drawPen(_overlayCtx, s.points, s.color || '#f00', s.width || 2, offX, offY, opacity);
}
}
// Draw current drawing
if (_drawing) {
if (_drawing.type === 'pen') {
drawPen(_overlayCtx, _drawing.points, _drawing.color, _drawing.width, offX, offY);
} else {
drawArrow(_overlayCtx, _drawing.x1 - offX, _drawing.y1 - offY, _drawing.x2 - offX, _drawing.y2 - offY, _drawing.color, _drawing.width);
}
}
}
function setOverlayActive(active) {
if (!_overlay) initOverlay();
_overlay.style.pointerEvents = active ? 'auto' : 'none';
// Re-render to ensure visibility aligns with content
renderOverlay();
}
function initOverlay() {
if (_overlay) return;
updateOverlayModeAndContainer();
_overlay = document.createElement('canvas');
_overlay.className = 'draw-overlay';
_overlayCtx = _overlay.getContext('2d');
document.body.appendChild(_overlay);
updateOverlayBounds();
loadShapes();
renderOverlay();
// Events
_overlay.addEventListener('mousedown', onPointerDown);
_overlay.addEventListener('mousemove', onPointerMove);
document.addEventListener('mouseup', onPointerUp);
_overlay.addEventListener('touchstart', onPointerDown, { passive: false });
_overlay.addEventListener('touchmove', onPointerMove, { passive: false });
document.addEventListener('touchend', onPointerUp);
_overlayResizeHandler = () => updateOverlayBounds();
window.addEventListener('resize', _overlayResizeHandler);
_overlayScrollHandler = () => renderOverlay();
window.addEventListener('scroll', _overlayScrollHandler);
// Start fade animation timer
_fadeTimer = setInterval(updateShapesFade, 100); // Update every 100ms for smooth fade
}
function rebindOverlayContainer() {
if (!_overlay) return;
// Remove old scroll handler
if (_overlayScrollHandler) { window.removeEventListener('scroll', _overlayScrollHandler); }
updateOverlayModeAndContainer();
updateOverlayBounds();
loadShapes();
renderOverlay();
_overlayScrollHandler = () => renderOverlay();
window.addEventListener('scroll', _overlayScrollHandler);
}
function teardownOverlay() {
if (!_overlay) return;
_overlay.removeEventListener('mousedown', onPointerDown);
_overlay.removeEventListener('mousemove', onPointerMove);
document.removeEventListener('mouseup', onPointerUp);
_overlay.removeEventListener('touchstart', onPointerDown);
_overlay.removeEventListener('touchmove', onPointerMove);
document.removeEventListener('touchend', onPointerUp);
if (_overlayResizeHandler) window.removeEventListener('resize', _overlayResizeHandler);
if (_overlayScrollHandler) {
if (_overlayContainer === window) {
window.removeEventListener('scroll', _overlayScrollHandler);
} else if (_overlayContainer) {
_overlayContainer.removeEventListener('scroll', _overlayScrollHandler);
}
}
if (_fadeTimer) {
clearInterval(_fadeTimer);
_fadeTimer = null;
}
if (_overlay.parentNode) _overlay.parentNode.removeChild(_overlay);
_overlay = null; _overlayCtx = null; _overlayContainer = null; _overlayResizeHandler = null; _overlayScrollHandler = null; _drawing = null;
}
function teardownFileExplorer() {
const fe = document.querySelector('.file-explorer');
if (fe && fe.parentNode) fe.parentNode.removeChild(fe);
}
function runCell(cellId){
const btn=document.querySelector('.run-btn[onclick*="'+cellId+'"]');
const output=document.getElementById('output-'+cellId);
if(btn){btn.textContent='⏳ running...';btn.disabled=true;}
if(output){output.classList.add('output-stale');}
fetch('/run/'+cellId,{method:'POST'}).then(r=>r.json()).then(data=>{
if(output){
output.classList.remove('output-stale');
let html='';
if(data.stdout) html+='<div class="cell-stdout">'+data.stdout+'</div>';
if(data.stderr) html+='<div class="cell-stderr">'+data.stderr+'</div>';
output.innerHTML=html;
}
if(btn){btn.textContent='▶ run';btn.disabled=false;}
}).catch(e=>{
console.error('Run failed:',e);
if(output){output.classList.remove('output-stale');}
if(btn){btn.textContent='▶ run';btn.disabled=false;}
});
}
function copyCell(cellId){
console.log('copyCell called with cellId:', cellId);
// Try multiple selectors to find the code element
let codeElement = document.querySelector('#code-'+cellId+' code');
if (!codeElement) {
codeElement = document.querySelector('#code-'+cellId+' pre code');
}
if (!codeElement) {
codeElement = document.querySelector('#code-'+cellId+' .highlight code');
}
if (!codeElement) {
// Try finding any code element within the cell
const codeDiv = document.getElementById('code-'+cellId);
if (codeDiv) {
codeElement = codeDiv.querySelector('code');
}
}
const btn = document.querySelector('.copy-btn[onclick*="'+cellId+'"]');
console.log('Found codeElement:', codeElement);
console.log('Found btn:', btn);
console.log('Code div structure:', document.getElementById('code-'+cellId));
if (!codeElement) {
console.error('Code element not found for cell:', cellId);
// Log the actual structure for debugging
const codeDiv = document.getElementById('code-'+cellId);
if (codeDiv) {
console.log('Code div HTML:', codeDiv.innerHTML);
}
return;
}
if (!btn) {
console.error('Copy button not found for cell:', cellId);
return;
}
const codeText = codeElement.textContent;
console.log('Code text to copy:', codeText ? codeText.substring(0, 50) + '...' : 'empty');
if (navigator.clipboard && navigator.clipboard.writeText) {
navigator.clipboard.writeText(codeText).then(function() {
console.log('Clipboard copy successful');
btn.textContent = '✓ Copied!';
btn.classList.add('copied');
setTimeout(function() {
btn.textContent = 'Copy';
btn.classList.remove('copied');
}, 2000);
}).catch(function(err) {
console.warn('Clipboard copy failed:', err);
fallbackCopy();
});
} else {
console.log('Using fallback copy method');
fallbackCopy();
}
function fallbackCopy() {
const textarea = document.createElement('textarea');
textarea.value = codeText;
textarea.style.position = 'absolute';
textarea.style.left = '-9999px';
document.body.appendChild(textarea);
textarea.select();
try {
const success = document.execCommand('copy');
console.log('Fallback copy success:', success);
btn.textContent = '✓ Copied!';
btn.classList.add('copied');
setTimeout(function() {
btn.textContent = 'Copy';
btn.classList.remove('copied');
}, 2000);
} catch (err) {
console.error('Fallback copy failed:', err);
btn.textContent = 'Copy failed';
setTimeout(function() {
btn.textContent = 'Copy';
}, 2000);
}
document.body.removeChild(textarea);
}
}
// Live reload functionality (robust SSE handling)
(function(){
if (!('EventSource' in window)) {
console.warn('SSE not supported in this browser');
return;
}
let source = new EventSource('/events');
let isOpen = false;
source.onopen = function(){ isOpen = true; console.log('SSE connected'); };
source.onmessage = function(e){
const msg=(e.data||'').trim(); if(!msg) return;
console.log('SSE message:', msg);
if (msg==='reload' || msg==='incremental') { location.reload(); }
// Ignore 'loading' to avoid premature reload loops
};
source.onerror = function(e){
// Let EventSource auto-reconnect instead of forcing a reload
if (isOpen) console.warn('SSE error after open, retrying...', e);
};
window.addEventListener('beforeunload', function(){ try{source.close();}catch(_){} });
})();
document.addEventListener('DOMContentLoaded', function() {
updateThemeIcon();
initMinimap();
initFileExplorer();
initTools();
initOverlay();
layoutWidgetsStackedBottomRight();
window.addEventListener('resize', layoutWidgetsStackedBottomRight);
});
</script>
</head>
<body>
<div class="controls">
<div class="theme-toggle" onclick="toggleTheme()">light</div>
<div class="reset-toggle" onclick="resetLayout()">reset</div>
</div>
<div class="main-content">
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('utils')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('utils')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: utils | deps: torch, numpy | 30.61s
| <button class="run-btn" onclick="runCell('utils')">▶ run</button>
<button class="copy-btn" onclick="copyCell('utils')">Copy</button>
</div>
<div id="code-utils" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal">10</span>
<span class="normal">11</span>
<span class="normal">12</span>
<span class="normal">13</span>
<span class="normal">14</span>
<span class="normal">15</span>
<span class="normal">16</span>
<span class="normal">17</span>
<span class="normal">18</span>
<span class="normal">19</span>
<span class="normal">20</span>
<span class="normal">21</span>
<span class="normal">22</span>
<span class="normal">23</span>
<span class="normal">24</span>
<span class="normal">25</span>
<span class="normal">26</span>
<span class="normal">27</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">&quot;&quot;&quot;Simple utilities for running the models.&quot;&quot;&quot;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dtype</span><span class="p">(</span><span class="n">dtype_str</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Convert string to torch dtype.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">&quot;float16&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span>
<span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">&quot;bfloat16&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="k">def</span><span class="w"> </span><span class="nf">tensor_stats</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Generate stats string for a tensor.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="p">(</span><span class="sa">f</span><span class="s2">&quot;shape=</span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;dtype=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;device=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;mean=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;std=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">set_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Set seeds for reproducibility.&quot;&quot;&quot;</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed_all</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">deterministic</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">benchmark</span> <span class="o">=</span> <span class="kc">False</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-utils" class="cell-output">
<div class="cell-stderr">Downloading setuptools (1.1MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading sympy (6.0MiB)
Downloading torch (846.8MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading triton (148.4MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading sympy
Downloading nvidia-nvjitlink-cu12
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 234ms
</div>
</div>
</div>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('bench_utils')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('bench_utils')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: bench_utils | deps: torch, numpy | 31.57s
| <button class="run-btn" onclick="runCell('bench_utils')">▶ run</button>
<button class="copy-btn" onclick="copyCell('bench_utils')">Copy</button>
</div>
<div id="code-bench_utils" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span>
<span class="normal">111</span>
<span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span>
<span class="normal">130</span>
<span class="normal">131</span>
<span class="normal">132</span>
<span class="normal">133</span>
<span class="normal">134</span>
<span class="normal">135</span>
<span class="normal">136</span>
<span class="normal">137</span>
<span class="normal">138</span>
<span class="normal">139</span>
<span class="normal">140</span>
<span class="normal">141</span>
<span class="normal">142</span>
<span class="normal">143</span>
<span class="normal">144</span>
<span class="normal">145</span>
<span class="normal">146</span>
<span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span>
<span class="normal">150</span>
<span class="normal">151</span>
<span class="normal">152</span>
<span class="normal">153</span>
<span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span>
<span class="normal">157</span>
<span class="normal">158</span>
<span class="normal">159</span>
<span class="normal">160</span>
<span class="normal">161</span>
<span class="normal">162</span>
<span class="normal">163</span>
<span class="normal">164</span>
<span class="normal">165</span>
<span class="normal">166</span>
<span class="normal">167</span>
<span class="normal">168</span>
<span class="normal">169</span>
<span class="normal">170</span>
<span class="normal">171</span>
<span class="normal">172</span>
<span class="normal">173</span>
<span class="normal">174</span>
<span class="normal">175</span>
<span class="normal">176</span>
<span class="normal">177</span>
<span class="normal">178</span>
<span class="normal">179</span>
<span class="normal">180</span>
<span class="normal">181</span>
<span class="normal">182</span>
<span class="normal">183</span>
<span class="normal">184</span>
<span class="normal">185</span>
<span class="normal">186</span>
<span class="normal">187</span>
<span class="normal">188</span>
<span class="normal">189</span>
<span class="normal">190</span>
<span class="normal">191</span>
<span class="normal">192</span>
<span class="normal">193</span>
<span class="normal">194</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">&quot;&quot;&quot;Reusable benchmarking utilities for performance testing.&quot;&quot;&quot;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">contextlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">contextmanager</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Optional</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dtype</span><span class="p">(</span><span class="n">dtype_str</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Convert string to torch dtype.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">&quot;float16&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span>
<span class="k">if</span> <span class="n">dtype_str</span> <span class="o">==</span> <span class="s2">&quot;bfloat16&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_sync</span><span class="p">(</span><span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Synchronize device if CUDA.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_compute_stats</span><span class="p">(</span><span class="n">times_s</span><span class="p">,</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Compute comprehensive latency and throughput statistics.&quot;&quot;&quot;</span>
<span class="n">lat_ms</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">t</span> <span class="o">*</span> <span class="mf">1000.0</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">times_s</span><span class="p">])</span>
<span class="n">lat_ms_sorted</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">)</span>
<span class="n">n</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">)</span>
<span class="n">stats</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;avg_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span>
<span class="s2">&quot;min_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span>
<span class="s2">&quot;max_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span>
<span class="s2">&quot;std_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">),</span>
<span class="s2">&quot;p50_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">,</span> <span class="mi">50</span><span class="p">),</span>
<span class="s2">&quot;p95_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">,</span> <span class="mi">95</span><span class="p">),</span>
<span class="s2">&quot;p99_ms&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span><span class="n">lat_ms</span><span class="p">,</span> <span class="mi">99</span><span class="p">),</span>
<span class="s2">&quot;num_iters&quot;</span><span class="p">:</span> <span class="n">n</span>
<span class="p">}</span>
<span class="k">if</span> <span class="n">tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">n</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">avg_s</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">times_s</span><span class="p">)</span>
<span class="n">stats</span><span class="p">[</span><span class="s2">&quot;tokens_per_s&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">tokens</span> <span class="o">/</span> <span class="n">avg_s</span> <span class="k">if</span> <span class="n">avg_s</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="nb">float</span><span class="p">(</span><span class="s2">&quot;inf&quot;</span><span class="p">)</span>
<span class="n">stats</span><span class="p">[</span><span class="s2">&quot;throughput_variance&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">([</span><span class="n">tokens</span> <span class="o">/</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">times_s</span> <span class="k">if</span> <span class="n">t</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">])</span>
<span class="k">return</span> <span class="n">stats</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_format_timing_stats</span><span class="p">(</span><span class="n">stats</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Format timing statistics for display.&quot;&quot;&quot;</span>
<span class="n">lines</span> <span class="o">=</span> <span class="p">[</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot;Iterations: </span><span class="si">{</span><span class="n">stats</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;num_iters&#39;</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Latency Statistics:&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; Average: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; Min: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;min_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; Max: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;max_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; Std Dev: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;std_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Percentiles:&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; P50 (median): </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;p50_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; P95: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;p95_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; P99: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;p99_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms&quot;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="s1">&#39;tokens_per_s&#39;</span> <span class="ow">in</span> <span class="n">stats</span><span class="p">:</span>
<span class="n">lines</span><span class="o">.</span><span class="n">extend</span><span class="p">([</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Throughput:&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; Tokens/sec: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;tokens_per_s&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
<span class="sa">f</span><span class="s2">&quot; Std Dev: </span><span class="si">{</span><span class="n">stats</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;throughput_variance&#39;</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">)</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
<span class="p">])</span>
<span class="n">lines</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lines</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_bench_engine</span><span class="p">(</span>
<span class="n">call</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[],</span> <span class="n">Any</span><span class="p">],</span> <span class="o">*</span><span class="p">,</span> <span class="n">warmup</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">iters</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dtype</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="nb">list</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Core benchmarking engine with warmup and timing.&quot;&quot;&quot;</span>
<span class="n">use_autocast</span> <span class="o">=</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span> <span class="ow">and</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">)</span>
<span class="c1"># Warmup phase</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Warming up (</span><span class="si">{</span><span class="n">warmup</span><span class="si">}</span><span class="s2"> iterations)...&quot;</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">warmup</span><span class="p">)):</span>
<span class="k">if</span> <span class="n">use_autocast</span><span class="p">:</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">autocast</span><span class="p">(</span><span class="n">device_type</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">):</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">_</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span>
<span class="n">_sync</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># Measurement phase</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Benchmarking (</span><span class="si">{</span><span class="n">iters</span><span class="si">}</span><span class="s2"> iterations)...&quot;</span><span class="p">)</span>
<span class="n">times_s</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">last</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">iters</span><span class="p">)):</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
<span class="k">if</span> <span class="n">use_autocast</span><span class="p">:</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">autocast</span><span class="p">(</span><span class="n">device_type</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">):</span>
<span class="n">last</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">last</span> <span class="o">=</span> <span class="n">call</span><span class="p">()</span>
<span class="n">_sync</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
<span class="n">times_s</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">)</span>
<span class="c1"># Progress indicator every 20% of iterations</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">%</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">iters</span> <span class="o">//</span> <span class="mi">5</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">pct</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">/</span> <span class="n">iters</span><span class="p">)</span> <span class="o">*</span> <span class="mi">100</span>
<span class="n">avg_so_far</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">times_s</span><span class="p">[:</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="mi">1000</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Progress: </span><span class="si">{</span><span class="n">pct</span><span class="si">:</span><span class="s2">.0f</span><span class="si">}</span><span class="s2">% complete (avg: </span><span class="si">{</span><span class="n">avg_so_far</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> ms)&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">last</span><span class="p">,</span> <span class="n">times_s</span>
<span class="k">def</span><span class="w"> </span><span class="nf">tensor_stats</span><span class="p">(</span><span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Generate comprehensive stats string for a tensor.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="p">(</span><span class="sa">f</span><span class="s2">&quot;shape=</span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;dtype=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;device=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;range=[</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">], &quot;</span>
<span class="sa">f</span><span class="s2">&quot;mean=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;std=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;norm=</span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">norm</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nd">@contextmanager</span>
<span class="k">def</span><span class="w"> </span><span class="nf">bench_context</span><span class="p">(</span>
<span class="o">*</span><span class="p">,</span> <span class="n">warmup</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">25</span><span class="p">,</span> <span class="n">iters</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;cuda&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">save_json</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Context that yields a runner: runner(fn, *args, **kwargs) -&gt; (result, stats).&quot;&quot;&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="nf">runner</span><span class="p">(</span><span class="n">fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]:</span>
<span class="c1"># Log configuration</span>
<span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">┌─ Benchmark Configuration ─────────────────────────────┐&quot;</span><span class="p">)</span>
<span class="c1"># print(f&quot;│ Device: {device:&lt;15} Dtype: {dtype} │&quot;)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;│ Warmup: </span><span class="si">{</span><span class="n">warmup</span><span class="si">:</span><span class="s2">&lt;15</span><span class="si">}</span><span class="s2"> Iters: </span><span class="si">{</span><span class="n">iters</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tokens</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;│ Tokens: </span><span class="si">{</span><span class="n">tokens</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;└────────────────────────────────────────────────────────┘&quot;</span><span class="p">)</span>
<span class="c1"># Log input if it&#39;s a tensor</span>
<span class="k">if</span> <span class="n">verbose</span> <span class="ow">and</span> <span class="n">args</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Input: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">call</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">fn</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="n">result</span><span class="p">,</span> <span class="n">times_s</span> <span class="o">=</span> <span class="n">_bench_engine</span><span class="p">(</span><span class="n">call</span><span class="p">,</span> <span class="n">warmup</span><span class="o">=</span><span class="n">warmup</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="n">iters</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># Log output if it&#39;s a tensor or tuple with tensors</span>
<span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Output tensors:&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Primary: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">result</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Primary: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Auxiliary: </span><span class="si">{</span><span class="n">tensor_stats</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Auxiliary: </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="c1"># Compute and display statistics</span>
<span class="n">stats</span> <span class="o">=</span> <span class="n">_compute_stats</span><span class="p">(</span><span class="n">times_s</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">)</span>
<span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="n">_format_timing_stats</span><span class="p">(</span><span class="n">stats</span><span class="p">,</span> <span class="n">tokens</span><span class="p">))</span>
<span class="c1"># Save to JSON if requested</span>
<span class="k">if</span> <span class="n">save_json</span><span class="p">:</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">json</span>
<span class="n">json_data</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;implementation&quot;</span><span class="p">:</span> <span class="n">save_json</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;.json&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">),</span>
<span class="s2">&quot;config&quot;</span><span class="p">:</span> <span class="p">{</span>
<span class="s2">&quot;warmup&quot;</span><span class="p">:</span> <span class="n">warmup</span><span class="p">,</span>
<span class="s2">&quot;iters&quot;</span><span class="p">:</span> <span class="n">iters</span><span class="p">,</span>
<span class="s2">&quot;device&quot;</span><span class="p">:</span> <span class="nb">str</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="c1"># Convert device to string</span>
<span class="s2">&quot;dtype&quot;</span><span class="p">:</span> <span class="nb">str</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="s2">&quot;tokens&quot;</span><span class="p">:</span> <span class="n">tokens</span>
<span class="p">},</span>
<span class="s2">&quot;stats&quot;</span><span class="p">:</span> <span class="n">stats</span><span class="p">,</span>
<span class="s2">&quot;output_sum&quot;</span><span class="p">:</span> <span class="nb">float</span><span class="p">(</span><span class="n">result</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">result</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="nb">float</span><span class="p">(</span><span class="n">result</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
<span class="p">}</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">save_json</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">json_data</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Saved benchmark results to </span><span class="si">{</span><span class="n">save_json</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span><span class="p">,</span> <span class="n">stats</span>
<span class="k">yield</span> <span class="n">runner</span>
<span class="k">def</span><span class="w"> </span><span class="nf">set_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Set seeds for reproducibility.&quot;&quot;&quot;</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">manual_seed_all</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">deterministic</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">torch</span><span class="o">.</span><span class="n">backends</span><span class="o">.</span><span class="n">cudnn</span><span class="o">.</span><span class="n">benchmark</span> <span class="o">=</span> <span class="kc">False</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-bench_utils" class="cell-output">
<div class="cell-stderr">Downloading networkx (1.9MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading setuptools (1.1MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading sympy (6.0MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading torch (846.8MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading triton (148.4MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading nvidia-nvjitlink-cu12
Downloading sympy
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 234ms
</div>
</div>
</div>
<p>This notebook runs the Yamoe and Binned MoE implementations once each with identical inputs to verify they produce consistent outputs.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('config')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('config')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: config | deps: torch, numpy | 37.88s
| <button class="run-btn" onclick="runCell('config')">▶ run</button>
<button class="copy-btn" onclick="copyCell('config')">Copy</button>
</div>
<div id="code-config" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal">10</span>
<span class="normal">11</span>
<span class="normal">12</span>
<span class="normal">13</span>
<span class="normal">14</span>
<span class="normal">15</span>
<span class="normal">16</span>
<span class="normal">17</span>
<span class="normal">18</span>
<span class="normal">19</span>
<span class="normal">20</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">&quot;&quot;&quot;Shared configuration for both implementations.&quot;&quot;&quot;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="c1"># Model configuration</span>
<span class="n">NUM_EXPERTS</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">HIDDEN_SIZE</span> <span class="o">=</span> <span class="mi">1152</span>
<span class="n">INTERMEDIATE_SIZE</span> <span class="o">=</span> <span class="mi">3072</span>
<span class="n">TOP_K</span> <span class="o">=</span> <span class="mi">4</span>
<span class="c1"># Input configuration</span>
<span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">SEQ_LEN</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">DTYPE</span> <span class="o">=</span> <span class="s2">&quot;float32&quot;</span>
<span class="n">DEVICE</span> <span class="o">=</span> <span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span>
<span class="c1"># Seeds for reproducibility</span>
<span class="n">WEIGHT_SEED</span> <span class="o">=</span> <span class="mi">999</span>
<span class="n">EXPERT_SEED</span> <span class="o">=</span> <span class="mi">777</span>
<span class="n">INPUT_SEED</span> <span class="o">=</span> <span class="mi">123</span>
<span class="n">GENERAL_SEED</span> <span class="o">=</span> <span class="mi">42</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-config" class="cell-output">
<div class="cell-stderr">Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading triton (148.4MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading setuptools (1.1MiB)
Downloading sympy (6.0MiB)
Downloading torch (846.8MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading sympy
Downloading nvidia-nvjitlink-cu12
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 225ms
</div>
</div>
</div>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('save_data')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('save_data')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: save_data | deps: torch, numpy | 38.59s
| <button class="run-btn" onclick="runCell('save_data')">▶ run</button>
<button class="copy-btn" onclick="copyCell('save_data')">Copy</button>
</div>
<div id="code-save_data" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal">10</span>
<span class="normal">11</span>
<span class="normal">12</span>
<span class="normal">13</span>
<span class="normal">14</span>
<span class="normal">15</span>
<span class="normal">16</span>
<span class="normal">17</span>
<span class="normal">18</span>
<span class="normal">19</span>
<span class="normal">20</span>
<span class="normal">21</span>
<span class="normal">22</span>
<span class="normal">23</span>
<span class="normal">24</span>
<span class="normal">25</span>
<span class="normal">26</span>
<span class="normal">27</span>
<span class="normal">28</span>
<span class="normal">29</span>
<span class="normal">30</span>
<span class="normal">31</span>
<span class="normal">32</span>
<span class="normal">33</span>
<span class="normal">34</span>
<span class="normal">35</span></pre></div></td><td class="code"><div><pre><span></span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">Generate deterministic shared weights once and save as artifacts so</span>
<span class="sd">both implementations load identical parameters.</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save_shared_weights</span><span class="p">():</span>
<span class="c1"># Router: Kaiming uniform as used by both, bias zeros</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">WEIGHT_SEED</span><span class="p">)</span>
<span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">kaiming_uniform_</span><span class="p">(</span><span class="n">router_weight</span><span class="p">)</span>
<span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">)</span>
<span class="c1"># Experts: normal(0, 0.02), biases zeros</span>
<span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">EXPERT_SEED</span><span class="p">)</span>
<span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>
<span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span>
<span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">normal_</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>
<span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">)</span>
<span class="c1"># Save artifacts</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="s1">&#39;router_weight.pt&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">router_bias</span><span class="p">,</span> <span class="s1">&#39;router_bias.pt&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="p">,</span> <span class="s1">&#39;gate_up_proj.pt&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="s1">&#39;gate_up_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">down_proj</span><span class="p">,</span> <span class="s1">&#39;down_proj.pt&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="p">,</span> <span class="s1">&#39;down_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Saved shared weights to artifacts&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">save_shared_weights</span><span class="p">()</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-save_data" class="cell-output">
<div class="cell-stdout">Saved shared weights to artifacts
Router weight sum: 12.588732
Gate/up sum: 1026.601807
Down sum: 206.729263
</div>
<div class="cell-stderr">Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading triton (148.4MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading sympy (6.0MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading setuptools (1.1MiB)
Downloading torch (846.8MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading nvidia-nvjitlink-cu12
Downloading sympy
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 239ms
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/save_data/down_proj.pt" class="artifact" target="_blank">down_proj.pt</a>
<a href="artifacts/save_data/down_proj_bias.pt" class="artifact" target="_blank">down_proj_bias.pt</a>
<a href="artifacts/save_data/gate_up_proj.pt" class="artifact" target="_blank">gate_up_proj.pt</a>
<a href="artifacts/save_data/gate_up_proj_bias.pt" class="artifact" target="_blank">gate_up_proj_bias.pt</a>
<a href="artifacts/save_data/router_bias.pt" class="artifact" target="_blank">router_bias.pt</a>
<a href="artifacts/save_data/router_weight.pt" class="artifact" target="_blank">router_weight.pt</a>
</div>
</div>
</div>
<h2>Yamoe Implementation</h2>
<p>This section runs the Yamoe MoE implementation with optimized Triton kernels.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('yamoe_run')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('yamoe_run')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: yamoe_run | deps: torch, kernels, numpy | 35.75s
| <button class="run-btn" onclick="runCell('yamoe_run')">▶ run</button>
<button class="copy-btn" onclick="copyCell('yamoe_run')">Copy</button>
</div>
<div id="code-yamoe_run" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span>
<span class="normal">111</span>
<span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">get_kernel</span><span class="p">,</span> <span class="n">get_local_kernel</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
<span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span>
<span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span>
<span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span>
<span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="c1"># Discover the upstream artifact directory from env</span>
<span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_SAVE_DATA&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Loading weights from: </span><span class="si">{</span><span class="n">data_dir</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_weight.pt&#39;</span><span class="p">)</span>
<span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_bias.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Loaded shared weights from artifacts&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">YamoeRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span>
<span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span>
<span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span>
<span class="k">class</span><span class="w"> </span><span class="nc">YamoeMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">YamoeRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span>
<span class="c1"># Load Yamoe kernel</span>
<span class="c1"># self.yamoe = get_local_kernel(Path(&quot;/home/ubuntu/Projects/yamoe/result&quot;), &quot;yamoe&quot;)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">yamoe</span> <span class="o">=</span> <span class="n">get_kernel</span><span class="p">(</span><span class="s2">&quot;drbh/yamoe&quot;</span><span class="p">,</span> <span class="n">revision</span><span class="o">=</span><span class="s2">&quot;v0.2.0&quot;</span><span class="p">)</span>
<span class="c1"># Expert capacity - generous to avoid dropping tokens</span>
<span class="c1"># self.expert_capacity = 256</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span> <span class="o">=</span> <span class="mi">12</span>
<span class="c1"># Expert weights - use the loaded weights</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span>
<span class="c1"># Get routing decisions</span>
<span class="n">routing_weights</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="c1"># Reshape for Yamoe kernel</span>
<span class="n">hidden_states_flat</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
<span class="n">routing_weights_flat</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">)</span>
<span class="c1"># Call Yamoe optimized kernel</span>
<span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">yamoe</span><span class="o">.</span><span class="n">experts</span><span class="p">(</span>
<span class="n">hidden_states_flat</span><span class="p">,</span>
<span class="n">router_indices</span><span class="p">,</span>
<span class="n">routing_weights_flat</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span>
<span class="p">)</span>
<span class="c1"># Reshape output back</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">routing_weights</span>
<span class="c1"># Run the model</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span> <span class="k">if</span> <span class="n">DEVICE</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span> <span class="k">else</span> <span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">=== Yamoe Implementation ===&quot;</span><span class="p">)</span>
<span class="c1"># Initialize model with loaded weights</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">YamoeMoEMLP</span><span class="p">(</span>
<span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="c1"># Generate input</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>
<span class="c1"># Benchmark the model</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span>
<span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">&quot;yamoe_results.json&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span>
<span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-yamoe_run" class="cell-output">
<div class="cell-stdout">Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73
Loaded shared weights from artifacts
Router weight sum: 12.588732
Gate/up sum: 1026.601807
Down sum: 206.729263
=== Yamoe Implementation ===
Router weight sum: 12.588732
Gate/up proj sum: 1026.601807
Down proj sum: 206.729340
┌─ Benchmark Configuration ─────────────────────────────┐
│ Warmup: 10 Iters: 50 │
│ Tokens: 100 │
└────────────────────────────────────────────────────────┘
Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163
Warming up (10 iterations)...
Benchmarking (50 iterations)...
Progress: 20% complete (avg: 8.633 ms)
Progress: 40% complete (avg: 8.627 ms)
Progress: 60% complete (avg: 8.629 ms)
Progress: 80% complete (avg: 8.630 ms)
Output tensors:
Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071
Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632
━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━
Iterations: 50
Latency Statistics:
Average: 8.631 ms
Min: 8.526 ms
Max: 8.661 ms
Std Dev: 0.022 ms
Percentiles:
P50 (median): 8.636 ms
P95: 8.653 ms
P99: 8.658 ms
Throughput:
Tokens/sec: 11586.6
Std Dev: 29.1
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Saved benchmark results to yamoe_results.json
Output sum: -0.597250
</div>
<div class="cell-stderr">Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading sympy (6.0MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading torch (846.8MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading hf-xet (3.0MiB)
Downloading setuptools (1.1MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading triton (148.4MiB)
Downloading nvidia-cufile-cu12
Downloading hf-xet
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading sympy
Downloading nvidia-nvjitlink-cu12
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 37 packages in 287ms
Fetching 6 files: 0%| | 0/6 [00:00&lt;?, ?it/s]
Fetching 6 files: 17%|█▋ | 1/6 [00:00&lt;00:01, 3.90it/s]
Fetching 6 files: 50%|█████ | 3/6 [00:00&lt;00:00, 3.70it/s]
Fetching 6 files: 100%|██████████| 6/6 [00:00&lt;00:00, 7.44it/s]
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/yamoe_run/yamoe_results.json" class="artifact" target="_blank">yamoe_results.json</a>
</div>
</div>
</div>
<h2>Binned Implementation</h2>
<p>This section runs the binned implementation that manually handles token gathering/scattering.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('binned_run')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('binned_run')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: binned_run | deps: torch, numpy | 42.05s
| <button class="run-btn" onclick="runCell('binned_run')">▶ run</button>
<button class="copy-btn" onclick="copyCell('binned_run')">Copy</button>
</div>
<div id="code-binned_run" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span>
<span class="normal">111</span>
<span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span>
<span class="normal">130</span>
<span class="normal">131</span>
<span class="normal">132</span>
<span class="normal">133</span>
<span class="normal">134</span>
<span class="normal">135</span>
<span class="normal">136</span>
<span class="normal">137</span>
<span class="normal">138</span>
<span class="normal">139</span>
<span class="normal">140</span>
<span class="normal">141</span>
<span class="normal">142</span>
<span class="normal">143</span>
<span class="normal">144</span>
<span class="normal">145</span>
<span class="normal">146</span>
<span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span>
<span class="normal">150</span>
<span class="normal">151</span>
<span class="normal">152</span>
<span class="normal">153</span>
<span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span>
<span class="normal">157</span>
<span class="normal">158</span>
<span class="normal">159</span>
<span class="normal">160</span>
<span class="normal">161</span>
<span class="normal">162</span>
<span class="normal">163</span>
<span class="normal">164</span>
<span class="normal">165</span>
<span class="normal">166</span>
<span class="normal">167</span>
<span class="normal">168</span>
<span class="normal">169</span>
<span class="normal">170</span>
<span class="normal">171</span>
<span class="normal">172</span>
<span class="normal">173</span>
<span class="normal">174</span>
<span class="normal">175</span>
<span class="normal">176</span>
<span class="normal">177</span>
<span class="normal">178</span>
<span class="normal">179</span>
<span class="normal">180</span>
<span class="normal">181</span>
<span class="normal">182</span>
<span class="normal">183</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
<span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span>
<span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span>
<span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span>
<span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="c1"># Discover the upstream artifact directory from env</span>
<span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_SAVE_DATA&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_weight.pt&#39;</span><span class="p">)</span>
<span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_bias.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Loaded shared weights from artifacts&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">binned_gather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">top_k</span><span class="p">):</span>
<span class="n">E</span><span class="p">,</span> <span class="n">H</span> <span class="o">=</span> <span class="n">bins</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">E</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">H</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">E</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">e</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span><span class="p">]</span>
<span class="n">n</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="n">flat_pos</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">start</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">flat_pos</span> <span class="o">//</span> <span class="n">top_k</span>
<span class="n">out</span><span class="p">[</span><span class="n">e</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">tok</span><span class="p">]</span>
<span class="k">return</span> <span class="n">out</span>
<span class="k">def</span><span class="w"> </span><span class="nf">binned_scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">top_k</span><span class="p">):</span>
<span class="n">E</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">H</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
<span class="n">N</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">top_k</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span> <span class="n">top_k</span><span class="p">,</span> <span class="n">H</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">for</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">E</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">e</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">bins</span><span class="p">[</span><span class="n">e</span><span class="p">]</span>
<span class="n">n</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span>
<span class="k">if</span> <span class="n">n</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">take</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">take</span><span class="p">):</span>
<span class="n">flat_pos</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">start</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">flat_pos</span> <span class="o">//</span> <span class="n">top_k</span>
<span class="n">slot</span> <span class="o">=</span> <span class="n">flat_pos</span> <span class="o">%</span> <span class="n">top_k</span>
<span class="n">scale</span> <span class="o">=</span> <span class="n">weights</span><span class="p">[</span><span class="n">flat_pos</span><span class="p">]</span> <span class="k">if</span> <span class="n">weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mf">1.0</span>
<span class="n">out</span><span class="p">[</span><span class="n">tok</span><span class="p">,</span> <span class="n">slot</span><span class="p">]</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">e</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale</span>
<span class="k">return</span> <span class="n">out</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">sort_tokens_by_expert</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">num_experts</span><span class="p">):</span>
<span class="n">flat_indices</span> <span class="o">=</span> <span class="n">router_indices</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="n">sorted_values</span><span class="p">,</span> <span class="n">sorted_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">flat_indices</span><span class="p">)</span>
<span class="n">tokens_per_expert</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bincount</span><span class="p">(</span><span class="n">sorted_values</span><span class="p">,</span> <span class="n">minlength</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span>
<span class="n">bins</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">tokens_per_expert</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="n">sorted_indices</span><span class="p">,</span> <span class="n">sorted_values</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">tokens_per_expert</span>
<span class="k">def</span><span class="w"> </span><span class="nf">binned_experts_ref</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">,</span>
<span class="n">router_indices</span><span class="p">,</span>
<span class="n">routing_weights</span><span class="p">,</span>
<span class="n">gate_up_proj</span><span class="p">,</span>
<span class="n">gate_up_proj_bias</span><span class="p">,</span>
<span class="n">down_proj</span><span class="p">,</span>
<span class="n">down_proj_bias</span><span class="p">,</span>
<span class="n">expert_capacity</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">B</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">H</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span>
<span class="n">E</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">router_indices</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">indices</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">sort_tokens_by_expert</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">E</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">binned_gather</span><span class="p">(</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span><span class="p">),</span> <span class="n">indices</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="n">gate_up</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">)</span>
<span class="n">gate_up</span> <span class="o">+=</span> <span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
<span class="c1"># clamp to limit</span>
<span class="n">limit</span> <span class="o">=</span> <span class="mf">7.0</span>
<span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
<span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
<span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="mf">1.702</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">)</span> <span class="o">+</span> <span class="n">down_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="c1"># build routing weights aligned to (token, slot)</span>
<span class="n">flat_dense</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">E</span><span class="p">)</span>
<span class="n">flat_router</span> <span class="o">=</span> <span class="n">router_indices</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="n">selected</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="n">flat_dense</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">flat_router</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># scatter back</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">binned_scatter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">selected</span><span class="p">,</span> <span class="n">bins</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="k">return</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">S</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">BinnedRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span>
<span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span>
<span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span>
<span class="k">class</span><span class="w"> </span><span class="nc">BinnedMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">BinnedRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span> <span class="o">=</span> <span class="mi">256</span>
<span class="c1"># Expert weights - use the loaded weights</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">binned_experts_ref</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">,</span>
<span class="n">router_indices</span><span class="p">,</span>
<span class="n">router_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">router_scores</span>
<span class="c1"># Run the model</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">=== Binned Implementation ===&quot;</span><span class="p">)</span>
<span class="c1"># Initialize model with loaded weights</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">BinnedMoEMLP</span><span class="p">(</span>
<span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="c1"># Generate the same input as Yamoe</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>
<span class="c1"># Benchmark the model</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span>
<span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">&quot;binned_results.json&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span>
<span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-binned_run" class="cell-output">
<div class="cell-stdout">Loaded shared weights from artifacts
Router weight sum: 12.588732
Gate/up sum: 1026.601807
Down sum: 206.729263
=== Binned Implementation ===
Router weight sum: 12.588732
Gate/up proj sum: 1026.601807
Down proj sum: 206.729340
┌─ Benchmark Configuration ─────────────────────────────┐
│ Warmup: 10 Iters: 50 │
│ Tokens: 100 │
└────────────────────────────────────────────────────────┘
Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163
Warming up (10 iterations)...
Benchmarking (50 iterations)...
Progress: 20% complete (avg: 104.222 ms)
Progress: 40% complete (avg: 104.671 ms)
Progress: 60% complete (avg: 105.372 ms)
Progress: 80% complete (avg: 105.570 ms)
Output tensors:
Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071
Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632
━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━
Iterations: 50
Latency Statistics:
Average: 105.618 ms
Min: 103.417 ms
Max: 107.809 ms
Std Dev: 1.458 ms
Percentiles:
P50 (median): 105.048 ms
P95: 107.729 ms
P99: 107.790 ms
Throughput:
Tokens/sec: 946.8
Std Dev: 13.0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Saved benchmark results to binned_results.json
Output sum: -0.597248
</div>
<div class="cell-stderr">Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading triton (148.4MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading torch (846.8MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading setuptools (1.1MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading sympy (6.0MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading nvidia-nvjitlink-cu12
Downloading sympy
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 233ms
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/binned_run/binned_results.json" class="artifact" target="_blank">binned_results.json</a>
</div>
</div>
</div>
<h2>GPT-OSS Implementation</h2>
<p>This section runs the GPT-OSS MoE implementation with manual expert loop handling.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('gptoss_run')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('gptoss_run')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: gptoss_run | deps: torch, numpy | 37.86s
| <button class="run-btn" onclick="runCell('gptoss_run')">▶ run</button>
<button class="copy-btn" onclick="copyCell('gptoss_run')">Copy</button>
</div>
<div id="code-gptoss_run" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span>
<span class="normal">111</span>
<span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span>
<span class="normal">130</span>
<span class="normal">131</span>
<span class="normal">132</span>
<span class="normal">133</span>
<span class="normal">134</span>
<span class="normal">135</span>
<span class="normal">136</span>
<span class="normal">137</span>
<span class="normal">138</span>
<span class="normal">139</span>
<span class="normal">140</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
<span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span>
<span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span>
<span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span>
<span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="c1"># Discover the upstream artifact directory from env</span>
<span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_SAVE_DATA&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_weight.pt&#39;</span><span class="p">)</span>
<span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_bias.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Loaded shared weights from artifacts&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">GptOssRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span>
<span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span>
<span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span>
<span class="k">class</span><span class="w"> </span><span class="nc">GptOssExperts</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.702</span>
<span class="bp">self</span><span class="o">.</span><span class="n">limit</span> <span class="o">=</span> <span class="mf">7.0</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="n">num_experts</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="s2">&quot;cpu&quot;</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">expert_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span>
<span class="n">expert_mask</span> <span class="o">=</span> <span class="n">expert_mask</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">expert_hit</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="n">expert_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)),</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span>
<span class="k">for</span> <span class="n">expert_idx</span> <span class="ow">in</span> <span class="n">expert_hit</span><span class="p">[:]:</span>
<span class="n">expert_idx</span> <span class="o">=</span> <span class="n">expert_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">_</span><span class="p">,</span> <span class="n">token_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">])</span>
<span class="n">current_state</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="n">token_idx</span><span class="p">]</span>
<span class="n">gate_up</span> <span class="o">=</span> <span class="n">current_state</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span>
<span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
<span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span>
<span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span>
<span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span>
<span class="n">gated_output</span> <span class="o">=</span> <span class="p">(</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">gated_output</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span>
<span class="n">weighted_output</span> <span class="o">=</span> <span class="n">out</span> <span class="o">*</span> <span class="n">routing_weights</span><span class="p">[</span><span class="n">token_idx</span><span class="p">,</span> <span class="n">expert_idx</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">next_states</span><span class="o">.</span><span class="n">index_add_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">token_idx</span><span class="p">,</span> <span class="n">weighted_output</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="n">gate_up</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
<span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span>
<span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span>
<span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(((</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">)</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span> <span class="o">*</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="n">next_states</span>
<span class="k">class</span><span class="w"> </span><span class="nc">GptOssMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">GptOssRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">GptOssExperts</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="n">routed_out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="n">router_scores</span><span class="p">)</span>
<span class="k">return</span> <span class="n">routed_out</span><span class="p">,</span> <span class="n">router_scores</span>
<span class="c1"># Run the model</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">=== GPT-OSS Implementation ===&quot;</span><span class="p">)</span>
<span class="c1"># Initialize model with loaded weights</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">GptOssMoEMLP</span><span class="p">(</span>
<span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="c1"># Generate the same input as other implementations</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>
<span class="c1"># Benchmark the model</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span>
<span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">&quot;gptoss_results.json&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span>
<span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-gptoss_run" class="cell-output">
<div class="cell-stdout">Loaded shared weights from artifacts
Router weight sum: 12.588732
Gate/up sum: 1026.601807
Down sum: 206.729263
=== GPT-OSS Implementation ===
Router weight sum: 12.588732
Gate/up proj sum: 1026.601807
Down proj sum: 206.729340
┌─ Benchmark Configuration ─────────────────────────────┐
│ Warmup: 10 Iters: 50 │
│ Tokens: 100 │
└────────────────────────────────────────────────────────┘
Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163
Warming up (10 iterations)...
Benchmarking (50 iterations)...
Progress: 20% complete (avg: 46.973 ms)
Progress: 40% complete (avg: 47.262 ms)
Progress: 60% complete (avg: 47.067 ms)
Progress: 80% complete (avg: 46.985 ms)
Output tensors:
Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071
Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632
━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━
Iterations: 50
Latency Statistics:
Average: 47.135 ms
Min: 46.582 ms
Max: 47.895 ms
Std Dev: 0.503 ms
Percentiles:
P50 (median): 46.789 ms
P95: 47.801 ms
P99: 47.856 ms
Throughput:
Tokens/sec: 2121.6
Std Dev: 22.5
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Saved benchmark results to gptoss_results.json
Output sum: -0.597250
</div>
<div class="cell-stderr">Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading setuptools (1.1MiB)
Downloading sympy (6.0MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading numpy (15.9MiB)
Downloading networkx (1.9MiB)
Downloading torch (846.8MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading triton (148.4MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading sympy
Downloading nvidia-nvjitlink-cu12
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 241ms
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/gptoss_run/gptoss_results.json" class="artifact" target="_blank">gptoss_results.json</a>
</div>
</div>
</div>
<h2>GPT-OSS Implementation (Training Mode)</h2>
<p>This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('gptoss_training_run')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('gptoss_training_run')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: gptoss_training_run | deps: torch, numpy | 36.75s
| <button class="run-btn" onclick="runCell('gptoss_training_run')">▶ run</button>
<button class="copy-btn" onclick="copyCell('gptoss_training_run')">Copy</button>
</div>
<div id="code-gptoss_training_run" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span>
<span class="normal">111</span>
<span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span>
<span class="normal">130</span>
<span class="normal">131</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
<span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span>
<span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span>
<span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span>
<span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="c1"># Discover the upstream artifact directory from env</span>
<span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_SAVE_DATA&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_weight.pt&#39;</span><span class="p">)</span>
<span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_bias.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Loaded shared weights from artifacts&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">GptOssTrainingRouter</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">TOP_K</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_weight</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">router_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span><span class="p">)</span>
<span class="n">router_logits</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="n">router_top_value</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">router_top_value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_top_value</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">router_top_value</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">router_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">router_logits</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">router_indices</span><span class="p">,</span> <span class="n">router_top_value</span><span class="p">)</span>
<span class="k">return</span> <span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span>
<span class="k">class</span><span class="w"> </span><span class="nc">GptOssTrainingExperts</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">NUM_EXPERTS</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.702</span>
<span class="bp">self</span><span class="o">.</span><span class="n">limit</span> <span class="o">=</span> <span class="mf">7.0</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="n">num_experts</span> <span class="o">=</span> <span class="n">routing_weights</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="c1"># Force training mode path (expert loop instead of batched)</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">expert_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span>
<span class="n">expert_mask</span> <span class="o">=</span> <span class="n">expert_mask</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">expert_hit</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">greater</span><span class="p">(</span><span class="n">expert_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)),</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span>
<span class="k">for</span> <span class="n">expert_idx</span> <span class="ow">in</span> <span class="n">expert_hit</span><span class="p">[:]:</span>
<span class="n">expert_idx</span> <span class="o">=</span> <span class="n">expert_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">_</span><span class="p">,</span> <span class="n">token_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">])</span>
<span class="n">current_state</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="p">[</span><span class="n">token_idx</span><span class="p">]</span>
<span class="n">gate_up</span> <span class="o">=</span> <span class="n">current_state</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">gate_up_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span>
<span class="n">gate</span><span class="p">,</span> <span class="n">up</span> <span class="o">=</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="n">gate_up</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
<span class="n">gate</span> <span class="o">=</span> <span class="n">gate</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span>
<span class="n">up</span> <span class="o">=</span> <span class="n">up</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">limit</span><span class="p">)</span>
<span class="n">glu</span> <span class="o">=</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gate</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span>
<span class="n">gated_output</span> <span class="o">=</span> <span class="p">(</span><span class="n">up</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">glu</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">gated_output</span> <span class="o">@</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">down_proj_bias</span><span class="p">[</span><span class="n">expert_idx</span><span class="p">]</span>
<span class="n">weighted_output</span> <span class="o">=</span> <span class="n">out</span> <span class="o">*</span> <span class="n">routing_weights</span><span class="p">[</span><span class="n">token_idx</span><span class="p">,</span> <span class="n">expert_idx</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
<span class="n">next_states</span><span class="o">.</span><span class="n">index_add_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">token_idx</span><span class="p">,</span> <span class="n">weighted_output</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">hidden_states</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">next_states</span> <span class="o">=</span> <span class="n">next_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="k">return</span> <span class="n">next_states</span>
<span class="k">class</span><span class="w"> </span><span class="nc">GptOssTrainingMoEMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">,</span> <span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">GptOssTrainingRouter</span><span class="p">(</span><span class="n">router_weight</span><span class="p">,</span> <span class="n">router_bias</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">GptOssTrainingExperts</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="p">,</span> <span class="n">gate_up_proj_bias</span><span class="p">,</span> <span class="n">down_proj</span><span class="p">,</span> <span class="n">down_proj_bias</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="n">router_scores</span><span class="p">,</span> <span class="n">router_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="n">routed_out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">router_indices</span><span class="o">=</span><span class="n">router_indices</span><span class="p">,</span> <span class="n">routing_weights</span><span class="o">=</span><span class="n">router_scores</span><span class="p">)</span>
<span class="k">return</span> <span class="n">routed_out</span><span class="p">,</span> <span class="n">router_scores</span>
<span class="c1"># Run the model</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">=== GPT-OSS Implementation (Training Mode - Expert Loop) ===&quot;</span><span class="p">)</span>
<span class="c1"># Initialize model with loaded weights and force training mode</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">GptOssTrainingMoEMLP</span><span class="p">(</span>
<span class="n">router_weight</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">router_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
<span class="n">down_proj_bias</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># Set to training mode to force expert loop path</span>
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down proj sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">experts</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Model training mode: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">training</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="c1"># Generate the same input as other implementations</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>
<span class="c1"># Benchmark the model</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span>
<span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">&quot;gptoss_training_results.json&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span>
<span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-gptoss_training_run" class="cell-output">
<div class="cell-stdout">Loaded shared weights from artifacts
Router weight sum: 12.588732
Gate/up sum: 1026.601807
Down sum: 206.729263
=== GPT-OSS Implementation (Training Mode - Expert Loop) ===
Router weight sum: 12.588732
Gate/up proj sum: 1026.601807
Down proj sum: 206.729340
Model training mode: True
┌─ Benchmark Configuration ─────────────────────────────┐
│ Warmup: 10 Iters: 50 │
│ Tokens: 100 │
└────────────────────────────────────────────────────────┘
Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163
Warming up (10 iterations)...
Benchmarking (50 iterations)...
Progress: 20% complete (avg: 48.328 ms)
Progress: 40% complete (avg: 48.764 ms)
Progress: 60% complete (avg: 48.825 ms)
Progress: 80% complete (avg: 48.769 ms)
Output tensors:
Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071
Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632
━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━
Iterations: 50
Latency Statistics:
Average: 48.630 ms
Min: 47.535 ms
Max: 49.414 ms
Std Dev: 0.559 ms
Percentiles:
P50 (median): 48.395 ms
P95: 49.346 ms
P99: 49.390 ms
Throughput:
Tokens/sec: 2056.3
Std Dev: 23.6
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Saved benchmark results to gptoss_training_results.json
Output sum: -0.597250
</div>
<div class="cell-stderr">Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading numpy (15.9MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading sympy (6.0MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading setuptools (1.1MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading triton (148.4MiB)
Downloading torch (846.8MiB)
Downloading nvidia-cufile-cu12
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading nvidia-nvjitlink-cu12
Downloading sympy
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 26 packages in 234ms
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/gptoss_training_run/gptoss_training_results.json" class="artifact" target="_blank">gptoss_training_results.json</a>
</div>
</div>
</div>
<h2>MegaBlocks Implementation</h2>
<p>This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('megablocks_run')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('megablocks_run')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: megablocks_run | deps: torch, numpy, kernels | 43.51s
| <button class="run-btn" onclick="runCell('megablocks_run')">▶ run</button>
<button class="copy-btn" onclick="copyCell('megablocks_run')">Copy</button>
</div>
<div id="code-megablocks_run" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">get_kernel</span><span class="p">,</span> <span class="n">get_local_kernel</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">bench_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">to_dtype</span><span class="p">,</span> <span class="n">tensor_stats</span><span class="p">,</span> <span class="n">set_seed</span><span class="p">,</span> <span class="n">bench_context</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">config</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span>
<span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">TOP_K</span><span class="p">,</span>
<span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">DTYPE</span><span class="p">,</span> <span class="n">DEVICE</span><span class="p">,</span>
<span class="n">WEIGHT_SEED</span><span class="p">,</span> <span class="n">EXPERT_SEED</span><span class="p">,</span> <span class="n">INPUT_SEED</span><span class="p">,</span> <span class="n">GENERAL_SEED</span>
<span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">namedtuple</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="c1"># Discover the upstream artifact directory from env</span>
<span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_SAVE_DATA&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Loading weights from: </span><span class="si">{</span><span class="n">data_dir</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">router_weight</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_weight.pt&#39;</span><span class="p">)</span>
<span class="n">router_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;router_bias.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj.pt&#39;</span><span class="p">)</span>
<span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;gate_up_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj.pt&#39;</span><span class="p">)</span>
<span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">&#39;down_proj_bias.pt&#39;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Loaded shared weights from artifacts&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Router weight sum: </span><span class="si">{</span><span class="n">router_weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Gate/up sum: </span><span class="si">{</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Down sum: </span><span class="si">{</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">build_megablocks_model</span><span class="p">(</span><span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">):</span>
<span class="c1"># Download optimized kernels from the Hugging Face hub</span>
<span class="n">megablocks</span> <span class="o">=</span> <span class="n">get_kernel</span><span class="p">(</span><span class="s2">&quot;kernels-community/megablocks&quot;</span><span class="p">)</span>
<span class="c1"># megablocks = get_local_kernel(</span>
<span class="c1"># Path(&quot;/home/ubuntu/Projects/megablocks-moe/build&quot;), &quot;megablocks&quot;)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">megablocks</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MegaBlocksMoeMLP</span><span class="p">()</span>
<span class="c1"># Create attribute container for expert weights</span>
<span class="n">model</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="n">namedtuple</span><span class="p">(</span>
<span class="s2">&quot;Experts&quot;</span><span class="p">,</span> <span class="p">[</span><span class="s2">&quot;gate_up_proj&quot;</span><span class="p">,</span> <span class="s2">&quot;gate_up_proj_bias&quot;</span><span class="p">,</span> <span class="s2">&quot;down_proj&quot;</span><span class="p">,</span> <span class="s2">&quot;down_proj_bias&quot;</span><span class="p">,</span> <span class="s2">&quot;hidden_size&quot;</span><span class="p">]</span>
<span class="p">)</span>
<span class="c1"># Use loaded router weights for consistency</span>
<span class="n">model</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">NUM_EXPERTS</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">router_weight</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">router_bias</span><span class="p">)</span>
<span class="c1"># Attach loaded expert weights to the experts container</span>
<span class="n">e</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">experts</span>
<span class="n">e</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.702</span>
<span class="n">e</span><span class="o">.</span><span class="n">capacity_factor</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span>
<span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">gate_up_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span>
<span class="n">e</span><span class="o">.</span><span class="n">down_proj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span>
<span class="n">e</span><span class="o">.</span><span class="n">down_proj_bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">down_proj_bias</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span>
<span class="n">e</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">HIDDEN_SIZE</span>
<span class="c1"># Log weight statistics for comparison</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;[MegaBlocks] Router weight sum: </span><span class="si">{</span><span class="n">model</span><span class="o">.</span><span class="n">router</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;[MegaBlocks] Gate/up projection shape: </span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, sum: </span><span class="si">{</span><span class="n">e</span><span class="o">.</span><span class="n">gate_up_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;[MegaBlocks] Down projection shape: </span><span class="si">{</span><span class="nb">tuple</span><span class="p">(</span><span class="n">e</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="si">}</span><span class="s2">, sum: </span><span class="si">{</span><span class="n">e</span><span class="o">.</span><span class="n">down_proj</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">model</span>
<span class="c1"># Create a wrapper to match the interface of other implementations</span>
<span class="k">class</span><span class="w"> </span><span class="nc">MegaBlocksMoEWrapper</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">megablocks_model</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">megablocks_model</span>
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">):</span>
<span class="c1"># MegaBlocks expects input in the format (batch, seq_len, hidden_dim)</span>
<span class="n">output</span><span class="p">,</span> <span class="n">dummy_routing_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">)</span>
<span class="c1"># Return output and dummy routing weights for consistency with other implementations</span>
<span class="c1"># dummy_routing_weights = torch.zeros(</span>
<span class="c1"># hidden_states.shape[0] * hidden_states.shape[1], </span>
<span class="c1"># NUM_EXPERTS, </span>
<span class="c1"># device=hidden_states.device,</span>
<span class="c1"># dtype=hidden_states.dtype</span>
<span class="c1"># )</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">dummy_routing_weights</span>
<span class="c1"># Run the model</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">GENERAL_SEED</span><span class="p">)</span>
<span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="n">DEVICE</span><span class="p">)</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">to_dtype</span><span class="p">(</span><span class="n">DTYPE</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">=== MegaBlocks Implementation ===&quot;</span><span class="p">)</span>
<span class="c1"># Build MegaBlocks model with loaded weights</span>
<span class="n">megablocks_model</span> <span class="o">=</span> <span class="n">build_megablocks_model</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MegaBlocksMoEWrapper</span><span class="p">(</span><span class="n">megablocks_model</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># Generate the same input as other implementations</span>
<span class="n">set_seed</span><span class="p">(</span><span class="n">INPUT_SEED</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">SEQ_LEN</span><span class="p">,</span> <span class="n">HIDDEN_SIZE</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.1</span>
<span class="c1"># Benchmark the model</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="n">SEQ_LEN</span>
<span class="k">with</span> <span class="n">bench_context</span><span class="p">(</span><span class="n">warmup</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">iters</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">tokens</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">save_json</span><span class="o">=</span><span class="s2">&quot;megablocks_results.json&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">bench</span><span class="p">:</span>
<span class="n">output</span><span class="p">,</span> <span class="n">stats</span> <span class="o">=</span> <span class="n">bench</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Output sum: </span><span class="si">{</span><span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-megablocks_run" class="cell-output">
<div class="cell-stdout">Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73
Loaded shared weights from artifacts
Router weight sum: 12.588732
Gate/up sum: 1026.601807
Down sum: 206.729263
=== MegaBlocks Implementation ===
[MegaBlocks] Router weight sum: 12.588732
[MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807
[MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340
┌─ Benchmark Configuration ─────────────────────────────┐
│ Warmup: 10 Iters: 50 │
│ Tokens: 100 │
└────────────────────────────────────────────────────────┘
Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163
Warming up (10 iterations)...
Benchmarking (50 iterations)...
Progress: 20% complete (avg: 0.867 ms)
Progress: 40% complete (avg: 0.853 ms)
Progress: 60% complete (avg: 1.181 ms)
Progress: 80% complete (avg: 3.026 ms)
Output tensors:
Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071
Auxiliary: shape=(100, 4), dtype=torch.float32, device=cuda:0, range=[0.220910, 0.294473], mean=0.250000, std=0.010777, norm=5.004632
━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━
Iterations: 50
Latency Statistics:
Average: 4.133 ms
Min: 0.823 ms
Max: 8.589 ms
Std Dev: 3.781 ms
Percentiles:
P50 (median): 0.864 ms
P95: 8.579 ms
P99: 8.589 ms
Throughput:
Tokens/sec: 24194.9
Std Dev: 52511.7
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Saved benchmark results to megablocks_results.json
Output sum: -0.597249
</div>
<div class="cell-stderr">Downloading setuptools (1.1MiB)
Downloading nvidia-cudnn-cu12 (674.0MiB)
Downloading numpy (15.9MiB)
Downloading nvidia-cusparse-cu12 (274.9MiB)
Downloading nvidia-nvjitlink-cu12 (37.4MiB)
Downloading hf-xet (3.0MiB)
Downloading nvidia-cusolver-cu12 (255.1MiB)
Downloading networkx (1.9MiB)
Downloading nvidia-cufft-cu12 (184.2MiB)
Downloading nvidia-cufile-cu12 (1.1MiB)
Downloading triton (148.4MiB)
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
Downloading nvidia-curand-cu12 (60.7MiB)
Downloading sympy (6.0MiB)
Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
Downloading nvidia-nccl-cu12 (307.4MiB)
Downloading nvidia-cusparselt-cu12 (273.9MiB)
Downloading nvidia-cublas-cu12 (566.8MiB)
Downloading torch (846.8MiB)
Downloading nvidia-cufile-cu12
Downloading hf-xet
Downloading setuptools
Downloading networkx
Downloading nvidia-cuda-cupti-cu12
Downloading numpy
Downloading sympy
Downloading nvidia-nvjitlink-cu12
Downloading nvidia-curand-cu12
Downloading nvidia-cuda-nvrtc-cu12
Downloading triton
Downloading nvidia-cufft-cu12
Downloading nvidia-cusolver-cu12
Downloading nvidia-cusparse-cu12
Downloading nvidia-cusparselt-cu12
Downloading nvidia-nccl-cu12
Downloading nvidia-cublas-cu12
Downloading nvidia-cudnn-cu12
Downloading torch
Installed 37 packages in 216ms
Fetching 66 files: 0%| | 0/66 [00:00&lt;?, ?it/s]
Fetching 66 files: 2%|▏ | 1/66 [00:00&lt;00:22, 2.87it/s]
Fetching 66 files: 26%|██▌ | 17/66 [00:01&lt;00:04, 11.84it/s]
Fetching 66 files: 100%|██████████| 66/66 [00:01&lt;00:00, 43.56it/s]
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/megablocks_run/megablocks_results.json" class="artifact" target="_blank">megablocks_results.json</a>
</div>
</div>
</div>
<h2>Performance Visualization</h2>
<p>This section reads all benchmark results and creates a comprehensive performance comparison chart.</p>
<div class="cell">
<div class="cell-header">
<span class="collapse-indicators">
<span onclick="toggleCode('visualization')" style="cursor: pointer;">▼ code</span>
<span onclick="toggleOutput('visualization')" style="cursor: pointer;">▼ output</span>
</span> |
Cell: visualization | deps: matplotlib | 3.96s
| <button class="run-btn" onclick="runCell('visualization')">▶ run</button>
<button class="copy-btn" onclick="copyCell('visualization')">Copy</button>
</div>
<div id="code-visualization" class="cell-code">
<div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span class="normal"> 1</span>
<span class="normal"> 2</span>
<span class="normal"> 3</span>
<span class="normal"> 4</span>
<span class="normal"> 5</span>
<span class="normal"> 6</span>
<span class="normal"> 7</span>
<span class="normal"> 8</span>
<span class="normal"> 9</span>
<span class="normal"> 10</span>
<span class="normal"> 11</span>
<span class="normal"> 12</span>
<span class="normal"> 13</span>
<span class="normal"> 14</span>
<span class="normal"> 15</span>
<span class="normal"> 16</span>
<span class="normal"> 17</span>
<span class="normal"> 18</span>
<span class="normal"> 19</span>
<span class="normal"> 20</span>
<span class="normal"> 21</span>
<span class="normal"> 22</span>
<span class="normal"> 23</span>
<span class="normal"> 24</span>
<span class="normal"> 25</span>
<span class="normal"> 26</span>
<span class="normal"> 27</span>
<span class="normal"> 28</span>
<span class="normal"> 29</span>
<span class="normal"> 30</span>
<span class="normal"> 31</span>
<span class="normal"> 32</span>
<span class="normal"> 33</span>
<span class="normal"> 34</span>
<span class="normal"> 35</span>
<span class="normal"> 36</span>
<span class="normal"> 37</span>
<span class="normal"> 38</span>
<span class="normal"> 39</span>
<span class="normal"> 40</span>
<span class="normal"> 41</span>
<span class="normal"> 42</span>
<span class="normal"> 43</span>
<span class="normal"> 44</span>
<span class="normal"> 45</span>
<span class="normal"> 46</span>
<span class="normal"> 47</span>
<span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span>
<span class="normal">105</span>
<span class="normal">106</span>
<span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span></pre></div></td><td class="code"><div><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">json</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="c1"># List of expected result files</span>
<span class="n">yamoe_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_YAMOE_RUN&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">binned_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_BINNED_RUN&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">gptoss_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_GPTOSS_RUN&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">gptoss_training_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_GPTOSS_TRAINING_RUN&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">megablocks_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;UVNOTE_INPUT_MEGABLOCKS_RUN&#39;</span><span class="p">,</span> <span class="s1">&#39;.&#39;</span><span class="p">)</span>
<span class="n">result_files</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">Path</span><span class="p">(</span><span class="n">yamoe_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">&quot;yamoe_results.json&quot;</span><span class="p">,</span>
<span class="n">Path</span><span class="p">(</span><span class="n">binned_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">&quot;binned_results.json&quot;</span><span class="p">,</span>
<span class="n">Path</span><span class="p">(</span><span class="n">gptoss_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">&quot;gptoss_results.json&quot;</span><span class="p">,</span>
<span class="n">Path</span><span class="p">(</span><span class="n">gptoss_training_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">&quot;gptoss_training_results.json&quot;</span><span class="p">,</span>
<span class="n">Path</span><span class="p">(</span><span class="n">megablocks_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s2">&quot;megablocks_results.json&quot;</span>
<span class="p">]</span>
<span class="c1"># Load all benchmark results</span>
<span class="n">results</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">file</span> <span class="ow">in</span> <span class="n">result_files</span><span class="p">:</span>
<span class="k">if</span> <span class="n">Path</span><span class="p">(</span><span class="n">file</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">results</span><span class="p">[</span><span class="n">data</span><span class="p">[</span><span class="s1">&#39;implementation&#39;</span><span class="p">]]</span> <span class="o">=</span> <span class="n">data</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Loaded </span><span class="si">{</span><span class="n">file</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Missing </span><span class="si">{</span><span class="n">file</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">results</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No benchmark results found. Run the benchmark cells first.&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Extract data for plotting</span>
<span class="n">implementations</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">results</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
<span class="n">avg_latencies</span> <span class="o">=</span> <span class="p">[</span><span class="n">results</span><span class="p">[</span><span class="n">impl</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span> <span class="k">for</span> <span class="n">impl</span> <span class="ow">in</span> <span class="n">implementations</span><span class="p">]</span>
<span class="n">p95_latencies</span> <span class="o">=</span> <span class="p">[</span><span class="n">results</span><span class="p">[</span><span class="n">impl</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;p95_ms&#39;</span><span class="p">]</span> <span class="k">for</span> <span class="n">impl</span> <span class="ow">in</span> <span class="n">implementations</span><span class="p">]</span>
<span class="n">throughputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">results</span><span class="p">[</span><span class="n">impl</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;tokens_per_s&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">impl</span> <span class="ow">in</span> <span class="n">implementations</span><span class="p">]</span>
<span class="c1"># Create figure with subplots</span>
<span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="n">ax2</span><span class="p">,</span> <span class="n">ax3</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">18</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
<span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s1">&#39;MoE Implementation Performance Comparison&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="c1"># Colors for each implementation</span>
<span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;#FF6B6B&#39;</span><span class="p">,</span> <span class="s1">&#39;#4ECDC4&#39;</span><span class="p">,</span> <span class="s1">&#39;#45B7D1&#39;</span><span class="p">,</span> <span class="s1">&#39;#96CEB4&#39;</span><span class="p">,</span> <span class="s1">&#39;#FECA57&#39;</span><span class="p">][:</span><span class="nb">len</span><span class="p">(</span><span class="n">implementations</span><span class="p">)]</span>
<span class="c1"># 1. Average Latency Chart</span>
<span class="n">bars1</span> <span class="o">=</span> <span class="n">ax1</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">implementations</span><span class="p">,</span> <span class="n">avg_latencies</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">colors</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">&#39;black&#39;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ax1</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;Average Latency&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">ax1</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;Latency (ms)&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="n">ax1</span><span class="o">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;x&#39;</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span>
<span class="n">ax1</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;y&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
<span class="c1"># Add value labels on bars</span>
<span class="k">for</span> <span class="n">bar</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bars1</span><span class="p">,</span> <span class="n">avg_latencies</span><span class="p">):</span>
<span class="n">ax1</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">bar</span><span class="o">.</span><span class="n">get_x</span><span class="p">()</span> <span class="o">+</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_width</span><span class="p">()</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_height</span><span class="p">()</span> <span class="o">+</span> <span class="nb">max</span><span class="p">(</span><span class="n">avg_latencies</span><span class="p">)</span><span class="o">*</span><span class="mf">0.01</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="s1">.2f</span><span class="si">}</span><span class="s1">ms&#39;</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s1">&#39;center&#39;</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s1">&#39;bottom&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="c1"># 2. P95 Latency Chart</span>
<span class="n">bars2</span> <span class="o">=</span> <span class="n">ax2</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">implementations</span><span class="p">,</span> <span class="n">p95_latencies</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">colors</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">&#39;black&#39;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;95th Percentile Latency&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;Latency (ms)&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;x&#39;</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;y&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
<span class="c1"># Add value labels on bars</span>
<span class="k">for</span> <span class="n">bar</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bars2</span><span class="p">,</span> <span class="n">p95_latencies</span><span class="p">):</span>
<span class="n">ax2</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">bar</span><span class="o">.</span><span class="n">get_x</span><span class="p">()</span> <span class="o">+</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_width</span><span class="p">()</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_height</span><span class="p">()</span> <span class="o">+</span> <span class="nb">max</span><span class="p">(</span><span class="n">p95_latencies</span><span class="p">)</span><span class="o">*</span><span class="mf">0.01</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="s1">.2f</span><span class="si">}</span><span class="s1">ms&#39;</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s1">&#39;center&#39;</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s1">&#39;bottom&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="c1"># 3. Throughput Chart</span>
<span class="n">bars3</span> <span class="o">=</span> <span class="n">ax3</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">implementations</span><span class="p">,</span> <span class="n">throughputs</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">colors</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">&#39;black&#39;</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">ax3</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;Throughput&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">ax3</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;Tokens/sec&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="n">ax3</span><span class="o">.</span><span class="n">tick_params</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;x&#39;</span><span class="p">,</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">45</span><span class="p">)</span>
<span class="n">ax3</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;y&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span>
<span class="c1"># Add value labels on bars</span>
<span class="k">for</span> <span class="n">bar</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">bars3</span><span class="p">,</span> <span class="n">throughputs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">val</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># Only show label if throughput was calculated</span>
<span class="n">ax3</span><span class="o">.</span><span class="n">text</span><span class="p">(</span><span class="n">bar</span><span class="o">.</span><span class="n">get_x</span><span class="p">()</span> <span class="o">+</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_width</span><span class="p">()</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="n">bar</span><span class="o">.</span><span class="n">get_height</span><span class="p">()</span> <span class="o">+</span> <span class="nb">max</span><span class="p">(</span><span class="n">throughputs</span><span class="p">)</span><span class="o">*</span><span class="mf">0.01</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="s1">.0f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s1">&#39;center&#39;</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s1">&#39;bottom&#39;</span><span class="p">,</span> <span class="n">fontweight</span><span class="o">=</span><span class="s1">&#39;bold&#39;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
<span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="s2">&quot;moe_performance_comparison.png&quot;</span><span class="p">,</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">300</span><span class="p">)</span>
<span class="c1"># Print summary table</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Performance Summary:&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;Implementation&#39;</span><span class="si">:</span><span class="s2">&lt;30</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">&#39;Avg (ms)&#39;</span><span class="si">:</span><span class="s2">&lt;12</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">&#39;P95 (ms)&#39;</span><span class="si">:</span><span class="s2">&lt;12</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">&#39;Tokens/sec&#39;</span><span class="si">:</span><span class="s2">&lt;12</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="s1">&#39;Relative Speed&#39;</span><span class="si">:</span><span class="s2">&lt;15</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;-&quot;</span><span class="o">*</span><span class="mi">80</span><span class="p">)</span>
<span class="c1"># Sort by average latency for relative speed calculation</span>
<span class="n">sorted_results</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">results</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">])</span>
<span class="n">fastest_latency</span> <span class="o">=</span> <span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">impl</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">sorted_results</span><span class="p">:</span>
<span class="n">avg_ms</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span>
<span class="n">p95_ms</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;p95_ms&#39;</span><span class="p">]</span>
<span class="n">tokens_s</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s1">&#39;stats&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;tokens_per_s&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">relative_speed</span> <span class="o">=</span> <span class="n">fastest_latency</span> <span class="o">/</span> <span class="n">avg_ms</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">impl</span><span class="si">:</span><span class="s2">&lt;30</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">avg_ms</span><span class="si">:</span><span class="s2">&gt;8.2f</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">p95_ms</span><span class="si">:</span><span class="s2">&gt;8.2f</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">tokens_s</span><span class="si">:</span><span class="s2">&gt;8.0f</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">relative_speed</span><span class="si">:</span><span class="s2">&gt;6.2f</span><span class="si">}</span><span class="s2">x&quot;</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Fastest: </span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">ms avg)&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">sorted_results</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Slowest: </span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">sorted_results</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">ms avg)&quot;</span><span class="p">)</span>
<span class="n">speedup</span> <span class="o">=</span> <span class="n">sorted_results</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span> <span class="o">/</span> <span class="n">sorted_results</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">][</span><span class="s1">&#39;stats&#39;</span><span class="p">][</span><span class="s1">&#39;avg_ms&#39;</span><span class="p">]</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Max Speedup: </span><span class="si">{</span><span class="n">speedup</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">x&quot;</span><span class="p">)</span>
</pre></div></td></tr></table></div>
</div>
<div id="output-visualization" class="cell-output">
<div class="cell-stdout">Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/c5c8a351e1080ea89737c25df783e5c81cd76df0f2b017cedfd813e3bdf2f9f9/yamoe_results.json
Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/af01d090b967f1cb05cacea7795553418933b27fc2f188da52f7c4642e456c24/binned_results.json
Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/cf359ebbdbfd10241ce11898ee298eefd5da768c42d502b034caf3ba5b16aed6/gptoss_results.json
Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/28eb2a85c2dc94e627a0c6373b55120bd67c549ef80cd5b5e94ae756ecd11aff/gptoss_training_results.json
Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/a712c225c474c8776a91d23a96a2d4dd5dde0716ed16f6eb0dce9d92b65e06b8/megablocks_results.json
Performance Summary:
Implementation Avg (ms) P95 (ms) Tokens/sec Relative Speed
--------------------------------------------------------------------------------
megablocks_results 4.13 8.58 24195 1.00x
yamoe_results 8.63 8.65 11587 0.48x
gptoss_results 47.14 47.80 2122 0.09x
gptoss_training_results 48.63 49.35 2056 0.08x
binned_results 105.62 107.73 947 0.04x
Fastest: megablocks_results (4.13ms avg)
Slowest: binned_results (105.62ms avg)
Max Speedup: 25.6x
</div>
<div class="cell-stderr">Downloading numpy (15.9MiB)
Downloading fonttools (4.7MiB)
Downloading pillow (6.3MiB)
Downloading matplotlib (8.3MiB)
Downloading kiwisolver (1.4MiB)
Downloading kiwisolver
Downloading pillow
Downloading fonttools
Downloading matplotlib
Downloading numpy
Installed 11 packages in 24ms
</div>
<div class="cell-artifacts">
<h4>Artifacts:</h4>
<a href="artifacts/visualization/moe_performance_comparison.png" class="artifact" target="_blank">moe_performance_comparison.png</a>
<div class="artifact-preview">
<img src="artifacts/visualization/moe_performance_comparison.png" alt="moe_performance_comparison.png">
</div>
</div>
</div>
</div>
</div>
</body>
</html>