Spaces:
Running
Running
| var e=(e=>typeof require<`u`?require:typeof Proxy<`u`?new Proxy(e,{get:(e,t)=>(typeof require<`u`?require:e)[t]}):e)(function(e){if(typeof require<`u`)return require.apply(this,arguments);throw Error('Calling `require` for "'+e+"\" in an environment that doesn't expose the `require` function. See https://rolldown.rs/in-depth/bundling-cjs#require-external-modules for more details.")}),t=Object.defineProperty,n=Object.getOwnPropertyDescriptor,r=Object.getOwnPropertyNames,i=Object.prototype.hasOwnProperty,a=(t=>typeof e<`u`?e:typeof Proxy<`u`?new Proxy(t,{get:(t,n)=>(typeof e<`u`?e:t)[n]}):t)(function(t){if(typeof e<`u`)return e.apply(this,arguments);throw Error(`Dynamic require of "`+t+`" is not supported`)}),o=(e,t)=>()=>(e&&(t=e(e=0)),t),s=(e,n)=>{for(var r in n)t(e,r,{get:n[r],enumerable:!0})},c=(e,a,o,s)=>{if(a&&typeof a==`object`||typeof a==`function`)for(let c of r(a))!i.call(e,c)&&c!==o&&t(e,c,{get:()=>a[c],enumerable:!(s=n(a,c))||s.enumerable});return e},l=e=>c(t({},`__esModule`,{value:!0}),e),u,d,f,p,m,h=o(()=>{u=new Map,d=[],f=(e,t,n)=>{if(t&&typeof t.init==`function`&&typeof t.createInferenceSessionHandler==`function`){let r=u.get(e);if(r===void 0)u.set(e,{backend:t,priority:n});else{if(r.priority>n)return;if(r.priority===n&&r.backend!==t)throw Error(`cannot register backend "${e}" using priority ${n}`)}if(n>=0){let t=d.indexOf(e);t!==-1&&d.splice(t,1);for(let t=0;t<d.length;t++)if(u.get(d[t]).priority<=n){d.splice(t,0,e);return}d.push(e)}return}throw TypeError(`not a valid backend`)},p=async e=>{let t=u.get(e);if(!t)return`backend not found.`;if(t.initialized)return t.backend;if(t.aborted)return t.error;{let n=!!t.initPromise;try{return n||(t.initPromise=t.backend.init(e)),await t.initPromise,t.initialized=!0,t.backend}catch(e){return n||(t.error=`${e}`,t.aborted=!0),t.error}finally{delete t.initPromise}}},m=async e=>{let t=e.executionProviders||[],n=t.map(e=>typeof e==`string`?e:e.name),r=n.length===0?d:n,i,a=[],o=new Set;for(let e of r){let t=await p(e);typeof t==`string`?a.push({name:e,err:t}):(i||=t,i===t&&o.add(e))}if(!i)throw Error(`no available backend found. ERR: ${a.map(e=>`[${e.name}] ${e.err}`).join(`, `)}`);for(let{name:e,err:t}of a)n.includes(e)&&console.warn(`removing requested execution provider "${e}" from session options because it is not available: ${t}`);let s=t.filter(e=>o.has(typeof e==`string`?e:e.name));return[i,new Proxy(e,{get:(e,t)=>t===`executionProviders`?s:Reflect.get(e,t)})]}}),g=o(()=>{h()}),_,v=o(()=>{_=`1.24.0-dev.20251116-b39e144322`}),y,b,x=o(()=>{v(),y=`warning`,b={wasm:{},webgl:{},webgpu:{},versions:{common:_},set logLevel(e){if(e!==void 0){if(typeof e!=`string`||[`verbose`,`info`,`warning`,`error`,`fatal`].indexOf(e)===-1)throw Error(`Unsupported logging level: ${e}`);y=e}},get logLevel(){return y}},Object.defineProperty(b,`logLevel`,{enumerable:!0})}),S,C=o(()=>{x(),S=b}),w,ee,T=o(()=>{w=(e,t)=>{let n=typeof document<`u`?document.createElement(`canvas`):new OffscreenCanvas(1,1);n.width=e.dims[3],n.height=e.dims[2];let r=n.getContext(`2d`);if(r!=null){let i,a;t?.tensorLayout!==void 0&&t.tensorLayout===`NHWC`?(i=e.dims[2],a=e.dims[3]):(i=e.dims[3],a=e.dims[2]);let o=t?.format===void 0?`RGB`:t.format,s=t?.norm,c,l;s===void 0||s.mean===void 0?c=[255,255,255,255]:typeof s.mean==`number`?c=[s.mean,s.mean,s.mean,s.mean]:(c=[s.mean[0],s.mean[1],s.mean[2],0],s.mean[3]!==void 0&&(c[3]=s.mean[3])),s===void 0||s.bias===void 0?l=[0,0,0,0]:typeof s.bias==`number`?l=[s.bias,s.bias,s.bias,s.bias]:(l=[s.bias[0],s.bias[1],s.bias[2],0],s.bias[3]!==void 0&&(l[3]=s.bias[3]));let u=a*i,d=0,f=u,p=u*2,m=-1;o===`RGBA`?(d=0,f=u,p=u*2,m=u*3):o===`RGB`?(d=0,f=u,p=u*2):o===`RBG`&&(d=0,p=u,f=u*2);for(let t=0;t<a;t++)for(let n=0;n<i;n++){let i=(e.data[d++]-l[0])*c[0],a=(e.data[f++]-l[1])*c[1],o=(e.data[p++]-l[2])*c[2],s=m===-1?255:(e.data[m++]-l[3])*c[3];r.fillStyle=`rgba(`+i+`,`+a+`,`+o+`,`+s+`)`,r.fillRect(n,t,1,1)}if(`toDataURL`in n)return n.toDataURL();throw Error(`toDataURL is not supported`)}else throw Error(`Can not access image data`)},ee=(e,t)=>{let n=typeof document<`u`?document.createElement(`canvas`).getContext(`2d`):new OffscreenCanvas(1,1).getContext(`2d`),r;if(n!=null){let i,a,o;t?.tensorLayout!==void 0&&t.tensorLayout===`NHWC`?(i=e.dims[2],a=e.dims[1],o=e.dims[3]):(i=e.dims[3],a=e.dims[2],o=e.dims[1]);let s=t!==void 0&&t.format!==void 0?t.format:`RGB`,c=t?.norm,l,u;c===void 0||c.mean===void 0?l=[255,255,255,255]:typeof c.mean==`number`?l=[c.mean,c.mean,c.mean,c.mean]:(l=[c.mean[0],c.mean[1],c.mean[2],255],c.mean[3]!==void 0&&(l[3]=c.mean[3])),c===void 0||c.bias===void 0?u=[0,0,0,0]:typeof c.bias==`number`?u=[c.bias,c.bias,c.bias,c.bias]:(u=[c.bias[0],c.bias[1],c.bias[2],0],c.bias[3]!==void 0&&(u[3]=c.bias[3]));let d=a*i;if(t!==void 0&&(t.format!==void 0&&o===4&&t.format!==`RGBA`||o===3&&t.format!==`RGB`&&t.format!==`BGR`))throw Error(`Tensor format doesn't match input tensor dims`);let f=0,p=1,m=2,h=3,g=0,_=d,v=d*2,y=-1;s===`RGBA`?(g=0,_=d,v=d*2,y=d*3):s===`RGB`?(g=0,_=d,v=d*2):s===`RBG`&&(g=0,v=d,_=d*2),r=n.createImageData(i,a);for(let t=0;t<a*i;f+=4,p+=4,m+=4,h+=4,t++)r.data[f]=(e.data[g++]-u[0])*l[0],r.data[p]=(e.data[_++]-u[1])*l[1],r.data[m]=(e.data[v++]-u[2])*l[2],r.data[h]=y===-1?255:(e.data[y++]-u[3])*l[3]}else throw Error(`Can not access image data`);return r}}),E,D,O,k,A,te,ne=o(()=>{fe(),E=(e,t)=>{if(e===void 0)throw Error(`Image buffer must be defined`);if(t.height===void 0||t.width===void 0)throw Error(`Image height and width must be defined`);if(t.tensorLayout===`NHWC`)throw Error(`NHWC Tensor layout is not supported yet`);let{height:n,width:r}=t,i=t.norm??{mean:255,bias:0},a,o;a=typeof i.mean==`number`?[i.mean,i.mean,i.mean,i.mean]:[i.mean[0],i.mean[1],i.mean[2],i.mean[3]??255],o=typeof i.bias==`number`?[i.bias,i.bias,i.bias,i.bias]:[i.bias[0],i.bias[1],i.bias[2],i.bias[3]??0];let s=t.format===void 0?`RGBA`:t.format,c=t.tensorFormat!==void 0&&t.tensorFormat!==void 0?t.tensorFormat:`RGB`,l=n*r,u=c===`RGBA`?new Float32Array(l*4):new Float32Array(l*3),d=4,f=0,p=1,m=2,h=3,g=0,_=l,v=l*2,y=-1;s===`RGB`&&(d=3,f=0,p=1,m=2,h=-1),c===`RGBA`?y=l*3:c===`RBG`?(g=0,v=l,_=l*2):c===`BGR`&&(v=0,_=l,g=l*2);for(let t=0;t<l;t++,f+=d,m+=d,p+=d,h+=d)u[g++]=(e[f]+o[0])/a[0],u[_++]=(e[p]+o[1])/a[1],u[v++]=(e[m]+o[2])/a[2],y!==-1&&h!==-1&&(u[y++]=(e[h]+o[3])/a[3]);return c===`RGBA`?new de(`float32`,u,[1,4,n,r]):new de(`float32`,u,[1,3,n,r])},D=async(e,t)=>{let n=typeof HTMLImageElement<`u`&&e instanceof HTMLImageElement,r=typeof ImageData<`u`&&e instanceof ImageData,i=typeof ImageBitmap<`u`&&e instanceof ImageBitmap,a=typeof e==`string`,o,s=t??{},c=()=>{if(typeof document<`u`)return document.createElement(`canvas`);if(typeof OffscreenCanvas<`u`)return new OffscreenCanvas(1,1);throw Error(`Canvas is not supported`)},l=e=>typeof HTMLCanvasElement<`u`&&e instanceof HTMLCanvasElement||e instanceof OffscreenCanvas?e.getContext(`2d`):null;if(n){let n=c();n.width=e.width,n.height=e.height;let r=l(n);if(r!=null){let n=e.height,i=e.width;if(t!==void 0&&t.resizedHeight!==void 0&&t.resizedWidth!==void 0&&(n=t.resizedHeight,i=t.resizedWidth),t!==void 0){if(s=t,t.tensorFormat!==void 0)throw Error(`Image input config format must be RGBA for HTMLImageElement`);s.tensorFormat=`RGBA`,s.height=n,s.width=i}else s.tensorFormat=`RGBA`,s.height=n,s.width=i;r.drawImage(e,0,0),o=r.getImageData(0,0,i,n).data}else throw Error(`Can not access image data`)}else if(r){let n,r;if(t!==void 0&&t.resizedWidth!==void 0&&t.resizedHeight!==void 0?(n=t.resizedHeight,r=t.resizedWidth):(n=e.height,r=e.width),t!==void 0&&(s=t),s.format=`RGBA`,s.height=n,s.width=r,t!==void 0){let t=c();t.width=r,t.height=n;let i=l(t);if(i!=null)i.putImageData(e,0,0),o=i.getImageData(0,0,r,n).data;else throw Error(`Can not access image data`)}else o=e.data}else if(i){if(t===void 0)throw Error(`Please provide image config with format for Imagebitmap`);let n=c();n.width=e.width,n.height=e.height;let r=l(n);if(r!=null){let t=e.height,n=e.width;return r.drawImage(e,0,0,n,t),o=r.getImageData(0,0,n,t).data,s.height=t,s.width=n,E(o,s)}else throw Error(`Can not access image data`)}else{if(a)return new Promise((t,n)=>{let r=c(),i=l(r);if(!e||!i)return n();let a=new Image;a.crossOrigin=`Anonymous`,a.src=e,a.onload=()=>{r.width=a.width,r.height=a.height,i.drawImage(a,0,0,r.width,r.height);let e=i.getImageData(0,0,r.width,r.height);s.height=r.height,s.width=r.width,t(E(e.data,s))}});throw Error(`Input data provided is not supported - aborted tensor creation`)}if(o!==void 0)return E(o,s);throw Error(`Input data provided is not supported - aborted tensor creation`)},O=(e,t)=>{let{width:n,height:r,download:i,dispose:a}=t;return new de({location:`texture`,type:`float32`,texture:e,dims:[1,r,n,4],download:i,dispose:a})},k=(e,t)=>{let{dataType:n,dims:r,download:i,dispose:a}=t;return new de({location:`gpu-buffer`,type:n??`float32`,gpuBuffer:e,dims:r,download:i,dispose:a})},A=(e,t)=>{let{dataType:n,dims:r,download:i,dispose:a}=t;return new de({location:`ml-tensor`,type:n??`float32`,mlTensor:e,dims:r,download:i,dispose:a})},te=(e,t,n)=>new de({location:`cpu-pinned`,type:e,data:t,dims:n??[t.length]})}),re,ie,ae,oe,se=o(()=>{re=new Map([[`float32`,Float32Array],[`uint8`,Uint8Array],[`int8`,Int8Array],[`uint16`,Uint16Array],[`int16`,Int16Array],[`int32`,Int32Array],[`bool`,Uint8Array],[`float64`,Float64Array],[`uint32`,Uint32Array],[`int4`,Uint8Array],[`uint4`,Uint8Array]]),ie=new Map([[Float32Array,`float32`],[Uint8Array,`uint8`],[Int8Array,`int8`],[Uint16Array,`uint16`],[Int16Array,`int16`],[Int32Array,`int32`],[Float64Array,`float64`],[Uint32Array,`uint32`]]),ae=!1,oe=()=>{if(!ae){ae=!0;let e=typeof BigInt64Array<`u`&&BigInt64Array.from,t=typeof BigUint64Array<`u`&&BigUint64Array.from,n=globalThis.Float16Array,r=typeof n<`u`&&n.from;e&&(re.set(`int64`,BigInt64Array),ie.set(BigInt64Array,`int64`)),t&&(re.set(`uint64`,BigUint64Array),ie.set(BigUint64Array,`uint64`)),r?(re.set(`float16`,n),ie.set(n,`float16`)):re.set(`float16`,Uint16Array)}}}),ce,le,ue=o(()=>{fe(),ce=e=>{let t=1;for(let n=0;n<e.length;n++){let r=e[n];if(typeof r!=`number`||!Number.isSafeInteger(r))throw TypeError(`dims[${n}] must be an integer, got: ${r}`);if(r<0)throw RangeError(`dims[${n}] must be a non-negative integer, got: ${r}`);t*=r}return t},le=(e,t)=>{switch(e.location){case`cpu`:return new de(e.type,e.data,t);case`cpu-pinned`:return new de({location:`cpu-pinned`,data:e.data,type:e.type,dims:t});case`texture`:return new de({location:`texture`,texture:e.texture,type:e.type,dims:t});case`gpu-buffer`:return new de({location:`gpu-buffer`,gpuBuffer:e.gpuBuffer,type:e.type,dims:t});case`ml-tensor`:return new de({location:`ml-tensor`,mlTensor:e.mlTensor,type:e.type,dims:t});default:throw Error(`tensorReshape: tensor location ${e.location} is not supported`)}}}),de,fe=o(()=>{T(),ne(),se(),ue(),de=class{constructor(e,t,n){oe();let r,i;if(typeof e==`object`&&`location`in e)switch(this.dataLocation=e.location,r=e.type,i=e.dims,e.location){case`cpu-pinned`:{let t=re.get(r);if(!t)throw TypeError(`unsupported type "${r}" to create tensor from pinned buffer`);if(!(e.data instanceof t))throw TypeError(`buffer should be of type ${t.name}`);this.cpuData=e.data;break}case`texture`:if(r!==`float32`)throw TypeError(`unsupported type "${r}" to create tensor from texture`);this.gpuTextureData=e.texture,this.downloader=e.download,this.disposer=e.dispose;break;case`gpu-buffer`:if(r!==`float32`&&r!==`float16`&&r!==`int32`&&r!==`int64`&&r!==`uint32`&&r!==`uint8`&&r!==`bool`&&r!==`uint4`&&r!==`int4`)throw TypeError(`unsupported type "${r}" to create tensor from gpu buffer`);this.gpuBufferData=e.gpuBuffer,this.downloader=e.download,this.disposer=e.dispose;break;case`ml-tensor`:if(r!==`float32`&&r!==`float16`&&r!==`int32`&&r!==`int64`&&r!==`uint32`&&r!==`uint64`&&r!==`int8`&&r!==`uint8`&&r!==`bool`&&r!==`uint4`&&r!==`int4`)throw TypeError(`unsupported type "${r}" to create tensor from MLTensor`);this.mlTensorData=e.mlTensor,this.downloader=e.download,this.disposer=e.dispose;break;default:throw Error(`Tensor constructor: unsupported location '${this.dataLocation}'`)}else{let a,o;if(typeof e==`string`)if(r=e,o=n,e===`string`){if(!Array.isArray(t))throw TypeError(`A string tensor's data must be a string array.`);a=t}else{let n=re.get(e);if(n===void 0)throw TypeError(`Unsupported tensor type: ${e}.`);if(Array.isArray(t)){if(e===`float16`&&n===Uint16Array||e===`uint4`||e===`int4`)throw TypeError(`Creating a ${e} tensor from number array is not supported. Please use ${n.name} as data.`);a=e===`uint64`||e===`int64`?n.from(t,BigInt):n.from(t)}else if(t instanceof n)a=t;else if(t instanceof Uint8ClampedArray)if(e===`uint8`)a=Uint8Array.from(t);else throw TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`);else if(e===`float16`&&t instanceof Uint16Array&&n!==Uint16Array)a=new globalThis.Float16Array(t.buffer,t.byteOffset,t.length);else throw TypeError(`A ${r} tensor's data must be type of ${n}`)}else if(o=t,Array.isArray(e)){if(e.length===0)throw TypeError(`Tensor type cannot be inferred from an empty array.`);let t=typeof e[0];if(t===`string`)r=`string`,a=e;else if(t===`boolean`)r=`bool`,a=Uint8Array.from(e);else throw TypeError(`Invalid element type of data array: ${t}.`)}else if(e instanceof Uint8ClampedArray)r=`uint8`,a=Uint8Array.from(e);else{let t=ie.get(e.constructor);if(t===void 0)throw TypeError(`Unsupported type for tensor data: ${e.constructor}.`);r=t,a=e}if(o===void 0)o=[a.length];else if(!Array.isArray(o))throw TypeError(`A tensor's dims must be a number array`);i=o,this.cpuData=a,this.dataLocation=`cpu`}let a=ce(i);if(this.cpuData&&a!==this.cpuData.length&&!((r===`uint4`||r===`int4`)&&Math.ceil(a/2)===this.cpuData.length))throw Error(`Tensor's size(${a}) does not match data length(${this.cpuData.length}).`);this.type=r,this.dims=i,this.size=a}static async fromImage(e,t){return D(e,t)}static fromTexture(e,t){return O(e,t)}static fromGpuBuffer(e,t){return k(e,t)}static fromMLTensor(e,t){return A(e,t)}static fromPinnedBuffer(e,t,n){return te(e,t,n)}toDataURL(e){return w(this,e)}toImageData(e){return ee(this,e)}get data(){if(this.ensureValid(),!this.cpuData)throw Error("The data is not on CPU. Use `getData()` to download GPU data to CPU, or use `texture` or `gpuBuffer` property to access the GPU data directly.");return this.cpuData}get location(){return this.dataLocation}get texture(){if(this.ensureValid(),!this.gpuTextureData)throw Error(`The data is not stored as a WebGL texture.`);return this.gpuTextureData}get gpuBuffer(){if(this.ensureValid(),!this.gpuBufferData)throw Error(`The data is not stored as a WebGPU buffer.`);return this.gpuBufferData}get mlTensor(){if(this.ensureValid(),!this.mlTensorData)throw Error(`The data is not stored as a WebNN MLTensor.`);return this.mlTensorData}async getData(e){switch(this.ensureValid(),this.dataLocation){case`cpu`:case`cpu-pinned`:return this.data;case`texture`:case`gpu-buffer`:case`ml-tensor`:if(!this.downloader)throw Error(`The current tensor is not created with a specified data downloader.`);if(this.isDownloading)throw Error(`The current tensor is being downloaded.`);try{this.isDownloading=!0;let t=await this.downloader();return this.downloader=void 0,this.dataLocation=`cpu`,this.cpuData=t,e&&this.disposer&&(this.disposer(),this.disposer=void 0),t}finally{this.isDownloading=!1}default:throw Error(`cannot get data from location: ${this.dataLocation}`)}}dispose(){if(this.isDownloading)throw Error(`The current tensor is being downloaded.`);this.disposer&&=(this.disposer(),void 0),this.cpuData=void 0,this.gpuTextureData=void 0,this.gpuBufferData=void 0,this.mlTensorData=void 0,this.downloader=void 0,this.isDownloading=void 0,this.dataLocation=`none`}ensureValid(){if(this.dataLocation===`none`)throw Error(`The tensor is disposed.`)}reshape(e){if(this.ensureValid(),this.downloader||this.disposer)throw Error(`Cannot reshape a tensor that owns GPU resource.`);return le(this,e)}}}),pe,me=o(()=>{fe(),pe=de}),he,ge,_e,j,M,ve,ye=o(()=>{x(),he=(e,t)=>{(typeof b.trace>`u`?!b.wasm.trace:!b.trace)||console.timeStamp(`${e}::ORT::${t}`)},ge=(e,t)=>{let n=Error().stack?.split(/\r\n|\r|\n/g)||[],r=!1;for(let i=0;i<n.length;i++){if(r&&!n[i].includes(`TRACE_FUNC`)){let r=`FUNC_${e}::${n[i].trim().split(` `)[1]}`;t&&(r+=`::${t}`),he(`CPU`,r);return}n[i].includes(`TRACE_FUNC`)&&(r=!0)}},_e=e=>{(typeof b.trace>`u`?!b.wasm.trace:!b.trace)||ge(`BEGIN`,e)},j=e=>{(typeof b.trace>`u`?!b.wasm.trace:!b.trace)||ge(`END`,e)},M=e=>{(typeof b.trace>`u`?!b.wasm.trace:!b.trace)||console.time(`ORT::${e}`)},ve=e=>{(typeof b.trace>`u`?!b.wasm.trace:!b.trace)||console.timeEnd(`ORT::${e}`)}}),be,xe=o(()=>{h(),me(),ye(),be=class e{constructor(e){this.handler=e}async run(e,t,n){_e(),M(`InferenceSession.run`);let r={},i={};if(typeof e!=`object`||!e||e instanceof pe||Array.isArray(e))throw TypeError(`'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.`);let a=!0;if(typeof t==`object`){if(t===null)throw TypeError(`Unexpected argument[1]: cannot be null.`);if(t instanceof pe)throw TypeError(`'fetches' cannot be a Tensor`);if(Array.isArray(t)){if(t.length===0)throw TypeError(`'fetches' cannot be an empty array.`);a=!1;for(let e of t){if(typeof e!=`string`)throw TypeError(`'fetches' must be a string array or an object.`);if(this.outputNames.indexOf(e)===-1)throw RangeError(`'fetches' contains invalid output name: ${e}.`);r[e]=null}if(typeof n==`object`&&n)i=n;else if(typeof n<`u`)throw TypeError(`'options' must be an object.`)}else{let e=!1,o=Object.getOwnPropertyNames(t);for(let n of this.outputNames)if(o.indexOf(n)!==-1){let i=t[n];(i===null||i instanceof pe)&&(e=!0,a=!1,r[n]=i)}if(e){if(typeof n==`object`&&n)i=n;else if(typeof n<`u`)throw TypeError(`'options' must be an object.`)}else i=t}}else if(typeof t<`u`)throw TypeError(`Unexpected argument[1]: must be 'fetches' or 'options'.`);for(let t of this.inputNames)if(typeof e[t]>`u`)throw Error(`input '${t}' is missing in 'feeds'.`);if(a)for(let e of this.outputNames)r[e]=null;let o=await this.handler.run(e,r,i),s={};for(let e in o)if(Object.hasOwnProperty.call(o,e)){let t=o[e];t instanceof pe?s[e]=t:s[e]=new pe(t.type,t.data,t.dims)}return ve(`InferenceSession.run`),j(),s}async release(){return this.handler.dispose()}static async create(t,n,r,i){_e(),M(`InferenceSession.create`);let a,o={};if(typeof t==`string`){if(a=t,typeof n==`object`&&n)o=n;else if(typeof n<`u`)throw TypeError(`'options' must be an object.`)}else if(t instanceof Uint8Array){if(a=t,typeof n==`object`&&n)o=n;else if(typeof n<`u`)throw TypeError(`'options' must be an object.`)}else if(t instanceof ArrayBuffer||typeof SharedArrayBuffer<`u`&&t instanceof SharedArrayBuffer){let e=t,s=0,c=t.byteLength;if(typeof n==`object`&&n)o=n;else if(typeof n==`number`){if(s=n,!Number.isSafeInteger(s))throw RangeError(`'byteOffset' must be an integer.`);if(s<0||s>=e.byteLength)throw RangeError(`'byteOffset' is out of range [0, ${e.byteLength}).`);if(c=t.byteLength-s,typeof r==`number`){if(c=r,!Number.isSafeInteger(c))throw RangeError(`'byteLength' must be an integer.`);if(c<=0||s+c>e.byteLength)throw RangeError(`'byteLength' is out of range (0, ${e.byteLength-s}].`);if(typeof i==`object`&&i)o=i;else if(typeof i<`u`)throw TypeError(`'options' must be an object.`)}else if(typeof r<`u`)throw TypeError(`'byteLength' must be a number.`)}else if(typeof n<`u`)throw TypeError(`'options' must be an object.`);a=new Uint8Array(e,s,c)}else throw TypeError(`Unexpected argument[0]: must be 'path' or 'buffer'.`);let[s,c]=await m(o),l=await s.createInferenceSessionHandler(a,c);return ve(`InferenceSession.create`),j(),new e(l)}startProfiling(){this.handler.startProfiling()}endProfiling(){this.handler.endProfiling()}get inputNames(){return this.handler.inputNames}get outputNames(){return this.handler.outputNames}get inputMetadata(){return this.handler.inputMetadata}get outputMetadata(){return this.handler.outputMetadata}}}),Se,Ce=o(()=>{xe(),Se=be}),we=o(()=>{}),Te=o(()=>{}),Ee=o(()=>{}),De=o(()=>{}),Oe={};s(Oe,{InferenceSession:()=>Se,TRACE:()=>he,TRACE_EVENT_BEGIN:()=>M,TRACE_EVENT_END:()=>ve,TRACE_FUNC_BEGIN:()=>_e,TRACE_FUNC_END:()=>j,Tensor:()=>pe,env:()=>S,registerBackend:()=>f});var N=o(()=>{g(),C(),Ce(),me(),we(),Te(),ye(),Ee(),De()}),ke=o(()=>{}),Ae={};s(Ae,{default:()=>Ne});var je,Me,Ne,Pe=o(()=>{Wu(),ct(),$e(),je=`ort-wasm-proxy-worker`,Me=globalThis.self?.name===je,Me&&(self.onmessage=e=>{let{type:t,in:n}=e.data;try{switch(t){case`init-wasm`:st(n.wasm).then(()=>{Mu(n).then(()=>{postMessage({type:t})},e=>{postMessage({type:t,err:e})})},e=>{postMessage({type:t,err:e})});break;case`init-ep`:{let{epName:e,env:r}=n;Nu(r,e).then(()=>{postMessage({type:t})},e=>{postMessage({type:t,err:e})});break}case`copy-from`:{let{buffer:e}=n,r=Lu(e);postMessage({type:t,out:r});break}case`create`:{let{model:e,options:r}=n;Ru(e,r).then(e=>{postMessage({type:t,out:e})},e=>{postMessage({type:t,err:e})});break}case`release`:zu(n),postMessage({type:t});break;case`run`:{let{sessionId:e,inputIndices:r,inputs:i,outputIndices:a,options:o}=n;Vu(e,r,i,a,Array(a.length).fill(null),o).then(e=>{e.some(e=>e[3]!==`cpu`)?postMessage({type:t,err:`Proxy does not support non-cpu tensor location.`}):postMessage({type:t,out:e},Uu([...i,...e]))},e=>{postMessage({type:t,err:e})});break}case`end-profiling`:Hu(n),postMessage({type:t});break;default:}}catch(e){postMessage({type:t,err:e})}}),Ne=Me?null:e=>new Worker(e??Ue,{type:`module`,name:je})}),Fe={};s(Fe,{default:()=>Le});async function Ie(e={}){var t=e,n=!!globalThis.window,r=!!globalThis.WorkerGlobalScope,i=r&&self.name?.startsWith(`em-pthread`);t.mountExternalData=(e,n)=>{e.startsWith(`./`)&&(e=e.substring(2)),(t.Zc||=new Map).set(e,n)},t.unmountExternalData=()=>{delete t.Zc},globalThis.SharedArrayBuffer??new WebAssembly.Memory({initial:0,maximum:0,ae:!0}).buffer.constructor;let a=e=>async(...n)=>{try{if(t.$c)throw Error(`Session already started`);let r=t.$c={Nd:n[0],errors:[]},i=await e(...n);if(t.$c!==r)throw Error(`Session mismatch`);t.gd?.flush();let a=r.errors;if(0<a.length){let e=await Promise.all(a);if(e=e.filter(e=>e),0<e.length)throw Error(e.join(` | |
| `))}return i}finally{t.$c=null}};t.jsepInit=(e,n)=>{if(e===`webgpu`){[t.gd,t.Dd,t.Hd,t.jd,t.Gd,t.ac,t.Id,t.Kd,t.Ed,t.Fd,t.Jd]=n;let e=t.gd;t.jsepRegisterBuffer=(t,n,r,i)=>e.registerBuffer(t,n,r,i),t.jsepGetBuffer=t=>e.getBuffer(t),t.jsepCreateDownloader=(t,n,r)=>e.createDownloader(t,n,r),t.jsepOnCreateSession=t=>{e.onCreateSession(t)},t.jsepOnReleaseSession=t=>{e.onReleaseSession(t)},t.jsepOnRunStart=t=>e.onRunStart(t),t.Ld=(t,n)=>{e.upload(t,n)}}else if(e===`webnn`){let e=n[0];[t.Zd,t.vd,t.webnnEnsureTensor,t.wd,t.webnnDownloadTensor,t.Yd,t.webnnEnableTraceEvent]=n.slice(1),t.webnnReleaseTensorId=t.vd,t.webnnUploadTensor=t.wd,t.webnnRegisterMLContext=t.Yd,t.webnnOnRunStart=t=>e.onRunStart(t),t.webnnOnRunEnd=e.onRunEnd.bind(e),t.webnnOnReleaseSession=t=>{e.onReleaseSession(t)},t.webnnCreateMLTensorDownloader=(t,n)=>e.createMLTensorDownloader(t,n),t.webnnRegisterMLTensor=(t,n,r,i)=>e.registerMLTensor(t,n,r,i),t.webnnCreateMLContext=t=>e.createMLContext(t),t.webnnRegisterMLConstant=(n,r,i,a,o,s)=>e.registerMLConstant(n,r,i,a,o,t.Zc,s),t.webnnRegisterGraphInput=e.registerGraphInput.bind(e),t.webnnIsGraphInput=e.isGraphInput.bind(e),t.webnnRegisterGraphOutput=e.registerGraphOutput.bind(e),t.webnnIsGraphOutput=e.isGraphOutput.bind(e),t.webnnCreateTemporaryTensor=e.createTemporaryTensor.bind(e),t.webnnIsGraphInputOutputTypeSupported=e.isGraphInputOutputTypeSupported.bind(e)}};let o=()=>{let e=e=>(...t)=>{let n=Qt;return t=e(...t),Qt==n?t:new Promise((e,t)=>{on={resolve:e,reject:t}})};(()=>{for(let n of[`_OrtAppendExecutionProvider`,`_OrtCreateSession`,`_OrtRun`,`_OrtRunWithBinding`,`_OrtBindInput`])t[n]=e(t[n])})(),a!==void 0&&(t._OrtRun=a(t._OrtRun),t._OrtRunWithBinding=a(t._OrtRunWithBinding)),o=void 0};t.asyncInit=()=>{o?.()};var s,c,l=(e,t)=>{throw t},u=import.meta.url,d=``;if(n||r){try{d=new URL(`.`,u).href}catch{}r&&(c=e=>{var t=new XMLHttpRequest;return t.open(`GET`,e,!1),t.responseType=`arraybuffer`,t.send(null),new Uint8Array(t.response)}),s=async e=>{if(C(e))return new Promise((t,n)=>{var r=new XMLHttpRequest;r.open(`GET`,e,!0),r.responseType=`arraybuffer`,r.onload=()=>{r.status==200||r.status==0&&r.response?t(r.response):n(r.status)},r.onerror=n,r.send(null)});var t=await fetch(e,{credentials:`same-origin`});if(t.ok)return t.arrayBuffer();throw Error(t.status+` : `+t.url)}}var f,p,m,h,g,_,v=console.log.bind(console),y=console.error.bind(console),b=v,x=y,S=!1,C=e=>e.startsWith(`file://`);function w(){N.buffer!=T.buffer&&se()}if(i){let e=function(n){try{var r=n.data,i=r.Uc;if(i===`load`){let n=[];self.onmessage=e=>n.push(e),_=()=>{postMessage({Uc:`loaded`});for(let t of n)e(t);self.onmessage=e};for(let e of r.Ad)t[e]&&!t[e].proxy||(t[e]=(...t)=>{postMessage({Uc:`callHandler`,zd:e,args:t})},e==`print`&&(b=t[e]),e==`printErr`&&(x=t[e]));N=r.Vd,se(),p=r.Wd,de(),ia()}else if(i===`run`){(function(e){var t=(w(),A)[e+52>>>2>>>0];e=(w(),A)[e+56>>>2>>>0],Sr(t,t-e),Z(t)})(r.Tc),mr(r.Tc,0,0,1,0,0),Ee(),Ht(r.Tc),ee||=(ur(),!0);try{ke(r.Pd,r.dd)}catch(e){if(e!=`unwind`)throw e}}else r.target!==`setimmediate`&&(i===`checkMailbox`?ee&&Ut():i&&(x(`worker: received unknown command ${i}`),x(r)))}catch(e){throw hr(),e}};var ee=!1;self.onunhandledrejection=e=>{throw e.reason||e},self.onmessage=e}var T,E,D,O,k,A,te,ne,re,ie,ae,oe=!1;function se(){var e=N.buffer;t.HEAP8=T=new Int8Array(e),D=new Int16Array(e),t.HEAPU8=E=new Uint8Array(e),O=new Uint16Array(e),t.HEAP32=k=new Int32Array(e),t.HEAPU32=A=new Uint32Array(e),te=new Float32Array(e),ne=new Float64Array(e),re=new BigInt64Array(e),ie=new BigUint64Array(e)}function ce(){oe=!0,i?_():hi.tb()}function le(e){throw x(e=`Aborted(`+e+`)`),S=!0,e=new WebAssembly.RuntimeError(e+`. Build with -sASSERTIONS for more info.`),g?.(e),e}function ue(){return{a:{ma:yi,hb:vi,g:Me,J:Pe,f:ze,n:Be,h:Ve,ha:He,b:Ue,T:We,Ia:Ke,o:qe,_:Ze,Ya:Qe,Ea:$e,Ga:et,Za:tt,Wa:nt,Pa:rt,Va:it,ka:at,Fa:ot,Ca:st,Xa:F,Da:ct,cb:lt,ea:gt,xa:_t,va:Tt,da:Dt,O:Ot,H:L,wa:jt,Z:Rt,ya:zt,Sa:z,Aa:B,Ja:Gt,ta:Kt,fa:qt,Ra:Ht,$a:Jt,R:ln,r:gn,c:bt,ib:_n,y:vn,M:yn,D:bn,m:V,t:H,jb:xn,I:Sn,S:U,j:Cn,u:W,q:G,l:wn,Ma:Tn,Na:q,Oa:J,Ka:On,La:kn,ua:Y,eb:Mn,bb:Fn,v:Rn,aa:zn,ga:Bn,ab:Nn,V:Vn,_a:Hn,Ba:Un,F:jn,U:Wn,la:Yn,za:Xn,gb:Jn,fb:Zn,Ta:tr,Ua:nr,Ha:be,$:rr,ja:ir,Qa:ar,ia:sr,lb:na,na:Yi,mb:ta,oa:Ji,G:zi,d:Ci,s:xi,w:bi,B:Ni,pb:Gi,K:Ii,x:Ti,pa:Ki,X:Xi,ba:Wi,nb:ea,ob:$i,ra:Bi,qa:Ui,qb:Vi,N:Li,Y:qi,e:wi,A:Ei,k:Si,kb:ra,p:Oi,z:ki,C:Di,E:Ai,L:Pi,rb:Ri,Q:Zi,ca:Fi,W:Qi,sb:Mi,sa:ji,P:Hi,i:cr,a:N,db:ve}}}async function de(){function e(e,n){var r=hi=e.exports;e={};for(let[t,n]of Object.entries(r))typeof n==`function`?(r=Xt(n),e[t]=r):e[t]=n;return hi=e,hi=function(){var e=hi,t=e=>t=>e(t)>>>0,n=e=>()=>e()>>>0;return(e=Object.assign({},e)).ub=t(e.ub),e.Yb=n(e.Yb),e._b=t(e._b),e.mc=t(e.mc),e.nc=n(e.nc),e.rc=t(e.rc),e}(),Ce.push(hi.$b),lr=(e=hi).ub,ur=e.vb,t._OrtInit=e.wb,t._OrtGetLastError=e.xb,t._OrtCreateSessionOptions=e.yb,t._OrtAppendExecutionProvider=e.zb,t._OrtAddFreeDimensionOverride=e.Ab,t._OrtAddSessionConfigEntry=e.Bb,t._OrtReleaseSessionOptions=e.Cb,t._OrtCreateSession=e.Db,t._OrtReleaseSession=e.Eb,t._OrtGetInputOutputCount=e.Fb,t._OrtGetInputOutputMetadata=e.Gb,t._OrtFree=e.Hb,t._OrtCreateTensor=e.Ib,t._OrtGetTensorData=e.Jb,t._OrtReleaseTensor=e.Kb,t._OrtCreateRunOptions=e.Lb,t._OrtAddRunConfigEntry=e.Mb,t._OrtReleaseRunOptions=e.Nb,t._OrtCreateBinding=e.Ob,t._OrtBindInput=e.Pb,t._OrtBindOutput=e.Qb,t._OrtClearBoundOutputs=e.Rb,t._OrtReleaseBinding=e.Sb,t._OrtRunWithBinding=e.Tb,t._OrtRun=e.Ub,t._OrtEndProfiling=e.Vb,t._JsepOutput=e.Wb,t._JsepGetNodeName=e.Xb,dr=e.Yb,fr=t._free=e.Zb,pr=t._malloc=e._b,mr=e.bc,hr=e.cc,gr=e.dc,_r=e.ec,vr=e.fc,yr=e.gc,br=e.hc,X=e.ic,xr=e.jc,Sr=e.kc,Z=e.lc,Cr=e.mc,Q=e.nc,wr=e.oc,Tr=e.pc,Er=e.qc,Dr=e.rc,Or=e.sc,kr=e.tc,Ar=e.uc,jr=e.vc,Mr=e.wc,Nr=e.xc,Pr=e.yc,Fr=e.zc,Ir=e.Ac,Lr=e.Bc,Rr=e.Cc,zr=e.Dc,Br=e.Ec,Vr=e.Fc,Hr=e.Gc,Ur=e.Hc,Wr=e.Ic,Gr=e.Jc,Kr=e.Kc,qr=e.Lc,Jr=e.Mc,Yr=e.Nc,Xr=e.Oc,Zr=e.Pc,Qr=e.Rc,$r=e.Sc,ei=e.bd,ti=e.cd,ni=e.hd,ri=e.ld,$=e.md,ii=e.nd,ai=e.od,oi=e.pd,si=e.qd,ci=e.rd,li=e.sd,ui=e.xd,di=e.Rd,fi=e.Sd,pi=e.Td,mi=e.Ud,p=n,hi}var n,r=ue();return t.instantiateWasm?new Promise(n=>{t.instantiateWasm(r,(t,r)=>{n(e(t,r))})}):i?e(new WebAssembly.Instance(p,ue()),p):(ae??=t.locateFile?t.locateFile?t.locateFile(`ort-wasm-simd-threaded.jsep.wasm`,d):d+`ort-wasm-simd-threaded.jsep.wasm`:new URL(`/assets/ort-wasm-simd-threaded.jsep-Bzonanhp.wasm`,``+import.meta.url).href,n=await async function(e){var t=ae;if(!f&&!C(t))try{var n=fetch(t,{credentials:`same-origin`});return await WebAssembly.instantiateStreaming(n,e)}catch(e){x(`wasm streaming compile failed: ${e}`),x(`falling back to ArrayBuffer instantiation`)}return async function(e,t){try{var n=await async function(e){if(!f)try{var t=await s(e);return new Uint8Array(t)}catch{}if(e==ae&&f)e=new Uint8Array(f);else{if(!c)throw`both async and sync fetching of the wasm failed`;e=c(e)}return e}(e);return await WebAssembly.instantiate(n,t)}catch(e){x(`failed to asynchronously prepare wasm: ${e}`),le(e)}}(t,e)}(r),e(n.instance,n.module))}class fe{name=`ExitStatus`;constructor(e){this.message=`Program terminated with exit(${e})`,this.status=e}}var pe=e=>{e.terminate(),e.onmessage=()=>{}},me=[],he=0,ge=null,_e=e=>{xe.length==0&&(Oe(),De(xe[0]));var t=xe.pop();if(!t)return 6;Se.push(t),we[e.Tc]=t,t.Tc=e.Tc;var n={Uc:`run`,Pd:e.Od,dd:e.dd,Tc:e.Tc};return t.postMessage(n,e.ud),0},j=0,M=(e,t,...n)=>{var r,i=16*n.length,a=Q(),o=Cr(i),s=o>>>3;for(r of n)typeof r==`bigint`?((w(),re)[s++>>>0]=1n,(w(),re)[s++>>>0]=r):((w(),re)[s++>>>0]=0n,(w(),ne)[s++>>>0]=r);return e=gr(e,0,i,o,t),Z(a),e};function ve(e){if(i)return M(0,1,e);if(m=e,!(0<j)){for(var t of Se)pe(t);for(t of xe)pe(t);xe=[],Se=[],we={},S=!0}l(0,new fe(e))}function ye(e){if(i)return M(1,0,e);be(e)}var be=e=>{if(m=e,i)throw ye(e),`unwind`;ve(e)},xe=[],Se=[],Ce=[],we={},Te=e=>{var t=e.Tc;delete we[t],xe.push(e),Se.splice(Se.indexOf(e),1),e.Tc=0,_r(t)};function Ee(){Ce.forEach(e=>e())}var De=e=>new Promise(n=>{e.onmessage=r=>{var i=r.data;if(r=i.Uc,i.ad&&i.ad!=dr()){var a=we[i.ad];a?a.postMessage(i,i.ud):x(`Internal error! Worker sent a message "${r}" to target pthread ${i.ad}, but that thread no longer exists!`)}else r===`checkMailbox`?Ut():r===`spawnThread`?_e(i):r===`cleanupThread`?Bt(()=>{Te(we[i.Qd])}):r===`loaded`?(e.loaded=!0,n(e)):i.target===`setimmediate`?e.postMessage(i):r===`uncaughtException`?e.onerror(i.error):r===`callHandler`?t[i.zd](...i.args):r&&x(`worker sent an unknown command ${r}`)},e.onerror=e=>{throw x(`worker sent an error! ${e.filename}:${e.lineno}: ${e.message}`),e};var r,i=[];for(r of[])t.propertyIsEnumerable(r)&&i.push(r);e.postMessage({Uc:`load`,Ad:i,Vd:N,Wd:p})});function Oe(){var e=new Worker((()=>{let e=URL;return import.meta.url>`file:`&&import.meta.url<`file;`?new e(`ort.bundle.min.mjs`,import.meta.url):new URL(import.meta.url)})(),{type:`module`,workerData:`em-pthread`,name:`em-pthread`});xe.push(e)}var N,ke=(e,t)=>{j=0,e=kr(e,t),0<j?m=e:vr(e)},Ae=[],je=0;function Me(e){var t=new Le(e>>>=0);return(w(),T)[t.Vc+12>>>0]==0&&(Fe(t,!0),je--),Ie(t,!1),Ae.push(t),Dr(e)}var Ne=0,Pe=()=>{X(0,0);var e=Ae.pop();wr(e.ed),Ne=0};function Fe(e,t){t=t?1:0,(w(),T)[e.Vc+12>>>0]=t}function Ie(e,t){t=t?1:0,(w(),T)[e.Vc+13>>>0]=t}class Le{constructor(e){this.ed=e,this.Vc=e-24}}var Re=e=>{var t=Ne;if(!t)return xr(0),0;var n=new Le(t);(w(),A)[n.Vc+16>>>2>>>0]=t;var r=(w(),A)[n.Vc+4>>>2>>>0];if(!r)return xr(0),t;for(var i of e){if(i===0||i===r)break;if(Er(i,r,n.Vc+16))return xr(i),t}return xr(r),t};function ze(){return Re([])}function Be(e){return Re([e>>>0])}function Ve(e,t,n,r){return Re([e>>>0,t>>>0,n>>>0,r>>>0])}var He=()=>{var e=Ae.pop();e||le(`no exception to throw`);var t=e.ed;throw(w(),T)[e.Vc+13>>>0]==0&&(Ae.push(e),Ie(e,!0),Fe(e,!1),je++),Tr(t),Ne=t};function Ue(e,t,n){var r=new Le(e>>>=0);throw t>>>=0,n>>>=0,(w(),A)[r.Vc+16>>>2>>>0]=0,(w(),A)[r.Vc+4>>>2>>>0]=t,(w(),A)[r.Vc+8>>>2>>>0]=n,Tr(e),je++,Ne=e}var We=()=>je;function Ge(e,t,n,r){return i?M(2,1,e,t,n,r):Ke(e,t,n,r)}function Ke(e,t,n,r){if(e>>>=0,t>>>=0,n>>>=0,r>>>=0,!globalThis.SharedArrayBuffer)return 6;var a=[];return i&&a.length===0?Ge(e,t,n,r):(e={Od:n,Tc:e,dd:r,ud:a},i?(e.Uc=`spawnThread`,postMessage(e,a),0):_e(e))}function qe(e){throw Ne||=e>>>0,Ne}var Je=globalThis.TextDecoder&&new TextDecoder,Ye=(e,t,n,r)=>{if(n=t+n,r)return n;for(;e[t]&&!(t>=n);)++t;return t},Xe=(e,t=0,n,r)=>{if(16<(n=Ye(e,t>>>=0,n,r))-t&&e.buffer&&Je)return Je.decode(e.buffer instanceof ArrayBuffer?e.subarray(t,n):e.slice(t,n));for(r=``;t<n;){var i=e[t++];if(128&i){var a=63&e[t++];if((224&i)==192)r+=String.fromCharCode((31&i)<<6|a);else{var o=63&e[t++];65536>(i=(240&i)==224?(15&i)<<12|a<<6|o:(7&i)<<18|a<<12|o<<6|63&e[t++])?r+=String.fromCharCode(i):(i-=65536,r+=String.fromCharCode(55296|i>>10,56320|1023&i))}}else r+=String.fromCharCode(i)}return r},P=(e,t,n)=>(e>>>=0)?Xe((w(),E),e,t,n):``;function Ze(e,t,n){return i?M(3,1,e,t,n):0}function Qe(e,t){if(i)return M(4,1,e,t)}function $e(e,t){if(i)return M(5,1,e,t)}function et(e,t,n){if(i)return M(6,1,e,t,n)}function tt(e,t,n){return i?M(7,1,e,t,n):0}function nt(e,t){if(i)return M(8,1,e,t)}function rt(e,t,n){if(i)return M(9,1,e,t,n)}function it(e,t,n,r){if(i)return M(10,1,e,t,n,r)}function at(e,t,n,r){if(i)return M(11,1,e,t,n,r)}function ot(e,t,n,r){if(i)return M(12,1,e,t,n,r)}function st(e){if(i)return M(13,1,e)}function F(e,t){if(i)return M(14,1,e,t)}function ct(e,t,n){if(i)return M(15,1,e,t,n)}var lt=()=>le(``),ut=e=>{e>>>=0;for(var t=``;;){var n=(w(),E)[e++>>>0];if(!n)return t;t+=String.fromCharCode(n)}},I={},dt={},ft={},pt=class extends Error{constructor(e){super(e),this.name=`BindingError`}};function mt(e,t,n={}){return function(e,t,n={}){var r=t.name;if(!e)throw new pt(`type "${r}" must have a positive integer typeid pointer`);if(dt.hasOwnProperty(e)){if(n.Bd)return;throw new pt(`Cannot register type '${r}' twice`)}dt[e]=t,delete ft[e],I.hasOwnProperty(e)&&(t=I[e],delete I[e],t.forEach(e=>e()))}(e,t,n)}var ht=(e,t,n)=>{switch(t){case 1:return n?e=>(w(),T)[e>>>0]:e=>(w(),E)[e>>>0];case 2:return n?e=>(w(),D)[e>>>1>>>0]:e=>(w(),O)[e>>>1>>>0];case 4:return n?e=>(w(),k)[e>>>2>>>0]:e=>(w(),A)[e>>>2>>>0];case 8:return n?e=>(w(),re)[e>>>3>>>0]:e=>(w(),ie)[e>>>3>>>0];default:throw TypeError(`invalid integer width (${t}): ${e}`)}};function gt(e,t,n,r,i){e>>>=0,n>>>=0,t=ut(t>>>0);let a=e=>e;if(r=r===0n){let e=8*n;a=t=>BigInt.asUintN(e,t),i=a(i)}mt(e,{name:t,Qc:a,Xc:(e,t)=>(typeof t==`number`&&(t=BigInt(t)),t),Wc:ht(t,n,!r),Yc:null})}function _t(e,t,n,r){mt(e>>>=0,{name:t=ut(t>>>0),Qc:function(e){return!!e},Xc:function(e,t){return t?n:r},Wc:function(e){return this.Qc((w(),E)[e>>>0])},Yc:null})}var vt=[],yt=[0,1,,1,null,1,!0,1,!1,1];function bt(e){9<(e>>>=0)&&--yt[e+1]==0&&(yt[e]=void 0,vt.push(e))}var xt=e=>{if(!e)throw new pt(`Cannot use deleted val. handle = ${e}`);return yt[e]},St=e=>{switch(e){case void 0:return 2;case null:return 4;case!0:return 6;case!1:return 8;default:let t=vt.pop()||yt.length;return yt[t]=e,yt[t+1]=1,t}};function Ct(e){return this.Qc((w(),A)[e>>>2>>>0])}var wt={name:`emscripten::val`,Qc:e=>{var t=xt(e);return bt(e),t},Xc:(e,t)=>St(t),Wc:Ct,Yc:null};function Tt(e){return mt(e>>>0,wt)}var Et=(e,t)=>{switch(t){case 4:return function(e){return this.Qc((w(),te)[e>>>2>>>0])};case 8:return function(e){return this.Qc((w(),ne)[e>>>3>>>0])};default:throw TypeError(`invalid float width (${t}): ${e}`)}};function Dt(e,t,n){n>>>=0,mt(e>>>=0,{name:t=ut(t>>>0),Qc:e=>e,Xc:(e,t)=>t,Wc:Et(t,n),Yc:null})}function Ot(e,t,n,r,i){e>>>=0,n>>>=0,t=ut(t>>>0);let a=e=>e;if(r===0){var o=32-8*n;a=e=>e<<o>>>o,i=a(i)}mt(e,{name:t,Qc:a,Xc:(e,t)=>t,Wc:ht(t,n,r!==0),Yc:null})}function L(e,t,n){function r(e){var t=(w(),A)[e>>>2>>>0];return e=(w(),A)[e+4>>>2>>>0],new i((w(),T).buffer,e,t)}var i=[Int8Array,Uint8Array,Int16Array,Uint16Array,Int32Array,Uint32Array,Float32Array,Float64Array,BigInt64Array,BigUint64Array][t];mt(e>>>=0,{name:n=ut(n>>>0),Qc:r,Wc:r},{Bd:!0})}var kt=(e,t,n)=>{var r=(w(),E);if(t>>>=0,0<n){var i=t;n=t+n-1;for(var a=0;a<e.length;++a){var o=e.codePointAt(a);if(127>=o){if(t>=n)break;r[t++>>>0]=o}else if(2047>=o){if(t+1>=n)break;r[t++>>>0]=192|o>>6,r[t++>>>0]=128|63&o}else if(65535>=o){if(t+2>=n)break;r[t++>>>0]=224|o>>12,r[t++>>>0]=128|o>>6&63,r[t++>>>0]=128|63&o}else{if(t+3>=n)break;r[t++>>>0]=240|o>>18,r[t++>>>0]=128|o>>12&63,r[t++>>>0]=128|o>>6&63,r[t++>>>0]=128|63&o,a++}}r[t>>>0]=0,e=t-i}else e=0;return e},At=e=>{for(var t=0,n=0;n<e.length;++n){var r=e.charCodeAt(n);127>=r?t++:2047>=r?t+=2:55296<=r&&57343>=r?(t+=4,++n):t+=3}return t};function jt(e,t){mt(e>>>=0,{name:t=ut(t>>>0),Qc(e){var t=(w(),A)[e>>>2>>>0];return t=P(e+4,t,!0),fr(e),t},Xc(e,t){t instanceof ArrayBuffer&&(t=new Uint8Array(t));var n=typeof t==`string`;if(!(n||ArrayBuffer.isView(t)&&t.BYTES_PER_ELEMENT==1))throw new pt(`Cannot pass non-string to std::string`);var r=n?At(t):t.length,i=pr(4+r+1),a=i+4;return(w(),A)[i>>>2>>>0]=r,n?kt(t,a,r+1):(w(),E).set(t,a>>>0),e!==null&&e.push(fr,i),i},Wc:Ct,Yc(e){fr(e)}})}var Mt=globalThis.TextDecoder?new TextDecoder(`utf-16le`):void 0,Nt=(e,t,n)=>{if(e>>>=1,16<(t=Ye((w(),O),e,t/2,n))-e&&Mt)return Mt.decode((w(),O).slice(e,t));for(n=``;e<t;++e){var r=(w(),O)[e>>>0];n+=String.fromCharCode(r)}return n},Pt=(e,t,n)=>{if(n??=2147483647,2>n)return 0;var r=t;n=(n-=2)<2*e.length?n/2:e.length;for(var i=0;i<n;++i){var a=e.charCodeAt(i);(w(),D)[t>>>1>>>0]=a,t+=2}return(w(),D)[t>>>1>>>0]=0,t-r},Ft=e=>2*e.length,It=(e,t,n)=>{var r=``;e>>>=2;for(var i=0;!(i>=t/4);i++){var a=(w(),A)[e+i>>>0];if(!a&&!n)break;r+=String.fromCodePoint(a)}return r},R=(e,t,n)=>{if(t>>>=0,n??=2147483647,4>n)return 0;var r=t;n=r+n-4;for(var i=0;i<e.length;++i){var a=e.codePointAt(i);if(65535<a&&i++,(w(),k)[t>>>2>>>0]=a,(t+=4)+4>n)break}return(w(),k)[t>>>2>>>0]=0,t-r},Lt=e=>{for(var t=0,n=0;n<e.length;++n)65535<e.codePointAt(n)&&n++,t+=4;return t};function Rt(e,t,n){if(e>>>=0,t>>>=0,n=ut(n>>>=0),t===2)var r=Nt,i=Pt,a=Ft;else r=It,i=R,a=Lt;mt(e,{name:n,Qc:e=>{var n=(w(),A)[e>>>2>>>0];return n=r(e+4,n*t,!0),fr(e),n},Xc:(e,r)=>{if(typeof r!=`string`)throw new pt(`Cannot pass non-string to C++ string type ${n}`);var o=a(r),s=pr(4+o+t);return(w(),A)[s>>>2>>>0]=o/t,i(r,s+4,o+t),e!==null&&e.push(fr,s),s},Wc:Ct,Yc(e){fr(e)}})}function zt(e,t){mt(e>>>=0,{Cd:!0,name:t=ut(t>>>0),Qc:()=>{},Xc:()=>{}})}function z(e){mr(e>>>0,!r,1,!n,131072,!1),Ee()}var Bt=e=>{if(!S)try{if(e(),!(0<j))try{i?dr()&&vr(m):be(m)}catch(e){e instanceof fe||e==`unwind`||l(0,e)}}catch(e){e instanceof fe||e==`unwind`||l(0,e)}},Vt=!Atomics.waitAsync||globalThis.navigator?.userAgent&&91>Number((navigator.userAgent.match(/Chrom(e|ium)\/([0-9]+)\./)||[])[2]);function Ht(e){e>>>=0,Vt||(Atomics.waitAsync((w(),k),e>>>2,e).value.then(Ut),e+=128,Atomics.store((w(),k),e>>>2,1))}var Ut=()=>Bt(()=>{var e=dr();e&&(Ht(e),br())});function B(e,t){(e>>>=0)==t>>>0?setTimeout(Ut):i?postMessage({ad:e,Uc:`checkMailbox`}):(e=we[e])&&e.postMessage({Uc:`checkMailbox`})}var Wt=[];function Gt(e,t,n,r,i){for(t>>>=0,i>>>=0,Wt.length=0,n=i>>>3,r=i+r>>>3;n<r;){var a=(w(),re)[n++>>>0]?(w(),re)[n++>>>0]:(w(),ne)[n++>>>0];Wt.push(a)}return(t?_i[t]:gi[e])(...Wt)}var Kt=()=>{j=0};function qt(e){e>>>=0,i?postMessage({Uc:`cleanupThread`,Qd:e}):Te(we[e])}function Jt(e){}var Yt=e=>{try{e()}catch(e){le(e)}};function Xt(e){var t=(...t)=>{en.push(e);try{return e(...t)}finally{S||(en.pop(),Qt&&Zt===1&&en.length===0&&(Zt=0,j+=1,Yt(fi),typeof Fibers<`u`&&Fibers.ce()))}};return rn.set(e,t),t}var Zt=0,Qt=null,$t=0,en=[],tn=new Map,nn=new Map,rn=new Map,an=0,on=null,sn=[],cn=e=>function(e){if(!S){if(Zt===0){var t=!1,n=!1;e((e=0)=>{if(!S&&($t=e,t=!0,n)){Zt=2,Yt(()=>pi(Qt)),typeof MainLoop<`u`&&MainLoop.yd&&MainLoop.resume(),e=!1;try{var r=function(){var e=(w(),k)[Qt+8>>>2>>>0];return e=nn.get(e),e=rn.get(e),--j,e()}()}catch(t){r=t,e=!0}var i=!1;if(!Qt){var a=on;a&&(on=null,(e?a.reject:a.resolve)(r),i=!0)}if(e&&!i)throw r}}),n=!0,t||(Zt=1,Qt=function(){var e=pr(65548),t=e+12;if((w(),A)[e>>>2>>>0]=t,(w(),A)[e+4>>>2>>>0]=t+65536,t=en[0],!tn.has(t)){var n=an++;tn.set(t,n),nn.set(n,t)}return t=tn.get(t),(w(),k)[e+8>>>2>>>0]=t,e}(),typeof MainLoop<`u`&&MainLoop.yd&&MainLoop.pause(),Yt(()=>di(Qt)))}else Zt===2?(Zt=0,Yt(mi),fr(Qt),Qt=null,sn.forEach(Bt)):le(`invalid state: ${Zt}`);return $t}}(t=>{e().then(t)});function ln(e){return e>>>=0,cn(async()=>St(await xt(e)))}var un=[],dn=e=>{var t=un.length;return un.push(e),t},fn=(e,t)=>{for(var n=Array(e),r=0;r<e;++r){var i=r,a=(w(),A)[t+4*r>>>2>>>0],o=dt[a];if(o===void 0)throw e=`parameter ${r}`,a=lr(a),t=ut(a),fr(a),new pt(`${e} has unknown type ${t}`);n[i]=o}return n},pn=(e,t,n)=>{var r=[];return e=e(r,n),r.length&&((w(),A)[t>>>2>>>0]=St(r)),e},mn={},hn=e=>{var t=mn[e];return t===void 0?ut(e):t};function gn(e,t,n){var[r,...i]=fn(e,t>>>0);t=r.Xc.bind(r);var a=i.map(e=>e.Wc.bind(e));e--;var o={toValue:xt};switch(e=a.map((e,t)=>{var n=`argFromPtr${t}`;return o[n]=e,`${n}(args${t?`+`+8*t:``})`}),n){case 0:var s=`toValue(handle)`;break;case 2:s=`new (toValue(handle))`;break;case 3:s=``;break;case 1:o.getStringOrSymbol=hn,s=`toValue(handle)[getStringOrSymbol(methodName)]`}return s+=`(${e})`,r.Cd||(o.toReturnWire=t,o.emval_returnValue=pn,s=`return emval_returnValue(toReturnWire, destructorsRef, ${s})`),s=`return function (handle, methodName, destructorsRef, args) { | |
| ${s} | |
| }`,n=Function(Object.keys(o),s)(...Object.values(o)),s=`methodCaller<(${i.map(e=>e.name)}) => ${r.name}>`,dn(Object.defineProperty(n,`name`,{value:s}))}function _n(e,t){return t>>>=0,(e=xt(e>>>0))==xt(t)}function vn(e){return(e>>>=0)?(e=hn(e),St(globalThis[e])):St(globalThis)}function yn(e){return e=hn(e>>>0),St(t[e])}function bn(e,t){return t>>>=0,e=xt(e>>>0),t=xt(t),St(e[t])}function V(e){9<(e>>>=0)&&(yt[e+1]+=1)}function H(e,t,n,r,i){return un[e>>>0](t>>>0,n>>>0,r>>>0,i>>>0)}function xn(e,t,n,r,i){return H(e>>>0,t>>>0,n>>>0,r>>>0,i>>>0)}function Sn(){return St([])}function U(e){e=xt(e>>>0);for(var t=Array(e.length),n=0;n<e.length;n++)t[n]=e[n];return St(t)}function Cn(e){return St(hn(e>>>0))}function W(){return St({})}function G(e){for(var t=xt(e>>>=0);t.length;){var n=t.pop();t.pop()(n)}bt(e)}function wn(e,t,n){t>>>=0,n>>>=0,e=xt(e>>>0),t=xt(t),n=xt(n),e[t]=n}function Tn(e,t){e=-9007199254740992>e||9007199254740992<e?NaN:Number(e),t>>>=0,e=new Date(1e3*e),(w(),k)[t>>>2>>>0]=e.getUTCSeconds(),(w(),k)[t+4>>>2>>>0]=e.getUTCMinutes(),(w(),k)[t+8>>>2>>>0]=e.getUTCHours(),(w(),k)[t+12>>>2>>>0]=e.getUTCDate(),(w(),k)[t+16>>>2>>>0]=e.getUTCMonth(),(w(),k)[t+20>>>2>>>0]=e.getUTCFullYear()-1900,(w(),k)[t+24>>>2>>>0]=e.getUTCDay(),e=(e.getTime()-Date.UTC(e.getUTCFullYear(),0,1,0,0,0,0))/864e5|0,(w(),k)[t+28>>>2>>>0]=e}var En=e=>e%4==0&&(e%100!=0||e%400==0),K=[0,31,60,91,121,152,182,213,244,274,305,335],Dn=[0,31,59,90,120,151,181,212,243,273,304,334];function q(e,t){e=-9007199254740992>e||9007199254740992<e?NaN:Number(e),t>>>=0,e=new Date(1e3*e),(w(),k)[t>>>2>>>0]=e.getSeconds(),(w(),k)[t+4>>>2>>>0]=e.getMinutes(),(w(),k)[t+8>>>2>>>0]=e.getHours(),(w(),k)[t+12>>>2>>>0]=e.getDate(),(w(),k)[t+16>>>2>>>0]=e.getMonth(),(w(),k)[t+20>>>2>>>0]=e.getFullYear()-1900,(w(),k)[t+24>>>2>>>0]=e.getDay();var n=(En(e.getFullYear())?K:Dn)[e.getMonth()]+e.getDate()-1|0;(w(),k)[t+28>>>2>>>0]=n,(w(),k)[t+36>>>2>>>0]=-60*e.getTimezoneOffset(),n=new Date(e.getFullYear(),6,1).getTimezoneOffset();var r=new Date(e.getFullYear(),0,1).getTimezoneOffset();e=0|(n!=r&&e.getTimezoneOffset()==Math.min(r,n)),(w(),k)[t+32>>>2>>>0]=e}function J(e){e>>>=0;var t=new Date((w(),k)[e+20>>>2>>>0]+1900,(w(),k)[e+16>>>2>>>0],(w(),k)[e+12>>>2>>>0],(w(),k)[e+8>>>2>>>0],(w(),k)[e+4>>>2>>>0],(w(),k)[e>>>2>>>0],0),n=(w(),k)[e+32>>>2>>>0],r=t.getTimezoneOffset(),i=new Date(t.getFullYear(),6,1).getTimezoneOffset(),a=new Date(t.getFullYear(),0,1).getTimezoneOffset(),o=Math.min(a,i);return 0>n?(w(),k)[e+32>>>2>>>0]=+(i!=a&&o==r):0<n!=(o==r)&&(i=Math.max(a,i),t.setTime(t.getTime()+6e4*((0<n?o:i)-r))),(w(),k)[e+24>>>2>>>0]=t.getDay(),n=(En(t.getFullYear())?K:Dn)[t.getMonth()]+t.getDate()-1|0,(w(),k)[e+28>>>2>>>0]=n,(w(),k)[e>>>2>>>0]=t.getSeconds(),(w(),k)[e+4>>>2>>>0]=t.getMinutes(),(w(),k)[e+8>>>2>>>0]=t.getHours(),(w(),k)[e+12>>>2>>>0]=t.getDate(),(w(),k)[e+16>>>2>>>0]=t.getMonth(),(w(),k)[e+20>>>2>>>0]=t.getYear(),e=t.getTime(),BigInt(isNaN(e)?-1:e/1e3)}function On(e,t,n,r,a,o,s){return i?M(16,1,e,t,n,r,a,o,s):-52}function kn(e,t,n,r,a,o){if(i)return M(17,1,e,t,n,r,a,o)}var An={},jn=()=>performance.timeOrigin+performance.now();function Y(e,t){return i?M(18,1,e,t):(An[e]&&(clearTimeout(An[e].id),delete An[e]),t&&(An[e]={id:setTimeout(()=>{delete An[e],Bt(()=>yr(e,performance.timeOrigin+performance.now()))},t),be:t}),0)}function Mn(e,t,n,r){e>>>=0,t>>>=0,n>>>=0,r>>>=0;var i=new Date().getFullYear(),a=new Date(i,0,1).getTimezoneOffset();i=new Date(i,6,1).getTimezoneOffset();var o=Math.max(a,i);(w(),A)[e>>>2>>>0]=60*o,(w(),k)[t>>>2>>>0]=+(a!=i),e=(t=e=>{var t=Math.abs(e);return`UTC${0<=e?`-`:`+`}${String(Math.floor(t/60)).padStart(2,`0`)}${String(t%60).padStart(2,`0`)}`})(a),t=t(i),i<a?(kt(e,n,17),kt(t,r,17)):(kt(e,r,17),kt(t,n,17))}var Nn=()=>Date.now(),Pn=1;function Fn(e,t,n){if(n>>>=0,!(0<=e&&3>=e))return 28;if(e===0)e=Date.now();else{if(!Pn)return 52;e=performance.timeOrigin+performance.now()}return e=Math.round(1e6*e),(w(),re)[n>>>3>>>0]=BigInt(e),0}var In=[],Ln=(e,t)=>{In.length=0;for(var n;n=(w(),E)[e++>>>0];){var r=n!=105;t+=(r&=n!=112)&&t%8?4:0,In.push(n==112?(w(),A)[t>>>2>>>0]:n==106?(w(),re)[t>>>3>>>0]:n==105?(w(),k)[t>>>2>>>0]:(w(),ne)[t>>>3>>>0]),t+=r?8:4}return In};function Rn(e,t,n){return e>>>=0,t=Ln(t>>>0,n>>>0),_i[e](...t)}function zn(e,t,n){return e>>>=0,t=Ln(t>>>0,n>>>0),_i[e](...t)}var Bn=()=>{};function Vn(e,t){return x(P(e>>>0,t>>>0))}var Hn=()=>{throw j+=1,`unwind`};function Un(){return 4294901760}var Wn=()=>navigator.hardwareConcurrency,Gn={},Kn=e=>{var t;return(t=/\bwasm-function\[\d+\]:(0x[0-9a-f]+)/.exec(e))?+t[1]:(t=/:(\d+):\d+(?:\)|$)/.exec(e))?2147483648|t[1]:0},qn=e=>{for(var t of e)(e=Kn(t))&&(Gn[e]=t)};function Jn(){var e=Error().stack.toString().split(` | |
| `);return e[0]==`Error`&&e.shift(),qn(e),Gn.kd=Kn(e[3]),Gn.Md=e,Gn.kd}function Yn(e){if(!(e=Gn[e>>>0]))return 0;var t;if(t=/^\s+at .*\.wasm\.(.*) \(.*\)$/.exec(e))e=t[1];else if(t=/^\s+at (.*) \(.*\)$/.exec(e))e=t[1];else{if(!(t=/^(.+?)@/.exec(e)))return 0;e=t[1]}fr(Yn.td??0),t=At(e)+1;var n=pr(t);return n&&kt(e,n,t),Yn.td=n,Yn.td}function Xn(e){e>>>=0;var t=(w(),E).length;if(e<=t||4294901760<e)return!1;for(var n=1;4>=n;n*=2){var r=t*(1+.2/n);r=Math.min(r,e+100663296);e:{r=(Math.min(4294901760,65536*Math.ceil(Math.max(e,r)/65536))-N.buffer.byteLength+65535)/65536|0;try{N.grow(r),se();var i=1;break e}catch{}i=void 0}if(i)return!0}return!1}function Zn(e,t,n){if(e>>>=0,t>>>=0,Gn.kd==e)var r=Gn.Md;else (r=Error().stack.toString().split(` | |
| `))[0]==`Error`&&r.shift(),qn(r);for(var i=3;r[i]&&Kn(r[i])!=e;)++i;for(e=0;e<n&&r[e+i];++e)(w(),k)[t+4*e>>>2>>>0]=Kn(r[e+i]);return e}var Qn,$n={},er=()=>{if(!Qn){var e,t={USER:`web_user`,LOGNAME:`web_user`,PATH:`/`,PWD:`/`,HOME:`/home/web_user`,LANG:(globalThis.navigator?.language??`C`).replace(`-`,`_`)+`.UTF-8`,_:`./this.program`};for(e in $n)$n[e]===void 0?delete t[e]:t[e]=$n[e];var n=[];for(e in t)n.push(`${e}=${t[e]}`);Qn=n}return Qn};function tr(e,t){if(i)return M(19,1,e,t);e>>>=0,t>>>=0;var n,r=0,a=0;for(n of er()){var o=t+r;(w(),A)[e+a>>>2>>>0]=o,r+=kt(n,o,1/0)+1,a+=4}return 0}function nr(e,t){if(i)return M(20,1,e,t);e>>>=0,t>>>=0;var n=er();for(var r of((w(),A)[e>>>2>>>0]=n.length,e=0,n))e+=At(r)+1;return(w(),A)[t>>>2>>>0]=e,0}function rr(e){return i?M(21,1,e):52}function ir(e,t,n,r){return i?M(22,1,e,t,n,r):52}function ar(e,t,n,r){return i?M(23,1,e,t,n,r):70}var or=[null,[],[]];function sr(e,t,n,r){if(i)return M(24,1,e,t,n,r);t>>>=0,n>>>=0,r>>>=0;for(var a=0,o=0;o<n;o++){var s=(w(),A)[t>>>2>>>0],c=(w(),A)[t+4>>>2>>>0];t+=8;for(var l=0;l<c;l++){var u=e,d=(w(),E)[s+l>>>0],f=or[u];d===0||d===10?((u===1?b:x)(Xe(f)),f.length=0):f.push(d)}a+=c}return(w(),A)[r>>>2>>>0]=a,0}function cr(e){return e>>>0}i||function(){for(var e=t.numThreads-1;e--;)Oe();me.push(async()=>{var e=async function(){if(!i)return Promise.all(xe.map(De))}();he++,await e,--he==0&&ge&&(e=ge,ge=null,e())})}(),i||(N=new WebAssembly.Memory({initial:256,maximum:65536,shared:!0}),se()),t.wasmBinary&&(f=t.wasmBinary),t.stackSave=()=>Q(),t.stackRestore=e=>Z(e),t.stackAlloc=e=>Cr(e),t.setValue=function(e,t,n=`i8`){switch(n.endsWith(`*`)&&(n=`*`),n){case`i1`:case`i8`:(w(),T)[e>>>0]=t;break;case`i16`:(w(),D)[e>>>1>>>0]=t;break;case`i32`:(w(),k)[e>>>2>>>0]=t;break;case`i64`:(w(),re)[e>>>3>>>0]=BigInt(t);break;case`float`:(w(),te)[e>>>2>>>0]=t;break;case`double`:(w(),ne)[e>>>3>>>0]=t;break;case`*`:(w(),A)[e>>>2>>>0]=t;break;default:le(`invalid type for setValue: ${n}`)}},t.getValue=function(e,t=`i8`){switch(t.endsWith(`*`)&&(t=`*`),t){case`i1`:case`i8`:return(w(),T)[e>>>0];case`i16`:return(w(),D)[e>>>1>>>0];case`i32`:return(w(),k)[e>>>2>>>0];case`i64`:return(w(),re)[e>>>3>>>0];case`float`:return(w(),te)[e>>>2>>>0];case`double`:return(w(),ne)[e>>>3>>>0];case`*`:return(w(),A)[e>>>2>>>0];default:le(`invalid type for getValue: ${t}`)}},t.UTF8ToString=P,t.stringToUTF8=kt,t.lengthBytesUTF8=At;var lr,ur,dr,fr,pr,mr,hr,gr,_r,vr,yr,br,X,xr,Sr,Z,Cr,Q,wr,Tr,Er,Dr,Or,kr,Ar,jr,Mr,Nr,Pr,Fr,Ir,Lr,Rr,zr,Br,Vr,Hr,Ur,Wr,Gr,Kr,qr,Jr,Yr,Xr,Zr,Qr,$r,ei,ti,ni,ri,$,ii,ai,oi,si,ci,li,ui,di,fi,pi,mi,hi,gi=[ve,ye,Ge,Ze,Qe,$e,et,tt,nt,rt,it,at,ot,st,F,ct,On,kn,Y,tr,nr,rr,ir,ar,sr],_i={937996:(e,n,r,i,a)=>{if(t===void 0||!t.Zc)return 1;if((e=P(Number(e>>>0))).startsWith(`./`)&&(e=e.substring(2)),!(e=t.Zc.get(e)))return 2;if(n=Number(n>>>0),r=Number(r>>>0),i=Number(i>>>0),n+r>e.byteLength)return 3;try{let o=e.subarray(n,n+r);switch(a){case 0:(w(),E).set(o,i>>>0);break;case 1:t.Xd?t.Xd(i,o):t.Ld(i,o);break;default:return 4}return 0}catch{return 4}},938820:(e,n,r)=>{t.wd(e,(w(),E).subarray(n>>>0,n+r>>>0))},938884:()=>t.Zd(),938926:e=>{t.vd(e)},938963:()=>{t.Ed()},938994:()=>{t.Fd()},939023:()=>{t.Jd()},939048:e=>t.Dd(e),939081:e=>t.Hd(e),939113:(e,n,r)=>{t.jd(Number(e),Number(n),Number(r),!0)},939176:(e,n,r)=>{t.jd(Number(e),Number(n),Number(r))},939233:()=>typeof wasmOffsetConverter<`u`,939290:e=>{t.ac(`Abs`,e,void 0)},939341:e=>{t.ac(`Neg`,e,void 0)},939392:e=>{t.ac(`Floor`,e,void 0)},939445:e=>{t.ac(`Ceil`,e,void 0)},939497:e=>{t.ac(`Reciprocal`,e,void 0)},939555:e=>{t.ac(`Sqrt`,e,void 0)},939607:e=>{t.ac(`Exp`,e,void 0)},939658:e=>{t.ac(`Erf`,e,void 0)},939709:e=>{t.ac(`Sigmoid`,e,void 0)},939764:(e,n,r)=>{t.ac(`HardSigmoid`,e,{alpha:n,beta:r})},939843:e=>{t.ac(`Log`,e,void 0)},939894:e=>{t.ac(`Sin`,e,void 0)},939945:e=>{t.ac(`Cos`,e,void 0)},939996:e=>{t.ac(`Tan`,e,void 0)},940047:e=>{t.ac(`Asin`,e,void 0)},940099:e=>{t.ac(`Acos`,e,void 0)},940151:e=>{t.ac(`Atan`,e,void 0)},940203:e=>{t.ac(`Sinh`,e,void 0)},940255:e=>{t.ac(`Cosh`,e,void 0)},940307:e=>{t.ac(`Asinh`,e,void 0)},940360:e=>{t.ac(`Acosh`,e,void 0)},940413:e=>{t.ac(`Atanh`,e,void 0)},940466:e=>{t.ac(`Tanh`,e,void 0)},940518:e=>{t.ac(`Not`,e,void 0)},940569:(e,n,r)=>{t.ac(`Clip`,e,{min:n,max:r})},940638:e=>{t.ac(`Clip`,e,void 0)},940690:(e,n)=>{t.ac(`Elu`,e,{alpha:n})},940748:e=>{t.ac(`Gelu`,e,void 0)},940800:e=>{t.ac(`Relu`,e,void 0)},940852:(e,n)=>{t.ac(`LeakyRelu`,e,{alpha:n})},940916:(e,n)=>{t.ac(`ThresholdedRelu`,e,{alpha:n})},940986:(e,n)=>{t.ac(`Cast`,e,{to:n})},941044:e=>{t.ac(`Add`,e,void 0)},941095:e=>{t.ac(`Sub`,e,void 0)},941146:e=>{t.ac(`Mul`,e,void 0)},941197:e=>{t.ac(`Div`,e,void 0)},941248:e=>{t.ac(`Pow`,e,void 0)},941299:e=>{t.ac(`Equal`,e,void 0)},941352:e=>{t.ac(`Greater`,e,void 0)},941407:e=>{t.ac(`GreaterOrEqual`,e,void 0)},941469:e=>{t.ac(`Less`,e,void 0)},941521:e=>{t.ac(`LessOrEqual`,e,void 0)},941580:(e,n,r,i,a)=>{t.ac(`ReduceMean`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},941755:(e,n,r,i,a)=>{t.ac(`ReduceMax`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},941929:(e,n,r,i,a)=>{t.ac(`ReduceMin`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},942103:(e,n,r,i,a)=>{t.ac(`ReduceProd`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},942278:(e,n,r,i,a)=>{t.ac(`ReduceSum`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},942452:(e,n,r,i,a)=>{t.ac(`ReduceL1`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},942625:(e,n,r,i,a)=>{t.ac(`ReduceL2`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},942798:(e,n,r,i,a)=>{t.ac(`ReduceLogSum`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},942975:(e,n,r,i,a)=>{t.ac(`ReduceSumSquare`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},943155:(e,n,r,i,a)=>{t.ac(`ReduceLogSumExp`,e,{keepDims:!!n,noopWithEmptyAxes:!!r,axes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},943335:e=>{t.ac(`Where`,e,void 0)},943388:(e,n,r)=>{t.ac(`Transpose`,e,{perm:n?Array.from((w(),k).subarray(Number(n)>>>0,Number(r)>>>0)):[]})},943512:(e,n,r,i)=>{t.ac(`DepthToSpace`,e,{blocksize:n,mode:P(r),format:i?`NHWC`:`NCHW`})},943645:(e,n,r,i)=>{t.ac(`DepthToSpace`,e,{blocksize:n,mode:P(r),format:i?`NHWC`:`NCHW`})},943778:(e,n,r,i,a,o,s,c,l,u,d,f,p,m,h)=>{t.ac(`ConvTranspose`,e,{format:l?`NHWC`:`NCHW`,autoPad:n,dilations:[r],group:i,kernelShape:[a],pads:[o,s],strides:[c],wIsConst:()=>!!(w(),T)[u>>>0],outputPadding:d?Array.from((w(),k).subarray(Number(d)>>>0,Number(f)>>>0)):[],outputShape:p?Array.from((w(),k).subarray(Number(p)>>>0,Number(m)>>>0)):[],activation:P(h)})},944211:(e,n,r,i,a,o,s,c,l,u,d,f,p,m)=>{t.ac(`ConvTranspose`,e,{format:c?`NHWC`:`NCHW`,autoPad:n,dilations:Array.from((w(),k).subarray(Number(r)>>>0,2+(Number(r)>>>0)>>>0)),group:i,kernelShape:Array.from((w(),k).subarray(Number(a)>>>0,2+(Number(a)>>>0)>>>0)),pads:Array.from((w(),k).subarray(Number(o)>>>0,4+(Number(o)>>>0)>>>0)),strides:Array.from((w(),k).subarray(Number(s)>>>0,2+(Number(s)>>>0)>>>0)),wIsConst:()=>!!(w(),T)[l>>>0],outputPadding:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],outputShape:f?Array.from((w(),k).subarray(Number(f)>>>0,Number(p)>>>0)):[],activation:P(m)})},944872:(e,n,r,i,a,o,s,c,l,u,d,f,p,m,h)=>{t.ac(`ConvTranspose`,e,{format:l?`NHWC`:`NCHW`,autoPad:n,dilations:[r],group:i,kernelShape:[a],pads:[o,s],strides:[c],wIsConst:()=>!!(w(),T)[u>>>0],outputPadding:d?Array.from((w(),k).subarray(Number(d)>>>0,Number(f)>>>0)):[],outputShape:p?Array.from((w(),k).subarray(Number(p)>>>0,Number(m)>>>0)):[],activation:P(h)})},945305:(e,n,r,i,a,o,s,c,l,u,d,f,p,m)=>{t.ac(`ConvTranspose`,e,{format:c?`NHWC`:`NCHW`,autoPad:n,dilations:Array.from((w(),k).subarray(Number(r)>>>0,2+(Number(r)>>>0)>>>0)),group:i,kernelShape:Array.from((w(),k).subarray(Number(a)>>>0,2+(Number(a)>>>0)>>>0)),pads:Array.from((w(),k).subarray(Number(o)>>>0,4+(Number(o)>>>0)>>>0)),strides:Array.from((w(),k).subarray(Number(s)>>>0,2+(Number(s)>>>0)>>>0)),wIsConst:()=>!!(w(),T)[l>>>0],outputPadding:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],outputShape:f?Array.from((w(),k).subarray(Number(f)>>>0,Number(p)>>>0)):[],activation:P(m)})},945966:(e,n)=>{t.ac(`GlobalAveragePool`,e,{format:n?`NHWC`:`NCHW`})},946057:(e,n,r,i,a,o,s,c,l,u,d,f,p,m)=>{t.ac(`AveragePool`,e,{format:m?`NHWC`:`NCHW`,auto_pad:n,ceil_mode:r,count_include_pad:i,storage_order:a,dilations:o?Array.from((w(),k).subarray(Number(o)>>>0,Number(s)>>>0)):[],kernel_shape:c?Array.from((w(),k).subarray(Number(c)>>>0,Number(l)>>>0)):[],pads:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],strides:f?Array.from((w(),k).subarray(Number(f)>>>0,Number(p)>>>0)):[]})},946536:(e,n)=>{t.ac(`GlobalAveragePool`,e,{format:n?`NHWC`:`NCHW`})},946627:(e,n,r,i,a,o,s,c,l,u,d,f,p,m)=>{t.ac(`AveragePool`,e,{format:m?`NHWC`:`NCHW`,auto_pad:n,ceil_mode:r,count_include_pad:i,storage_order:a,dilations:o?Array.from((w(),k).subarray(Number(o)>>>0,Number(s)>>>0)):[],kernel_shape:c?Array.from((w(),k).subarray(Number(c)>>>0,Number(l)>>>0)):[],pads:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],strides:f?Array.from((w(),k).subarray(Number(f)>>>0,Number(p)>>>0)):[]})},947106:(e,n)=>{t.ac(`GlobalMaxPool`,e,{format:n?`NHWC`:`NCHW`})},947193:(e,n,r,i,a,o,s,c,l,u,d,f,p,m)=>{t.ac(`MaxPool`,e,{format:m?`NHWC`:`NCHW`,auto_pad:n,ceil_mode:r,count_include_pad:i,storage_order:a,dilations:o?Array.from((w(),k).subarray(Number(o)>>>0,Number(s)>>>0)):[],kernel_shape:c?Array.from((w(),k).subarray(Number(c)>>>0,Number(l)>>>0)):[],pads:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],strides:f?Array.from((w(),k).subarray(Number(f)>>>0,Number(p)>>>0)):[]})},947668:(e,n)=>{t.ac(`GlobalMaxPool`,e,{format:n?`NHWC`:`NCHW`})},947755:(e,n,r,i,a,o,s,c,l,u,d,f,p,m)=>{t.ac(`MaxPool`,e,{format:m?`NHWC`:`NCHW`,auto_pad:n,ceil_mode:r,count_include_pad:i,storage_order:a,dilations:o?Array.from((w(),k).subarray(Number(o)>>>0,Number(s)>>>0)):[],kernel_shape:c?Array.from((w(),k).subarray(Number(c)>>>0,Number(l)>>>0)):[],pads:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],strides:f?Array.from((w(),k).subarray(Number(f)>>>0,Number(p)>>>0)):[]})},948230:(e,n,r,i,a)=>{t.ac(`Gemm`,e,{alpha:n,beta:r,transA:i,transB:a})},948334:e=>{t.ac(`MatMul`,e,void 0)},948388:(e,n,r,i)=>{t.ac(`ArgMax`,e,{keepDims:!!n,selectLastIndex:!!r,axis:i})},948496:(e,n,r,i)=>{t.ac(`ArgMin`,e,{keepDims:!!n,selectLastIndex:!!r,axis:i})},948604:(e,n)=>{t.ac(`Softmax`,e,{axis:n})},948667:(e,n)=>{t.ac(`Concat`,e,{axis:n})},948727:(e,n,r,i,a)=>{t.ac(`Split`,e,{axis:n,numOutputs:r,splitSizes:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},948883:e=>{t.ac(`Expand`,e,void 0)},948937:(e,n)=>{t.ac(`Gather`,e,{axis:Number(n)})},949008:(e,n)=>{t.ac(`GatherElements`,e,{axis:Number(n)})},949087:(e,n)=>{t.ac(`GatherND`,e,{batch_dims:Number(n)})},949166:(e,n,r,i,a,o,s,c,l,u,d)=>{t.ac(`Resize`,e,{antialias:n,axes:r?Array.from((w(),k).subarray(Number(r)>>>0,Number(i)>>>0)):[],coordinateTransformMode:P(a),cubicCoeffA:o,excludeOutside:s,extrapolationValue:c,keepAspectRatioPolicy:P(l),mode:P(u),nearestMode:P(d)})},949528:(e,n,r,i,a,o,s)=>{t.ac(`Slice`,e,{starts:n?Array.from((w(),k).subarray(Number(n)>>>0,Number(r)>>>0)):[],ends:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[],axes:o?Array.from((w(),k).subarray(Number(o)>>>0,Number(s)>>>0)):[]})},949792:e=>{t.ac(`Tile`,e,void 0)},949844:(e,n,r)=>{t.ac(`InstanceNormalization`,e,{epsilon:n,format:r?`NHWC`:`NCHW`})},949958:(e,n,r)=>{t.ac(`InstanceNormalization`,e,{epsilon:n,format:r?`NHWC`:`NCHW`})},950072:e=>{t.ac(`Range`,e,void 0)},950125:(e,n)=>{t.ac(`Einsum`,e,{equation:P(n)})},950206:(e,n,r,i,a)=>{t.ac(`Pad`,e,{mode:n,value:r,pads:i?Array.from((w(),k).subarray(Number(i)>>>0,Number(a)>>>0)):[]})},950349:(e,n,r,i,a,o)=>{t.ac(`BatchNormalization`,e,{epsilon:n,momentum:r,spatial:!!a,trainingMode:!!i,format:o?`NHWC`:`NCHW`})},950518:(e,n,r,i,a,o)=>{t.ac(`BatchNormalization`,e,{epsilon:n,momentum:r,spatial:!!a,trainingMode:!!i,format:o?`NHWC`:`NCHW`})},950687:(e,n,r)=>{t.ac(`CumSum`,e,{exclusive:Number(n),reverse:Number(r)})},950784:(e,n,r)=>{t.ac(`DequantizeLinear`,e,{axis:n,blockSize:r})},950874:(e,n,r,i,a)=>{t.ac(`GridSample`,e,{align_corners:n,mode:P(r),padding_mode:P(i),format:a?`NHWC`:`NCHW`})},951044:(e,n,r,i,a)=>{t.ac(`GridSample`,e,{align_corners:n,mode:P(r),padding_mode:P(i),format:a?`NHWC`:`NCHW`})},951214:(e,n)=>{t.ac(`ScatterND`,e,{reduction:P(n)})},951299:(e,n,r,i,a,o,s,c,l)=>{t.ac(`Attention`,e,{numHeads:n,isUnidirectional:r,maskFilterValue:i,scale:a,doRotary:o,qkvHiddenSizes:s?Array.from((w(),k).subarray(Number(c)>>>0,Number(c)+s>>>0)):[],pastPresentShareBuffer:!!l})},951571:e=>{t.ac(`BiasAdd`,e,void 0)},951626:e=>{t.ac(`BiasSplitGelu`,e,void 0)},951687:e=>{t.ac(`FastGelu`,e,void 0)},951743:(e,n,r,i,a,o,s,c,l,u,d,f,p,m,h,g)=>{t.ac(`Conv`,e,{format:f?`NHWC`:`NCHW`,auto_pad:n,dilations:r?Array.from((w(),k).subarray(Number(r)>>>0,Number(i)>>>0)):[],group:a,kernel_shape:o?Array.from((w(),k).subarray(Number(o)>>>0,Number(s)>>>0)):[],pads:c?Array.from((w(),k).subarray(Number(c)>>>0,Number(l)>>>0)):[],strides:u?Array.from((w(),k).subarray(Number(u)>>>0,Number(d)>>>0)):[],w_is_const:()=>!!(w(),T)[Number(p)>>>0],activation:P(m),activation_params:h?Array.from((w(),te).subarray(Number(h)>>>0,Number(g)>>>0)):[]})},952327:e=>{t.ac(`Gelu`,e,void 0)},952379:(e,n,r,i,a,o,s,c,l)=>{t.ac(`GroupQueryAttention`,e,{numHeads:n,kvNumHeads:r,scale:i,softcap:a,doRotary:o,rotaryInterleaved:s,smoothSoftmax:c,localWindowSize:l})},952596:(e,n,r,i)=>{t.ac(`LayerNormalization`,e,{axis:n,epsilon:r,simplified:!!i})},952707:(e,n,r,i)=>{t.ac(`LayerNormalization`,e,{axis:n,epsilon:r,simplified:!!i})},952818:(e,n,r,i,a,o)=>{t.ac(`MatMulNBits`,e,{k:n,n:r,accuracyLevel:i,bits:a,blockSize:o})},952945:(e,n,r,i,a,o)=>{t.ac(`MultiHeadAttention`,e,{numHeads:n,isUnidirectional:r,maskFilterValue:i,scale:a,doRotary:o})},953104:(e,n)=>{t.ac(`QuickGelu`,e,{alpha:n})},953168:(e,n,r,i,a)=>{t.ac(`RotaryEmbedding`,e,{interleaved:!!n,numHeads:r,rotaryEmbeddingDim:i,scale:a})},953307:(e,n,r)=>{t.ac(`SkipLayerNormalization`,e,{epsilon:n,simplified:!!r})},953409:(e,n,r)=>{t.ac(`SkipLayerNormalization`,e,{epsilon:n,simplified:!!r})},953511:(e,n,r,i)=>{t.ac(`GatherBlockQuantized`,e,{gatherAxis:n,quantizeAxis:r,blockSize:i})},953632:e=>{t.Id(e)},953666:(e,n)=>t.Kd(Number(e),Number(n),t.$c.Nd,t.$c.errors)};function vi(e,n,r){return cn(async()=>{await t.Gd(Number(e),Number(n),Number(r))})}function yi(){return typeof wasmOffsetConverter<`u`}function bi(e,t,n,r){var i=Q();try{return Lr(e,t,n,r)}catch(e){if(Z(i),e!==e+0)throw e;X(1,0)}}function xi(e,t,n){var r=Q();try{return Nr(e,t,n)}catch(e){if(Z(r),e!==e+0)throw e;X(1,0)}}function Si(e,t,n){var r=Q();try{Or(e,t,n)}catch(e){if(Z(r),e!==e+0)throw e;X(1,0)}}function Ci(e,t){var n=Q();try{return kr(e,t)}catch(e){if(Z(n),e!==e+0)throw e;X(1,0)}}function wi(e){var t=Q();try{Ar(e)}catch(e){if(Z(t),e!==e+0)throw e;X(1,0)}}function Ti(e,t,n,r,i,a,o){var s=Q();try{return Fr(e,t,n,r,i,a,o)}catch(e){if(Z(s),e!==e+0)throw e;X(1,0)}}function Ei(e,t){var n=Q();try{Rr(e,t)}catch(e){if(Z(n),e!==e+0)throw e;X(1,0)}}function Di(e,t,n,r,i,a){var o=Q();try{jr(e,t,n,r,i,a)}catch(e){if(Z(o),e!==e+0)throw e;X(1,0)}}function Oi(e,t,n,r){var i=Q();try{Ir(e,t,n,r)}catch(e){if(Z(i),e!==e+0)throw e;X(1,0)}}function ki(e,t,n,r,i){var a=Q();try{Mr(e,t,n,r,i)}catch(e){if(Z(a),e!==e+0)throw e;X(1,0)}}function Ai(e,t,n,r,i,a,o){var s=Q();try{Br(e,t,n,r,i,a,o)}catch(e){if(Z(s),e!==e+0)throw e;X(1,0)}}function ji(e,t,n,r,i,a,o){var s=Q();try{Vr(e,t,n,r,i,a,o)}catch(e){if(Z(s),e!==e+0)throw e;X(1,0)}}function Mi(e,t,n,r,i,a,o,s){var c=Q();try{Gr(e,t,n,r,i,a,o,s)}catch(e){if(Z(c),e!==e+0)throw e;X(1,0)}}function Ni(e,t,n,r,i){var a=Q();try{return zr(e,t,n,r,i)}catch(e){if(Z(a),e!==e+0)throw e;X(1,0)}}function Pi(e,t,n,r,i,a,o,s){var c=Q();try{Kr(e,t,n,r,i,a,o,s)}catch(e){if(Z(c),e!==e+0)throw e;X(1,0)}}function Fi(e,t,n,r,i,a,o,s,c,l,u,d){var f=Q();try{Hr(e,t,n,r,i,a,o,s,c,l,u,d)}catch(e){if(Z(f),e!==e+0)throw e;X(1,0)}}function Ii(e,t,n,r,i,a){var o=Q();try{return Ur(e,t,n,r,i,a)}catch(e){if(Z(o),e!==e+0)throw e;X(1,0)}}function Li(e,t,n){var r=Q();try{return qr(e,t,n)}catch(e){if(Z(r),e!==e+0)throw e;return X(1,0),0n}}function Ri(e,t,n,r,i,a,o,s,c){var l=Q();try{Pr(e,t,n,r,i,a,o,s,c)}catch(e){if(Z(l),e!==e+0)throw e;X(1,0)}}function zi(e){var t=Q();try{return Jr(e)}catch(e){if(Z(t),e!==e+0)throw e;X(1,0)}}function Bi(e,t,n){var r=Q();try{return Yr(e,t,n)}catch(e){if(Z(r),e!==e+0)throw e;X(1,0)}}function Vi(e,t){var n=Q();try{return ui(e,t)}catch(e){if(Z(n),e!==e+0)throw e;return X(1,0),0n}}function Hi(e,t,n,r,i){var a=Q();try{Xr(e,t,n,r,i)}catch(e){if(Z(a),e!==e+0)throw e;X(1,0)}}function Ui(e){var t=Q();try{return Zr(e)}catch(e){if(Z(t),e!==e+0)throw e;return X(1,0),0n}}function Wi(e,t,n,r,i,a){var o=Q();try{return ri(e,t,n,r,i,a)}catch(e){if(Z(o),e!==e+0)throw e;X(1,0)}}function Gi(e,t,n,r,i,a){var o=Q();try{return $(e,t,n,r,i,a)}catch(e){if(Z(o),e!==e+0)throw e;X(1,0)}}function Ki(e,t,n,r,i,a,o,s){var c=Q();try{return Wr(e,t,n,r,i,a,o,s)}catch(e){if(Z(c),e!==e+0)throw e;X(1,0)}}function qi(e,t,n,r,i){var a=Q();try{return ii(e,t,n,r,i)}catch(e){if(Z(a),e!==e+0)throw e;return X(1,0),0n}}function Ji(e,t,n,r){var i=Q();try{return ai(e,t,n,r)}catch(e){if(Z(i),e!==e+0)throw e;X(1,0)}}function Yi(e,t,n,r){var i=Q();try{return oi(e,t,n,r)}catch(e){if(Z(i),e!==e+0)throw e;X(1,0)}}function Xi(e,t,n,r,i,a,o,s,c,l,u,d){var f=Q();try{return si(e,t,n,r,i,a,o,s,c,l,u,d)}catch(e){if(Z(f),e!==e+0)throw e;X(1,0)}}function Zi(e,t,n,r,i,a,o,s,c,l,u){var d=Q();try{ti(e,t,n,r,i,a,o,s,c,l,u)}catch(e){if(Z(d),e!==e+0)throw e;X(1,0)}}function Qi(e,t,n,r,i,a,o,s,c,l,u,d,f,p,m,h){var g=Q();try{ni(e,t,n,r,i,a,o,s,c,l,u,d,f,p,m,h)}catch(e){if(Z(g),e!==e+0)throw e;X(1,0)}}function $i(e,t,n,r){var i=Q();try{return ci(e,t,n,r)}catch(e){if(Z(i),e!==e+0)throw e;X(1,0)}}function ea(e,t,n,r,i){var a=Q();try{return li(e,t,n,r,i)}catch(e){if(Z(a),e!==e+0)throw e;X(1,0)}}function ta(e,t,n){var r=Q();try{return Qr(e,t,n)}catch(e){if(Z(r),e!==e+0)throw e;X(1,0)}}function na(e,t,n){var r=Q();try{return $r(e,t,n)}catch(e){if(Z(r),e!==e+0)throw e;X(1,0)}}function ra(e,t,n,r){var i=Q();try{ei(e,t,n,r)}catch(e){if(Z(i),e!==e+0)throw e;X(1,0)}}function ia(){if(0<he)ge=ia;else if(i)h?.(t),ce();else{for(var e=me;0<e.length;)e.shift()(t);0<he?ge=ia:(t.calledRun=!0,S||(ce(),h?.(t)))}}return i||(hi=await de(),ia()),t.PTR_SIZE=4,oe?t:new Promise((e,t)=>{h=e,g=t})}var Le,Re,ze=o(()=>{Le=Ie,Re=globalThis.self?.name?.startsWith(`em-pthread`),Re&&Ie()}),Be,Ve,He,Ue,We,Ge,Ke,qe,Je,Ye,Xe,P,Ze,Qe,$e=o(()=>{ke(),Be=typeof location>`u`?void 0:location.origin,Ve=import.meta.url>`file:`&&import.meta.url<`file;`,He=()=>Ve?new URL(new URL(`ort.bundle.min.mjs`,import.meta.url).href,Be).href:import.meta.url,Ue=He(),We=()=>{if(Ue&&!Ue.startsWith(`blob:`))return Ue.substring(0,Ue.lastIndexOf(`/`)+1)},Ge=(e,t)=>{try{let n=t??Ue;return(n?new URL(e,n):new URL(e)).origin===Be}catch{return!1}},Ke=(e,t)=>{let n=t??Ue;try{return(n?new URL(e,n):new URL(e)).href}catch{return}},qe=(e,t)=>`${t??`./`}${e}`,Je=async e=>{let t=await(await fetch(e,{credentials:`same-origin`})).blob();return URL.createObjectURL(t)},Ye=async e=>(await import(e)).default,Xe=(Pe(),l(Ae)).default,P=async()=>{if(!Ue)throw Error(`Failed to load proxy worker: cannot determine the script source URL.`);if(Ge(Ue))return[void 0,Xe()];let e=await Je(Ue);return[e,Xe(e)]},Ze=(ze(),l(Fe)).default,Qe=async(e,t,n,r)=>{let i=Ze&&!(e||t);if(i)if(Ue)i=Ge(Ue)||r&&!n;else if(r&&!n)i=!0;else throw Error(`cannot determine the script source URL.`);if(i)return[void 0,Ze];{let r=`ort-wasm-simd-threaded.jsep.mjs`,i=e??Ke(r,t),a=n&&i&&!Ge(i,t),o=a?await Je(i):i??qe(r,t);return[a?o:void 0,await Ye(o)]}}}),et,tt,nt,rt,it,at,ot,st,F,ct=o(()=>{$e(),tt=!1,nt=!1,rt=!1,it=()=>{if(typeof SharedArrayBuffer>`u`)return!1;try{return typeof MessageChannel<`u`&&new MessageChannel().port1.postMessage(new SharedArrayBuffer(1)),WebAssembly.validate(new Uint8Array([0,97,115,109,1,0,0,0,1,4,1,96,0,0,3,2,1,0,5,4,1,3,1,1,10,11,1,9,0,65,0,254,16,2,0,26,11]))}catch{return!1}},at=()=>{try{return WebAssembly.validate(new Uint8Array([0,97,115,109,1,0,0,0,1,4,1,96,0,0,3,2,1,0,10,30,1,28,0,65,0,253,15,253,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,253,186,1,26,11]))}catch{return!1}},ot=()=>{try{return WebAssembly.validate(new Uint8Array([0,97,115,109,1,0,0,0,1,5,1,96,0,1,123,3,2,1,0,10,19,1,17,0,65,1,253,15,65,2,253,15,65,3,253,15,253,147,2,11]))}catch{return!1}},st=async e=>{if(tt)return Promise.resolve();if(nt)throw Error(`multiple calls to 'initializeWebAssembly()' detected.`);if(rt)throw Error(`previous call to 'initializeWebAssembly()' failed.`);nt=!0;let t=e.initTimeout,n=e.numThreads;if(e.simd!==!1){if(e.simd===`relaxed`){if(!ot())throw Error(`Relaxed WebAssembly SIMD is not supported in the current environment.`)}else if(!at())throw Error(`WebAssembly SIMD is not supported in the current environment.`)}let r=it();n>1&&!r&&(typeof self<`u`&&!self.crossOriginIsolated&&console.warn(`env.wasm.numThreads is set to `+n+`, but this will not work unless you enable crossOriginIsolated mode. See https://web.dev/cross-origin-isolation-guide/ for more info.`),console.warn(`WebAssembly multi-threading is not supported in the current environment. Falling back to single-threading.`),e.numThreads=n=1);let i=e.wasmPaths,a=typeof i==`string`?i:void 0,o=i?.mjs,s=o?.href??o,c=i?.wasm,l=c?.href??c,u=e.wasmBinary,[d,f]=await Qe(s,a,n>1,!!u||!!l),p=!1,m=[];if(t>0&&m.push(new Promise(e=>{setTimeout(()=>{p=!0,e()},t)})),m.push(new Promise((e,t)=>{let r={numThreads:n};if(u)r.wasmBinary=u,r.locateFile=e=>e;else if(l||a)r.locateFile=e=>l??a+e;else if(s&&s.indexOf(`blob:`)!==0)r.locateFile=e=>new URL(e,s).href;else if(d){let e=We();e&&(r.locateFile=t=>e+t)}f(r).then(t=>{nt=!1,tt=!0,et=t,e(),d&&URL.revokeObjectURL(d)},e=>{nt=!1,rt=!0,t(e)})})),await Promise.race(m),p)throw Error(`WebAssembly backend initializing failed due to timeout: ${t}ms`)},F=()=>{if(tt&&et)return et;throw Error(`WebAssembly is not initialized yet.`)}}),lt,ut,I,dt=o(()=>{ct(),lt=(e,t)=>{let n=F(),r=n.lengthBytesUTF8(e)+1,i=n._malloc(r);return n.stringToUTF8(e,i,r),t.push(i),i},ut=(e,t,n,r)=>{if(typeof e==`object`&&e){if(n.has(e))throw Error(`Circular reference in options`);n.add(e)}Object.entries(e).forEach(([e,i])=>{let a=t?t+e:e;if(typeof i==`object`)ut(i,a+`.`,n,r);else if(typeof i==`string`||typeof i==`number`)r(a,i.toString());else if(typeof i==`boolean`)r(a,i?`1`:`0`);else throw Error(`Can't handle extra config type: ${typeof i}`)})},I=e=>{let t=F(),n=t.stackSave();try{let n=t.PTR_SIZE,r=t.stackAlloc(2*n);t._OrtGetLastError(r,r+n);let i=Number(t.getValue(r,n===4?`i32`:`i64`)),a=t.getValue(r+n,`*`),o=a?t.UTF8ToString(a):``;throw Error(`${e} ERROR_CODE: ${i}, ERROR_MESSAGE: ${o}`)}finally{t.stackRestore(n)}}}),ft,pt=o(()=>{ct(),dt(),ft=e=>{let t=F(),n=0,r=[],i=e||{};try{if(e?.logSeverityLevel===void 0)i.logSeverityLevel=2;else if(typeof e.logSeverityLevel!=`number`||!Number.isInteger(e.logSeverityLevel)||e.logSeverityLevel<0||e.logSeverityLevel>4)throw Error(`log severity level is not valid: ${e.logSeverityLevel}`);if(e?.logVerbosityLevel===void 0)i.logVerbosityLevel=0;else if(typeof e.logVerbosityLevel!=`number`||!Number.isInteger(e.logVerbosityLevel))throw Error(`log verbosity level is not valid: ${e.logVerbosityLevel}`);e?.terminate===void 0&&(i.terminate=!1);let a=0;return e?.tag!==void 0&&(a=lt(e.tag,r)),n=t._OrtCreateRunOptions(i.logSeverityLevel,i.logVerbosityLevel,!!i.terminate,a),n===0&&I(`Can't create run options.`),e?.extra!==void 0&&ut(e.extra,``,new WeakSet,(e,i)=>{let a=lt(e,r),o=lt(i,r);t._OrtAddRunConfigEntry(n,a,o)!==0&&I(`Can't set a run config entry: ${e} - ${i}.`)}),[n,r]}catch(e){throw n!==0&&t._OrtReleaseRunOptions(n),r.forEach(e=>t._free(e)),e}}}),mt,ht,gt,_t,vt,yt,bt=o(()=>{ct(),dt(),mt=e=>{switch(e){case`disabled`:return 0;case`basic`:return 1;case`extended`:return 2;case`layout`:return 3;case`all`:return 99;default:throw Error(`unsupported graph optimization level: ${e}`)}},ht=e=>{switch(e){case`sequential`:return 0;case`parallel`:return 1;default:throw Error(`unsupported execution mode: ${e}`)}},gt=e=>{e.extra||={},e.extra.session||(e.extra.session={});let t=e.extra.session;t.use_ort_model_bytes_directly||=`1`,e.executionProviders&&e.executionProviders.some(e=>(typeof e==`string`?e:e.name)===`webgpu`)&&(e.enableMemPattern=!1)},_t=(e,t,n,r)=>{let i=lt(t,r),a=lt(n,r);F()._OrtAddSessionConfigEntry(e,i,a)!==0&&I(`Can't set a session config entry: ${t} - ${n}.`)},vt=async(e,t,n)=>{let r=t.executionProviders;for(let t of r){let r=typeof t==`string`?t:t.name,i=[];switch(r){case`webnn`:if(r=`WEBNN`,typeof t!=`string`){let r=t?.deviceType;r&&_t(e,`deviceType`,r,n)}break;case`webgpu`:if(r=`JS`,typeof t!=`string`){let r=t;if(r?.preferredLayout){if(r.preferredLayout!==`NCHW`&&r.preferredLayout!==`NHWC`)throw Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${r.preferredLayout}`);_t(e,`preferredLayout`,r.preferredLayout,n)}}break;case`wasm`:case`cpu`:continue;default:throw Error(`not supported execution provider: ${r}`)}let a=lt(r,n),o=i.length,s=0,c=0;if(o>0){s=F()._malloc(o*F().PTR_SIZE),n.push(s),c=F()._malloc(o*F().PTR_SIZE),n.push(c);for(let e=0;e<o;e++)F().setValue(s+e*F().PTR_SIZE,i[e][0],`*`),F().setValue(c+e*F().PTR_SIZE,i[e][1],`*`)}await F()._OrtAppendExecutionProvider(e,a,s,c,o)!==0&&I(`Can't append execution provider: ${r}.`)}},yt=async e=>{let t=F(),n=0,r=[],i=e||{};gt(i);try{let e=mt(i.graphOptimizationLevel??`all`),a=ht(i.executionMode??`sequential`),o=typeof i.logId==`string`?lt(i.logId,r):0,s=i.logSeverityLevel??2;if(!Number.isInteger(s)||s<0||s>4)throw Error(`log severity level is not valid: ${s}`);let c=i.logVerbosityLevel??0;if(!Number.isInteger(c)||c<0||c>4)throw Error(`log verbosity level is not valid: ${c}`);let l=typeof i.optimizedModelFilePath==`string`?lt(i.optimizedModelFilePath,r):0;if(n=t._OrtCreateSessionOptions(e,!!i.enableCpuMemArena,!!i.enableMemPattern,a,!!i.enableProfiling,0,o,s,c,l),n===0&&I(`Can't create session options.`),i.executionProviders&&await vt(n,i,r),i.enableGraphCapture!==void 0){if(typeof i.enableGraphCapture!=`boolean`)throw Error(`enableGraphCapture must be a boolean value: ${i.enableGraphCapture}`);_t(n,`enableGraphCapture`,i.enableGraphCapture.toString(),r)}if(i.freeDimensionOverrides)for(let[e,a]of Object.entries(i.freeDimensionOverrides)){if(typeof e!=`string`)throw Error(`free dimension override name must be a string: ${e}`);if(typeof a!=`number`||!Number.isInteger(a)||a<0)throw Error(`free dimension override value must be a non-negative integer: ${a}`);let i=lt(e,r);t._OrtAddFreeDimensionOverride(n,i,a)!==0&&I(`Can't set a free dimension override: ${e} - ${a}.`)}return i.extra!==void 0&&ut(i.extra,``,new WeakSet,(e,t)=>{_t(n,e,t,r)}),[n,r]}catch(e){throw n!==0&&t._OrtReleaseSessionOptions(n)!==0&&I(`Can't release session options.`),r.forEach(e=>t._free(e)),e}}}),xt,St,Ct,wt,Tt,Et,Dt,Ot,L=o(()=>{xt=e=>{switch(e){case`int8`:return 3;case`uint8`:return 2;case`bool`:return 9;case`int16`:return 5;case`uint16`:return 4;case`int32`:return 6;case`uint32`:return 12;case`float16`:return 10;case`float32`:return 1;case`float64`:return 11;case`string`:return 8;case`int64`:return 7;case`uint64`:return 13;case`int4`:return 22;case`uint4`:return 21;default:throw Error(`unsupported data type: ${e}`)}},St=e=>{switch(e){case 3:return`int8`;case 2:return`uint8`;case 9:return`bool`;case 5:return`int16`;case 4:return`uint16`;case 6:return`int32`;case 12:return`uint32`;case 10:return`float16`;case 1:return`float32`;case 11:return`float64`;case 8:return`string`;case 7:return`int64`;case 13:return`uint64`;case 22:return`int4`;case 21:return`uint4`;default:throw Error(`unsupported data type: ${e}`)}},Ct=(e,t)=>{let n=[-1,4,1,1,2,2,4,8,-1,1,2,8,4,8,-1,-1,-1,-1,-1,-1,-1,.5,.5][e],r=typeof t==`number`?t:t.reduce((e,t)=>e*t,1);return n>0?Math.ceil(r*n):void 0},wt=e=>{switch(e){case`float16`:return typeof Float16Array<`u`&&Float16Array.from?Float16Array:Uint16Array;case`float32`:return Float32Array;case`uint8`:return Uint8Array;case`int8`:return Int8Array;case`uint16`:return Uint16Array;case`int16`:return Int16Array;case`int32`:return Int32Array;case`bool`:return Uint8Array;case`float64`:return Float64Array;case`uint32`:return Uint32Array;case`int64`:return BigInt64Array;case`uint64`:return BigUint64Array;default:throw Error(`unsupported type: ${e}`)}},Tt=e=>{switch(e){case`verbose`:return 0;case`info`:return 1;case`warning`:return 2;case`error`:return 3;case`fatal`:return 4;default:throw Error(`unsupported logging level: ${e}`)}},Et=e=>e===`float32`||e===`float16`||e===`int32`||e===`int64`||e===`uint32`||e===`uint8`||e===`bool`||e===`uint4`||e===`int4`,Dt=e=>e===`float32`||e===`float16`||e===`int32`||e===`int64`||e===`uint32`||e===`uint64`||e===`int8`||e===`uint8`||e===`bool`||e===`uint4`||e===`int4`,Ot=e=>{switch(e){case`none`:return 0;case`cpu`:return 1;case`cpu-pinned`:return 2;case`texture`:return 3;case`gpu-buffer`:return 4;case`ml-tensor`:return 5;default:throw Error(`unsupported data location: ${e}`)}}}),kt,At=o(()=>{ke(),kt=async e=>{if(typeof e==`string`){let t=await fetch(e);if(!t.ok)throw Error(`failed to load external data file: ${e}`);let n=t.headers.get(`Content-Length`),r=n?parseInt(n,10):0;if(r<1073741824)return new Uint8Array(await t.arrayBuffer());{if(!t.body)throw Error(`failed to load external data file: ${e}, no response body.`);let n=t.body.getReader(),i;try{i=new ArrayBuffer(r)}catch(e){if(e instanceof RangeError){let e=Math.ceil(r/65536);i=new WebAssembly.Memory({initial:e,maximum:e}).buffer}else throw e}let a=0;for(;;){let{done:e,value:t}=await n.read();if(e)break;let r=t.byteLength;new Uint8Array(i,a,r).set(t),a+=r}return new Uint8Array(i,0,r)}}else return e instanceof Blob?new Uint8Array(await e.arrayBuffer()):e instanceof Uint8Array?e:new Uint8Array(e)}}),jt,Mt,Nt,Pt,Ft,It,R,Lt=o(()=>{L(),jt=[`V`,`I`,`W`,`E`,`F`],Mt=(e,t)=>{console.log(`[${jt[e]},${new Date().toISOString()}]${t}`)},Ft=(e,t)=>{Nt=e,Pt=t},It=(e,t)=>{let n=Tt(e);n>=Tt(Nt)&&Mt(n,typeof t==`function`?t():t)},R=(...e)=>{Pt&&It(...e)}}),Rt,zt,z,Bt,Vt,Ht,Ut,B=o(()=>{Rt=class{static calcMatMulShape(e,t){return e[1]===t[0]?[e[0],t[1]]:void 0}},zt=class{static calcShape(e,t,n=!1){let r=e.length,i=t.length;if(r===0)return t;if(i===0)return e;let a=Math.max(e.length,t.length),o=Array(a);if(n){if(r<2||i<2)return;let n=Rt.calcMatMulShape([e[r-2],e[r-1]],[t[i-2],t[i-1]]);if(n===void 0)return;[o[a-2],o[a-1]]=n}for(let s=n?3:1;s<=a;s++){let n=r-s<0?1:e[r-s],c=i-s<0?1:t[i-s];if(n!==c&&n>1&&c>1)return;let l=Math.max(n,c);if(n&&c)o[a-s]=Math.max(n,c);else{if(l>1)return;o[a-s]=0}}return o}static isValidBroadcast(e,t){let n=e.length,r=t.length;if(n>r)return!1;for(let i=1;i<=n;i++)if(e[n-i]!==1&&e[n-i]!==t[r-i])return!1;return!0}},z=class e{static size(t){return e.getSizeFromDimensionRange(t,0,t.length)}static convertShape(e,t=4){let n=e.length;if(n===0)return[];let r=Array(n),i=n-1;for(;i>=0;){if(e[i]%t===0){r[i]=e[i]/t;break}if(t%e[i]!==0)throw Error(`cannot convert shape`);r[i]=1,t/=e[i],i--}for(i--;i>=0;i--)r[i]=e[i];return r}static sizeFromDimension(t,n){if(n<0||n>t.length)throw Error(`invalid dimension of ${n} for sizeFromDimension as Tensor has ${t.length} dimensions.`);return e.getSizeFromDimensionRange(t,n,t.length)}static sizeToDimension(t,n){if(n<0||n>t.length)throw Error(`invalid dimension of ${n} for sizeToDimension as Tensor has ${t.length} dimensions.`);return e.getSizeFromDimensionRange(t,0,n)}static getSizeFromDimensionRange(e,t,n){let r=1;for(let i=t;i<n;i++){if(e[i]<0)throw Error(`cannot get valid size from specified dimension range. Most likely the range contains negative values in them.`);r*=Number(e[i])}return r}static computeStrides(e){let t=e.length;if(t===0)return[];if(t===1)return[1];let n=Array(t);n[t-1]=1,n[t-2]=e[t-1];for(let r=t-3;r>=0;--r)n[r]=n[r+1]*e[r+1];return n}static normalizeAxis(e,t){if(e<-t&&e>=t)throw Error(`unsupported axis for this operation.`);return e<0?e+t:e}static normalizeAxes(e,t){return e.map(n=>this.normalizeAxis(n,t??e.length))}static sortBasedOnPerm(e,t){return t?t.map(t=>e[t]):e.slice().reverse()}static padShape(e,t){let n=e.length;return e.map((e,r)=>e+t[r]+t[r+n])}static areEqual(e,t){return e.length===t.length?e.every((e,n)=>e===t[n]):!1}},Bt=class e{static adjustPoolAttributes(e,t,n,r,i,a){if(!e&&n.length!==t.length-2)throw Error(`length of specified kernel shapes should be 2 less than length of input dimensions`);if(e)for(let e=0;e<t.length-2;e++)e>=n.length?n.push(t[e+2]):n[e]=t[e+2];for(let e=0;e<n.length;e++)if(e<r.length){if(r[e]<0)throw Error(`strides should be greater than or equal to 1`)}else r.push(1);for(let e=0;e<n.length;e++)if(e<i.length){if(i[e]<0)throw Error(`dilations should be greater than or equal to 1`)}else i.push(1);for(let e=0;e<n.length*2;e++)if(e<a.length){if(a[e]<0)throw Error(`pad should be greater than or equal to 1`)}else a.push(0);for(let e=0;e<n.length;e++){if(n[e]<=0)throw Error(`kernel shapes need to be greater than 0`);if(a[e]>=n[e]||a[e+n.length]>=n[e])throw Error(`pads should be smaller than kernel`)}}static adjustPadsBasedOnAutoPad(t,n,r,i,a,o,s){if(s){if(a.length!==2*(t.length-2))throw Error(`length of pads should be twice the length of data dimensions`);if(n.length!==t.length-2)throw Error(`length of strides should be the length of data dimensions`);if(i.length!==t.length-2)throw Error(`length of kernel shapes should be the length of data dimensions`);for(let c=0;c<t.length-2;c++)e.adjustPadAndReturnShape(t[c+(o?1:2)],n[c],r[c],i[c],a,c,c+t.length-2,s)}}static computePoolOutputShape(t,n,r,i,a,o,s){if(n.length<=0)throw Error(`input shape must be of size greater than 0`);let c=[n[0],n[1]];return e.computeShapeHelper(t,n,c,r,i,a,o,s),c}static computeConvOutputShape(t,n,r,i,a,o,s){if(t.length<=0||n.length<=0)throw Error(`invalid input tensor dims or invalid filter tensor dims`);let c=[t[0],n[0]];return e.computeShapeHelper(!1,t,c,r,i,a,o,s),c}static computeShapeHelper(t,n,r,i,a,o,s,c){if(t)for(let e=0;e<n.length-2;e++)r.push(1);else for(let t=0;t<n.length-2;t++)r.push(e.adjustPadAndReturnShape(n[t+2],i[t],a[t],o[t],s,t,t+n.length-2,c))}static adjustPadAndReturnShape(e,t,n,r,i,a,o,s){let c=n*(r-1)+1;if(s&&s!==`NOTSET`)switch(s){case`VALID`:return i[a]=0,i[o]=0,Math.floor((e-c)/t+1);case`SAME_LOWER`:case`SAME_UPPER`:if(n!==1)throw Error(`Dilation not supported for SAME_UPPER or SAME_LOWER`);{let n=((e+t-1)/t-1)*t+r-e;return i[a]=Math.floor(s===`SAME_LOWER`?(n+1)/2:n/2),i[o]=n-i[a],Math.floor((e+n-r)/t+1)}default:throw Error(`Unsupported AutoPad type`)}else return Math.floor((e+i[a]+i[o]-c)/t+1)}},Vt=class{static getShapeOfGemmResult(e,t,n,r,i){if(e.length!==2||n.length!==2)throw Error(`shape need to be of size 2`);let a,o,s;t?(a=e[1],o=e[0]):(a=e[0],o=e[1]);let c=-1;if(r?(s=n[0],c=1):(s=n[1],c=0),n[c]!==o)throw Error(`dimension mismatch`);if(a<=0||s<=0||o<=0)throw Error(`invalid shape specified`);if(i&&!zt.isValidBroadcast(i,[a,s]))throw Error(`gemm: invalid bias shape for broadcast`);return[a,s,o]}},Ht=-34028234663852886e22,Ut=34028234663852886e22}),Wt,Gt=o(()=>{L(),Wt=(e,t)=>new(wt(t))(e)}),Kt,qt,Jt,Yt,Xt,Zt,Qt,$t,en,tn,nn,rn=o(()=>{L(),Lt(),Kt=new Map([[`float32`,32],[`float16`,16],[`int32`,32],[`uint32`,32],[`int64`,64],[`uint64`,64],[`int8`,8],[`uint8`,8],[`int4`,4],[`uint4`,4]]),qt=(e,t)=>{if(t===`int32`)return e;let n=Kt.get(t);if(!n)throw Error(`WebNN backend does not support data type: ${t}`);let r=n/8;if(e.byteLength%r!==0)throw Error(`Invalid Uint8Array length - must be a multiple of ${r}.`);let i=e.byteLength/r,a=new(wt(t))(e.buffer,e.byteOffset,i);switch(t){case`int64`:case`uint64`:{let e=new Int32Array(i);for(let t=0;t<i;t++){let n=a[t];if(n>2147483647n||n<-2147483648n)throw Error(`Can not convert int64 data to int32 - value out of range.`);e[t]=Number(n)}return new Uint8Array(e.buffer)}case`int8`:case`uint8`:case`uint32`:{if(t===`uint32`&&a.some(e=>e>2147483647))throw Error(`Can not convert uint32 data to int32 - value out of range.`);let e=Int32Array.from(a,Number);return new Uint8Array(e.buffer)}default:throw Error(`Unsupported data conversion from ${t} to 'int32'`)}},Jt=(e,t)=>{if(t===`int32`)return e;if(e.byteLength%4!=0)throw Error(`Invalid Uint8Array length - must be a multiple of 4 (int32).`);let n=e.byteLength/4,r=new Int32Array(e.buffer,e.byteOffset,n);switch(t){case`int64`:{let e=BigInt64Array.from(r,BigInt);return new Uint8Array(e.buffer)}case`uint64`:{if(r.some(e=>e<0))throw Error(`Can not convert int32 data to uin64 - negative value found.`);let e=BigUint64Array.from(r,BigInt);return new Uint8Array(e.buffer)}case`int8`:{if(r.some(e=>e<-128||e>127))throw Error(`Can not convert int32 data to int8 - value out of range.`);let e=Int8Array.from(r,Number);return new Uint8Array(e.buffer)}case`uint8`:if(r.some(e=>e<0||e>255))throw Error(`Can not convert int32 data to uint8 - value out of range.`);return Uint8Array.from(r,Number);case`uint32`:{if(r.some(e=>e<0))throw Error(`Can not convert int32 data to uint32 - negative value found.`);let e=Uint32Array.from(r,Number);return new Uint8Array(e.buffer)}default:throw Error(`Unsupported data conversion from 'int32' to ${t}`)}},Yt=1,Xt=()=>Yt++,Zt=new Map([[`int8`,`int32`],[`uint8`,`int32`],[`uint32`,`int32`],[`int64`,`int32`]]),Qt=(e,t)=>{let n=Kt.get(e);if(!n)throw Error(`WebNN backend does not support data type: ${e}`);return t.length>0?Math.ceil(t.reduce((e,t)=>e*t)*n/8):0},$t=class{constructor(e){this.isDataConverted=!1;let{sessionId:t,context:n,tensor:r,dataType:i,shape:a,fallbackDataType:o}=e;this.sessionId=t,this.mlContext=n,this.mlTensor=r,this.dataType=i,this.tensorShape=a,this.fallbackDataType=o}get tensor(){return this.mlTensor}get type(){return this.dataType}get fallbackType(){return this.fallbackDataType}get shape(){return this.tensorShape}get byteLength(){return Qt(this.dataType,this.tensorShape)}destroy(){R(`verbose`,()=>`[WebNN] TensorWrapper.destroy`),this.mlTensor.destroy()}write(e){this.mlContext.writeTensor(this.mlTensor,e)}async read(e){if(this.fallbackDataType){let t=await this.mlContext.readTensor(this.mlTensor),n=Jt(new Uint8Array(t),this.dataType);if(e){(e instanceof ArrayBuffer?new Uint8Array(e):new Uint8Array(e.buffer,e.byteOffset,e.byteLength)).set(n);return}else return n.buffer}else return e?this.mlContext.readTensor(this.mlTensor,e):this.mlContext.readTensor(this.mlTensor)}canReuseTensor(e,t,n){return this.mlContext===e&&this.dataType===t&&this.tensorShape.length===n.length&&this.tensorShape.every((e,t)=>e===n[t])}setIsDataConverted(e){this.isDataConverted=e}},en=class{constructor(e,t){this.tensorManager=e,this.wrapper=t}get tensorWrapper(){return this.wrapper}releaseTensor(){this.tensorWrapper&&(this.tensorManager.releaseTensor(this.tensorWrapper),this.wrapper=void 0)}async ensureTensor(e,t,n,r){let i=this.tensorManager.getMLContext(e),a=this.tensorManager.getMLOpSupportLimits(e),o;if(!a?.input.dataTypes.includes(t)){if(o=Zt.get(t),!o||a?.input.dataTypes.includes(o))throw Error(`WebNN backend does not support data type: ${t}`);R(`verbose`,()=>`[WebNN] TensorIdTracker.ensureTensor: fallback dataType from ${t} to ${o}`)}if(this.wrapper){if(this.wrapper.canReuseTensor(i,t,n))return this.wrapper.tensor;if(r){if(this.wrapper.byteLength!==Qt(t,n))throw Error(`Unable to copy data to tensor with different size.`);this.activeUpload=new Uint8Array(await this.wrapper.read())}this.tensorManager.releaseTensor(this.wrapper)}let s=typeof MLTensorUsage>`u`?void 0:MLTensorUsage.READ|MLTensorUsage.WRITE;return this.wrapper=await this.tensorManager.getCachedTensor(e,t,n,s,!0,!0,o),r&&this.activeUpload&&(this.wrapper.write(this.activeUpload),this.activeUpload=void 0),this.wrapper.tensor}upload(e){let t=e;if(this.wrapper){if(this.wrapper.fallbackType)if(this.wrapper.fallbackType===`int32`)t=qt(e,this.wrapper.type),this.wrapper.setIsDataConverted(!0);else throw Error(`Unsupported fallback data type: ${this.wrapper.fallbackType}`);if(e.byteLength===this.wrapper.byteLength){this.wrapper.write(t);return}else R(`verbose`,()=>`Data size does not match tensor size. Releasing tensor.`),this.releaseTensor()}this.activeUpload?this.activeUpload.set(t):this.activeUpload=new Uint8Array(t)}async download(e){if(this.activeUpload){let t=this.wrapper?.isDataConverted?Jt(this.activeUpload,this.wrapper?.type):this.activeUpload;if(e){e instanceof ArrayBuffer?new Uint8Array(e).set(t):new Uint8Array(e.buffer,e.byteOffset,e.byteLength).set(t);return}else return t.buffer}if(!this.wrapper)throw Error(`Tensor has not been created.`);return e?this.wrapper.read(e):this.wrapper.read()}},tn=class{constructor(e){this.backend=e,this.tensorTrackersById=new Map,this.freeTensors=[],this.externalTensors=new Set}getMLContext(e){let t=this.backend.getMLContext(e);if(!t)throw Error(`MLContext not found for session.`);return t}getMLOpSupportLimits(e){return this.backend.getMLOpSupportLimits(e)}reserveTensorId(){let e=Xt();return this.tensorTrackersById.set(e,new en(this)),e}releaseTensorId(e){let t=this.tensorTrackersById.get(e);t&&(this.tensorTrackersById.delete(e),t.tensorWrapper&&this.releaseTensor(t.tensorWrapper))}async ensureTensor(e,t,n,r,i){R(`verbose`,()=>`[WebNN] TensorManager.ensureTensor {tensorId: ${t}, dataType: ${n}, shape: ${r}, copyOld: ${i}}`);let a=this.tensorTrackersById.get(t);if(!a)throw Error(`Tensor not found.`);return a.ensureTensor(e,n,r,i)}upload(e,t){let n=this.tensorTrackersById.get(e);if(!n)throw Error(`Tensor not found.`);n.upload(t)}async download(e,t){R(`verbose`,()=>`[WebNN] TensorManager.download {tensorId: ${e}, dstBuffer: ${t?.byteLength}}`);let n=this.tensorTrackersById.get(e);if(!n)throw Error(`Tensor not found.`);return n.download(t)}releaseTensorsForSession(e){for(let t of this.freeTensors)t.sessionId===e&&t.destroy();this.freeTensors=this.freeTensors.filter(t=>t.sessionId!==e)}registerTensor(e,t,n,r){let i=this.getMLContext(e),a=Xt(),o=new $t({sessionId:e,context:i,tensor:t,dataType:n,shape:r});return this.tensorTrackersById.set(a,new en(this,o)),this.externalTensors.add(o),a}async getCachedTensor(e,t,n,r,i,a,o){let s=this.getMLContext(e);for(let[r,i]of this.freeTensors.entries())if(i.canReuseTensor(s,t,n)){R(`verbose`,()=>`[WebNN] Reusing tensor {dataType: ${t}, ${o?`fallbackDataType: ${o},`:``} shape: ${n}`);let i=this.freeTensors.splice(r,1)[0];return i.sessionId=e,i}R(`verbose`,()=>`[WebNN] MLContext.createTensor {dataType: ${t}, ${o?`fallbackDataType: ${o},`:``} shape: ${n}}`);let c=await s.createTensor({dataType:o??t,shape:n,dimensions:n,usage:r,writable:i,readable:a});return new $t({sessionId:e,context:s,tensor:c,dataType:t,shape:n,fallbackDataType:o})}releaseTensor(e){this.externalTensors.has(e)&&this.externalTensors.delete(e),this.freeTensors.push(e)}},nn=(...e)=>new tn(...e)}),an,on,sn,cn=o(()=>{L(),ct(),Gt(),rn(),Lt(),an=new Map([[1,`float32`],[10,`float16`],[6,`int32`],[12,`uint32`],[7,`int64`],[13,`uint64`],[22,`int4`],[21,`uint4`],[3,`int8`],[2,`uint8`],[9,`uint8`]]),on=(e,t)=>{if(e===t)return!0;if(e===void 0||t===void 0)return!1;let n=Object.keys(e).sort(),r=Object.keys(t).sort();return n.length===r.length&&n.every((n,i)=>n===r[i]&&e[n]===t[n])},sn=class{constructor(e){this.tensorManager=nn(this),this.mlContextBySessionId=new Map,this.sessionIdsByMLContext=new Map,this.mlContextCache=[],this.sessionGraphInputs=new Map,this.sessionGraphOutputs=new Map,this.temporaryGraphInputs=[],this.temporaryGraphOutputs=[],this.temporarySessionTensorIds=new Map,this.mlOpSupportLimitsBySessionId=new Map,Ft(e.logLevel,!!e.debug)}get currentSessionId(){if(this.activeSessionId===void 0)throw Error(`No active session`);return this.activeSessionId}onRunStart(e){R(`verbose`,()=>`[WebNN] onRunStart {sessionId: ${e}}`),this.activeSessionId=e}onRunEnd(e){R(`verbose`,()=>`[WebNN] onRunEnd {sessionId: ${e}}`);let t=this.temporarySessionTensorIds.get(e);if(t){for(let e of t)R(`verbose`,()=>`[WebNN] releasing temporary tensor {tensorId: ${e}}`),this.tensorManager.releaseTensorId(e);this.temporarySessionTensorIds.delete(e),this.activeSessionId=void 0}}async createMLContext(e){if(e instanceof GPUDevice){let t=this.mlContextCache.findIndex(t=>t.gpuDevice===e);if(t!==-1)return this.mlContextCache[t].mlContext;{let t=await navigator.ml.createContext(e);return this.mlContextCache.push({gpuDevice:e,mlContext:t}),t}}else if(e===void 0){let e=this.mlContextCache.findIndex(e=>e.options===void 0&&e.gpuDevice===void 0);if(e!==-1)return this.mlContextCache[e].mlContext;{let e=await navigator.ml.createContext();return this.mlContextCache.push({mlContext:e}),e}}let t=this.mlContextCache.findIndex(t=>on(t.options,e));if(t!==-1)return this.mlContextCache[t].mlContext;{let t=await navigator.ml.createContext(e);return this.mlContextCache.push({options:e,mlContext:t}),t}}registerMLContext(e,t){this.mlContextBySessionId.set(e,t);let n=this.sessionIdsByMLContext.get(t);n||(n=new Set,this.sessionIdsByMLContext.set(t,n)),n.add(e),this.mlOpSupportLimitsBySessionId.has(e)||this.mlOpSupportLimitsBySessionId.set(e,t.opSupportLimits()),this.temporaryGraphInputs.length>0&&(this.sessionGraphInputs.set(e,this.temporaryGraphInputs),this.temporaryGraphInputs=[]),this.temporaryGraphOutputs.length>0&&(this.sessionGraphOutputs.set(e,this.temporaryGraphOutputs),this.temporaryGraphOutputs=[])}onReleaseSession(e){this.sessionGraphInputs.delete(e),this.sessionGraphOutputs.delete(e);let t=this.mlContextBySessionId.get(e);if(!t)return;this.tensorManager.releaseTensorsForSession(e),this.mlContextBySessionId.delete(e),this.mlOpSupportLimitsBySessionId.delete(e);let n=this.sessionIdsByMLContext.get(t);if(n.delete(e),n.size===0){this.sessionIdsByMLContext.delete(t);let e=this.mlContextCache.findIndex(e=>e.mlContext===t);e!==-1&&this.mlContextCache.splice(e,1)}}getMLContext(e){return this.mlContextBySessionId.get(e)}getMLOpSupportLimits(e){return this.mlOpSupportLimitsBySessionId.get(e)}reserveTensorId(){return this.tensorManager.reserveTensorId()}releaseTensorId(e){R(`verbose`,()=>`[WebNN] releaseTensorId {tensorId: ${e}}`),this.tensorManager.releaseTensorId(e)}async ensureTensor(e,t,n,r,i){let a=an.get(n);if(!a)throw Error(`Unsupported ONNX data type: ${n}`);return this.tensorManager.ensureTensor(e??this.currentSessionId,t,a,r,i)}async createTemporaryTensor(e,t,n){R(`verbose`,()=>`[WebNN] createTemporaryTensor {onnxDataType: ${t}, shape: ${n}}`);let r=an.get(t);if(!r)throw Error(`Unsupported ONNX data type: ${t}`);let i=this.tensorManager.reserveTensorId();await this.tensorManager.ensureTensor(e,i,r,n,!1);let a=this.temporarySessionTensorIds.get(e);return a?a.push(i):this.temporarySessionTensorIds.set(e,[i]),i}uploadTensor(e,t){if(!F().shouldTransferToMLTensor)throw Error(`Trying to upload to a MLTensor while shouldTransferToMLTensor is false`);R(`verbose`,()=>`[WebNN] uploadTensor {tensorId: ${e}, data: ${t.byteLength}}`),this.tensorManager.upload(e,t)}async downloadTensor(e,t){return this.tensorManager.download(e,t)}createMLTensorDownloader(e,t){return async()=>{let n=await this.tensorManager.download(e);return Wt(n,t)}}registerMLTensor(e,t,n,r){let i=an.get(n);if(!i)throw Error(`Unsupported ONNX data type: ${n}`);let a=this.tensorManager.registerTensor(e,t,i,r);return R(`verbose`,()=>`[WebNN] registerMLTensor {tensor: ${t}, dataType: ${i}, dimensions: ${r}} -> {tensorId: ${a}}`),a}registerMLConstant(e,t,n,r,i,a,o=!1){if(!a)throw Error(`External mounted files are not available.`);let s=e;e.startsWith(`./`)&&(s=e.substring(2));let c=a.get(s);if(!c)throw Error(`File with name ${s} not found in preloaded files.`);if(t+n>c.byteLength)throw Error(`Out of bounds: data offset and length exceed the external file data size.`);let l=c.slice(t,t+n).buffer,u;switch(i.dataType){case`float32`:u=new Float32Array(l);break;case`float16`:u=typeof Float16Array<`u`&&Float16Array.from?new Float16Array(l):new Uint16Array(l);break;case`int32`:u=new Int32Array(l);break;case`uint32`:u=new Uint32Array(l);break;case`int64`:if(o){let e=qt(new Uint8Array(l),`int64`);u=new Int32Array(e.buffer),i.dataType=`int32`}else u=new BigInt64Array(l);break;case`uint64`:u=new BigUint64Array(l);break;case`int8`:u=new Int8Array(l);break;case`int4`:case`uint4`:case`uint8`:u=new Uint8Array(l);break;default:throw Error(`Unsupported data type: ${i.dataType} in creating WebNN Constant from external data.`)}return R(`verbose`,()=>`[WebNN] registerMLConstant {dataType: ${i.dataType}, shape: ${i.shape}}} ${o?`(Note: it was int64 data type and registered to int32 as workaround)`:``}`),r.constant(i,u)}registerGraphInput(e){this.temporaryGraphInputs.push(e)}registerGraphOutput(e){this.temporaryGraphOutputs.push(e)}isGraphInput(e,t){let n=this.sessionGraphInputs.get(e);return n?n.includes(t):!1}isGraphOutput(e,t){let n=this.sessionGraphOutputs.get(e);return n?n.includes(t):!1}isGraphInputOutputTypeSupported(e,t,n=!0){let r=an.get(xt(t)),i=this.mlOpSupportLimitsBySessionId.get(e);return typeof r>`u`?!1:n?!!i?.input.dataTypes.includes(r):!!i?.output.dataTypes.includes(r)}flush(){}}}),ln=o(()=>{}),un,dn,fn,pn,mn,hn,gn,_n,vn,yn=o(()=>{Lt(),ln(),un=new Map([[64,250],[128,200],[256,200],[512,200],[2048,230],[4096,200],[8192,50],[16384,50],[32768,50],[65536,50],[131072,50],[262144,50],[524288,50],[1048576,50],[2097152,30],[4194304,20],[8388608,10],[12582912,10],[16777216,10],[26214400,15],[33554432,22],[44236800,2],[58982400,6],[67108864,6],[134217728,6],[167772160,6]]),dn=[],fn=e=>Math.ceil(Number(e)/16)*16,pn=e=>{for(let t=0;t<dn.length;t++){let n=dn[t];if(e<=n)return n}return Math.ceil(e/16)*16},mn=1,hn=()=>mn++,gn=async(e,t,n,r)=>{let i=fn(n),a=e.device.createBuffer({size:i,usage:GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});try{let o=e.getCommandEncoder();e.endComputePass(),o.copyBufferToBuffer(t,0,a,0,i),e.flush(),await a.mapAsync(GPUMapMode.READ);let s=a.getMappedRange();if(r){let e=r();return e.set(new Uint8Array(s,0,n)),e}else return new Uint8Array(s.slice(0,n))}finally{a.destroy()}},_n=class{constructor(e){this.backend=e,this.storageCache=new Map,this.freeBuffers=new Map,this.freeUniformBuffers=new Map,this.buffersPending=[],this.capturedPendingBuffers=new Map;for(let[e]of un)dn.push(e),this.freeBuffers.set(e,[]),this.freeUniformBuffers.set(e,[]);this.sessionCount=0}upload(e,t){let n=t.buffer,r=t.byteOffset,i=t.byteLength,a=fn(i),o=this.storageCache.get(e);if(!o)throw Error(`gpu data for uploading does not exist`);if(Number(o.originalSize)!==i)throw Error(`inconsistent data size. gpu data size=${o.originalSize}, data size=${i}`);let s=this.backend.device.createBuffer({mappedAtCreation:!0,size:a,usage:GPUBufferUsage.MAP_WRITE|GPUBufferUsage.COPY_SRC}),c=s.getMappedRange();new Uint8Array(c).set(new Uint8Array(n,r,i)),s.unmap();let l=this.backend.device.createCommandEncoder();l.copyBufferToBuffer(s,0,o.gpuData.buffer,0,a),this.backend.device.queue.submit([l.finish()]),s.destroy(),R(`verbose`,()=>`[WebGPU] GpuDataManager.upload(id=${e})`)}memcpy(e,t){let n=this.storageCache.get(e);if(!n)throw Error(`source gpu data for memcpy does not exist`);let r=this.storageCache.get(t);if(!r)throw Error(`destination gpu data for memcpy does not exist`);if(n.originalSize!==r.originalSize)throw Error(`inconsistent source and destination gpu data size`);let i=fn(n.originalSize),a=this.backend.getCommandEncoder();this.backend.endComputePass(),a.copyBufferToBuffer(n.gpuData.buffer,0,r.gpuData.buffer,0,i)}registerExternalBuffer(e,t,n){let r;if(n){if(r=n[0],e===n[1])return R(`verbose`,()=>`[WebGPU] GpuDataManager.registerExternalBuffer(size=${t}) => id=${r}, buffer is the same, skip.`),r;if(this.backend.capturedCommandList.has(this.backend.currentSessionId))throw Error(`Registering a different external buffer under graph capture mode is not supported yet. | |
| Please use the previous external buffer!`)}else r=hn();return this.storageCache.set(r,{gpuData:{id:r,type:0,buffer:e},originalSize:t}),R(`verbose`,()=>`[WebGPU] GpuDataManager.registerExternalBuffer(size=${t}) => id=${r}, registered.`),r}unregisterExternalBuffer(e){e!==void 0&&(this.storageCache.delete(e),R(`verbose`,()=>`[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${e}`))}create(e,t=GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST){let n=pn(e),r,i=(t&GPUBufferUsage.STORAGE)===GPUBufferUsage.STORAGE,a=(t&GPUBufferUsage.UNIFORM)===GPUBufferUsage.UNIFORM;if(i||a){let e=(i?this.freeBuffers:this.freeUniformBuffers).get(n);r=e&&e.length>0?e.pop():this.backend.device.createBuffer({size:n,usage:t})}else r=this.backend.device.createBuffer({size:n,usage:t});let o={id:hn(),type:0,buffer:r};return this.storageCache.set(o.id,{gpuData:o,originalSize:Number(e)}),R(`verbose`,()=>`[WebGPU] GpuDataManager.create(size=${e}) => id=${o.id}`),o}get(e){return this.storageCache.get(e)?.gpuData}release(e){let t=typeof e==`bigint`?Number(e):e,n=this.storageCache.get(t);if(!n){if(this.storageCache.size===0)return 0;throw Error(`releasing data does not exist`)}return R(`verbose`,()=>`[WebGPU] GpuDataManager.release(id=${t}), gpuDataId=${n.gpuData.id}`),this.storageCache.delete(t),this.buffersPending.push(n.gpuData.buffer),n.originalSize}async download(e,t){let n=this.storageCache.get(Number(e));if(!n)throw Error(`data does not exist`);await gn(this.backend,n.gpuData.buffer,n.originalSize,t)}refreshPendingBuffers(){if(this.buffersPending.length!==0)if(this.backend.sessionStatus===`default`){for(let e of this.buffersPending){let t=un.get(e.size);if((e.usage&GPUBufferUsage.STORAGE)===GPUBufferUsage.STORAGE){let n=this.freeBuffers.get(e.size)||[];t===void 0||n.length>=t?e.destroy():n.push(e)}else if((e.usage&GPUBufferUsage.UNIFORM)===GPUBufferUsage.UNIFORM){let n=this.freeUniformBuffers.get(e.size)||[];t===void 0||n.length>=t?e.destroy():n.push(e)}else e.destroy()}this.buffersPending=[]}else{let e=this.capturedPendingBuffers.get(this.backend.currentSessionId);e||(e=[],this.capturedPendingBuffers.set(this.backend.currentSessionId,e));for(let t of this.buffersPending)e.push(t);this.buffersPending=[]}}dispose(){this.freeBuffers.forEach(e=>{e.forEach(e=>{e.destroy()})}),this.freeUniformBuffers.forEach(e=>{e.forEach(e=>{e.destroy()})}),this.storageCache.forEach(e=>{e.gpuData.buffer.destroy()}),this.capturedPendingBuffers.forEach(e=>{e.forEach(e=>{e.destroy()})}),this.storageCache=new Map,this.freeBuffers=new Map,this.freeUniformBuffers=new Map,this.capturedPendingBuffers=new Map}onCreateSession(){this.sessionCount+=1}onReleaseSession(e){let t=this.capturedPendingBuffers.get(e);t&&(t.forEach(e=>{e.destroy()}),this.capturedPendingBuffers.delete(e)),--this.sessionCount,this.sessionCount===0&&(R(`warning`,()=>`[WebGPU] Clearing webgpu buffer cache`),this.storageCache.forEach(e=>{e.gpuData.buffer.destroy()}),this.storageCache=new Map)}},vn=(...e)=>new _n(...e)}),bn,V,H=o(()=>{bn=class{constructor(e){Object.assign(this,e)}get cacheKey(){return this.key||=Object.getOwnPropertyNames(this).sort().map(e=>`${this[e]}`).join(`;`),this.key}},V=e=>new bn(e)}),xn,Sn,U,Cn,W,G,wn,Tn,En,K,Dn,q,J,On,kn,An,jn,Y=o(()=>{L(),B(),xn=64,Sn=(e,t)=>{if(t===3)throw Error(`vec3 has same alignment as vec4, use vec4 instead`);switch(Number(e)){case 10:return t>1?`vec${t}<f16>`:`f16`;case 1:return t>1?`vec${t}<f32>`:`f32`;case 6:return t>1?`vec${t}<i32>`:`i32`;case 12:return t>1?`vec${t}<u32>`:`u32`;case 7:if(t>1)throw Error(`currently not supported vecX of uint64 yet`);return[`vec2<u32>`,`i32`];case 13:if(t>1)throw Error(`currently not supported vecX of uint64 yet`);return[`vec2<u32>`,`u32`];case 9:if(t!==4)throw Error(`bool must be vec4`);return[`u32`,`vec4<bool>`];case 22:return`i32`;case 21:return`u32`;default:throw Error(`Unknown data type: ${e}`)}},U=(e,t=1)=>{let n=Sn(e,t);return typeof n==`string`?n:n[0]},Cn=(e,t=1)=>{let n=Sn(e,t);return typeof n==`string`?n:n[1]},W=(...e)=>{let t=[];return e.forEach(e=>{e.length!==0&&t.push({type:12,data:e},{type:12,data:z.computeStrides(e)})}),t},G=e=>e%4==0?4:e%2==0?2:1,wn=(e=`f32`,t,n=`0`)=>!t||t===1?`${e}(${n})`:`vec${t}<${e}>(${n})`,Tn=(e,t,n)=>e===`f32`?n:t===1?`f32(${n})`:`vec${t}<f32>(${n})`,En=(e,t)=>t===4?`(${e}.x + ${e}.y + ${e}.z + ${e}.w)`:t===2?`(${e}.x + ${e}.y)`:t===3?`(${e}.x + ${e}.y + ${e}.z)`:e,K=(e,t,n,r)=>e.startsWith(`uniforms.`)&&n>4?typeof t==`string`?r===`f16`?`${e}[(${t}) / 8][(${t}) % 8 / 4][(${t}) % 8 % 4]`:`${e}[(${t}) / 4][(${t}) % 4]`:r===`f16`?`${e}[${Math.floor(t/8)}][${Math.floor(t%8/4)}][${t%8%4}]`:`${e}[${Math.floor(t/4)}][${t%4}]`:n>1?`${e}[${t}]`:e,Dn=(e,t,n,r,i)=>{let a=typeof n==`number`,o=a?n:n.length,s=[...Array(o).keys()],c=o<2?`u32`:o<=4?`vec${o}<u32>`:`array<u32, ${o}>`,l=Sn(t,i),u=typeof l==`string`?l:l[1],d={indices:c,value:u,storage:typeof l==`string`?l:l[0],tensor:t},f=e=>typeof e==`string`?e:`${e}u`,p={offsetToIndices:!1,indicesToOffset:!1,broadcastedIndicesToOffset:!1,set:!1,setByIndices:!1,get:!1,getByIndices:!1},m=a?`uniforms.`:``,h=`${m}${e}_shape`,g=`${m}${e}_strides`,_=``;for(let e=0;e<o-1;e++)_+=` | |
| let dim${e} = current / ${K(g,e,o)}; | |
| let rest${e} = current % ${K(g,e,o)}; | |
| indices[${e}] = dim${e}; | |
| current = rest${e}; | |
| `;_+=`indices[${o-1}] = current;`;let v=o<2?``:` | |
| fn o2i_${e}(offset: u32) -> ${d.indices} { | |
| var indices: ${d.indices}; | |
| var current = offset; | |
| ${_} | |
| return indices; | |
| }`,y=t=>(p.offsetToIndices=!0,o<2?t:`o2i_${e}(${t})`),b=[];if(o>=2)for(let e=o-1;e>=0;e--)b.push(`${K(g,e,o)} * (indices[${e}])`);let x=o<2?``:` | |
| fn i2o_${e}(indices: ${d.indices}) -> u32 { | |
| return ${b.join(`+`)}; | |
| }`,S=t=>(p.indicesToOffset=!0,o<2?t:`i2o_${e}(${t})`),C=(...e)=>o===0?`0u`:`${d.indices}(${e.map(f).join(`,`)})`,w=(e,t)=>o<2?`${e}`:`${K(e,t,o)}`,ee=(e,t,n)=>o<2?`${e}=${n};`:`${K(e,t,o)}=${n};`,T={},E=(t,n)=>{p.broadcastedIndicesToOffset=!0;let r=`${n.name}broadcastedIndicesTo${e}Offset`;if(r in T)return`${r}(${t})`;let i=[];for(let e=o-1;e>=0;e--){let t=n.indicesGet(`outputIndices`,e+n.rank-o);i.push(`${w(g,e)} * (${t} % ${w(h,e)})`)}return T[r]=`fn ${r}(outputIndices: ${n.type.indices}) -> u32 { | |
| return ${i.length>0?i.join(`+`):`0u`}; | |
| }`,`${r}(${t})`},D=(t,n)=>(()=>{if(d.storage===d.value)return`${e}[${t}]=${n};`;if(d.storage===`vec2<u32>`&&d.value===`i32`)return`${e}[${t}]=vec2<u32>(u32(${n}), select(0u, 0xFFFFFFFFu, ${n} < 0));`;if(d.storage===`vec2<u32>`&&d.value===`u32`)return`${e}[${t}]=vec2<u32>(u32(${n}), 0u);`;if(d.storage===`u32`&&d.value===`vec4<bool>`)return`${e}[${t}]=dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(${n}));`;throw Error(`not supported combination of storage type ${d.storage} and value type ${d.value} yet`)})(),O=t=>(()=>{if(d.storage===d.value)return`${e}[${t}]`;if(d.storage===`vec2<u32>`&&d.value===`i32`)return`i32(${e}[${t}].x)`;if(d.storage===`vec2<u32>`&&d.value===`u32`)return`u32(${e}[${t}].x)`;if(d.storage===`u32`&&d.value===`vec4<bool>`)return`vec4<bool>(bool(${e}[${t}] & 0xFFu), bool(${e}[${t}] & 0xFF00u), bool(${e}[${t}] & 0xFF0000u), bool(${e}[${t}] & 0xFF000000u))`;throw Error(`not supported combination of storage type ${d.storage} and value type ${d.value} yet`)})(),k=o<2?``:` | |
| fn get_${e}ByIndices(indices: ${d.indices}) -> ${u} { | |
| return ${O(`i2o_${e}(indices)`)}; | |
| }`,A=o<2?``:` | |
| fn get_${e}(${s.map(e=>`d${e}: u32`).join(`, `)}) -> ${u} { | |
| return get_${e}ByIndices(${C(s.map(e=>`d${e}`).join(`, `))}); | |
| }`,te=(...t)=>{if(t.length!==o)throw Error(`indices length must be ${o}`);let n=t.map(f).join(`,`);return o===0?O(`0u`):o===1?O(n[0]):(p.get=!0,p.getByIndices=!0,p.indicesToOffset=!0,`get_${e}(${n})`)},ne=t=>o<2?O(t):(p.getByIndices=!0,p.indicesToOffset=!0,`get_${e}ByIndices(${t})`),re=o<2?``:` | |
| fn set_${e}ByIndices(indices: ${d.indices}, value: ${u}) { | |
| ${D(`i2o_${e}(indices)`,`value`)} | |
| }`,ie=o<2?``:` | |
| fn set_${e}(${s.map(e=>`d${e}: u32`).join(`, `)}, value: ${u}) { | |
| set_${e}ByIndices(${C(s.map(e=>`d${e}`).join(`, `))}, value); | |
| }`;return{impl:()=>{let e=[],t=!1;return p.offsetToIndices&&(e.push(v),t=!0),p.indicesToOffset&&(e.push(x),t=!0),p.broadcastedIndicesToOffset&&(Object.values(T).forEach(t=>e.push(t)),t=!0),p.set&&(e.push(ie),t=!0),p.setByIndices&&(e.push(re),t=!0),p.get&&(e.push(A),t=!0),p.getByIndices&&(e.push(k),t=!0),!a&&t&&e.unshift(`const ${h} = ${d.indices}(${n.join(`,`)});`,`const ${g} = ${d.indices}(${z.computeStrides(n).join(`,`)});`),e.join(` | |
| `)},type:d,offsetToIndices:y,indicesToOffset:S,broadcastedIndicesToOffset:E,indices:C,indicesGet:w,indicesSet:ee,set:(...t)=>{if(t.length!==o+1)throw Error(`indices length must be ${o}`);let n=t[o];if(typeof n!=`string`)throw Error(`value must be string`);let r=t.slice(0,o).map(f).join(`,`);return o===0?D(`0u`,n):o===1?D(r[0],n):(p.set=!0,p.setByIndices=!0,p.indicesToOffset=!0,`set_${e}(${r}, ${n})`)},setByOffset:D,setByIndices:(t,n)=>o<2?D(t,n):(p.setByIndices=!0,p.indicesToOffset=!0,`set_${e}ByIndices(${t}, ${n});`),get:te,getByOffset:O,getByIndices:ne,usage:r,name:e,strides:g,shape:h,rank:o}},q=(e,t,n,r=1)=>Dn(e,t,n,`input`,r),J=(e,t,n,r=1)=>Dn(e,t,n,`output`,r),On=(e,t,n)=>Dn(e,t,n,`atomicOutput`,1),kn=(e,t,n,r=1)=>Dn(e,t,n,`internal`,r),An=class{constructor(e,t){this.normalizedDispatchGroup=e,this.limits=t,this.internalVariables=[],this.variables=[],this.uniforms=[],this.variableIndex=0}guardAgainstOutOfBoundsWorkgroupSizes(e){return`if (global_idx >= ${typeof e==`number`?`${e}u`:e}) { return; }`}mainStart(e=xn){let t=typeof e==`number`?e:e[0],n=typeof e==`number`?1:e[1],r=typeof e==`number`?1:e[2];if(t>this.limits.maxComputeWorkgroupSizeX||n>this.limits.maxComputeWorkgroupSizeY||r>this.limits.maxComputeWorkgroupSizeZ)throw Error(`workgroup size [${t}, ${n}, ${r}] exceeds the maximum workgroup size [${this.limits.maxComputeWorkgroupSizeX}, ${this.limits.maxComputeWorkgroupSizeY}, ${this.limits.maxComputeWorkgroupSizeZ}].`);if(t*n*r>this.limits.maxComputeInvocationsPerWorkgroup)throw Error(`workgroup size [${t}, ${n}, ${r}] exceeds the maximum workgroup invocations ${this.limits.maxComputeInvocationsPerWorkgroup}.`);let i=this.normalizedDispatchGroup[1]===1&&this.normalizedDispatchGroup[2]===1;return`@compute @workgroup_size(${t}, ${n}, ${r}) | |
| fn main(${i?`@builtin(global_invocation_id) global_id : vec3<u32>, | |
| @builtin(workgroup_id) workgroup_id : vec3<u32>, | |
| @builtin(local_invocation_index) local_idx : u32, | |
| @builtin(local_invocation_id) local_id : vec3<u32>`:`@builtin(global_invocation_id) global_id : vec3<u32>, | |
| @builtin(local_invocation_id) local_id : vec3<u32>, | |
| @builtin(local_invocation_index) local_idx : u32, | |
| @builtin(workgroup_id) workgroup_id : vec3<u32>, | |
| @builtin(num_workgroups) num_workgroups : vec3<u32>`}) { | |
| ${i?`let global_idx = global_id.x; | |
| let workgroup_index = workgroup_id.x;`:`let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + | |
| workgroup_id.y * num_workgroups[0] + workgroup_id.x; | |
| let global_idx = workgroup_index * ${t*n*r}u + local_idx;`} | |
| `}appendVariableUniforms(e){e.rank!==0&&(e.shape.startsWith(`uniforms.`)&&this.uniforms.push({name:e.shape.replace(`uniforms.`,``),type:`u32`,length:e.rank}),e.strides.startsWith(`uniforms.`)&&this.uniforms.push({name:e.strides.replace(`uniforms.`,``),type:`u32`,length:e.rank}))}declareVariable(e,t){if(e.usage===`internal`)throw Error(`cannot use internal variable with declareVariable(). use registerInternalVariables() instead.`);this.variables.push(e),this.appendVariableUniforms(e);let n=e.usage===`input`?`read`:`read_write`,r=e.usage===`atomicOutput`?`atomic<i32>`:e.type.storage;return`@group(0) @binding(${t}) var<storage, ${n}> ${e.name}: array<${r}>;`}declareVariables(...e){return e.map(e=>this.declareVariable(e,this.variableIndex++)).join(` | |
| `)}registerInternalVariable(e){if(e.usage!==`internal`)throw Error(`cannot use input or output variable with registerInternalVariable(). use declareVariables() instead.`);this.internalVariables.push(e),this.appendVariableUniforms(e)}registerInternalVariables(...e){return e.forEach(e=>this.registerInternalVariable(e)),this}registerUniform(e,t,n=1){return this.uniforms.push({name:e,type:t,length:n}),this}registerUniforms(e){return this.uniforms=this.uniforms.concat(e),this}uniformDeclaration(){if(this.uniforms.length===0)return``;let e=[];for(let{name:t,type:n,length:r}of this.uniforms)if(r&&r>4)n===`f16`?e.push(`@align(16) ${t}:array<mat2x4<${n}>, ${Math.ceil(r/8)}>`):e.push(`${t}:array<vec4<${n}>, ${Math.ceil(r/4)}>`);else{let i=r==null||r===1?n:`vec${r}<${n}>`;e.push(`${t}:${i}`)}return` | |
| struct Uniforms { ${e.join(`, `)} }; | |
| @group(0) @binding(${this.variableIndex}) var<uniform> uniforms: Uniforms;`}get additionalImplementations(){return this.uniformDeclaration()+this.variables.map(e=>e.impl()).join(` | |
| `)+this.internalVariables.map(e=>e.impl()).join(` | |
| `)}get variablesInfo(){if(this.uniforms.length===0)return;let e=e=>[12,10,1,6][[`u32`,`f16`,`f32`,`i32`].indexOf(e)];return this.uniforms.map(t=>[e(t.type),t.length??1])}},jn=(e,t)=>new An(e,t)}),Mn,Nn,Pn,Fn,In,Ln,Rn,zn,Bn,Vn=o(()=>{L(),B(),H(),Y(),Mn=(e,t)=>{if(!e||e.length!==1)throw Error(`Transpose requires 1 input.`);if(t.length!==0&&t.length!==e[0].dims.length)throw Error(`perm size ${t.length} does not match input rank ${e[0].dims.length}`)},Nn=(e,t)=>t.length===0?[...Array(e).keys()].reverse():t,Pn=(e,t)=>z.sortBasedOnPerm(e,Nn(e.length,t)),Fn=(e,t,n,r)=>{let i=`fn perm(i: ${r.type.indices}) -> ${n.type.indices} { | |
| var a: ${n.type.indices};`;for(let n=0;n<t;++n)i+=`a[${e[n]}]=i[${n}];`;return i+=`return a;}`},In=(e,t)=>{let n=[],r=[];for(let i=0;i<e.length;++i)e[i]!==1&&n.push(e[i]),e[t[i]]!==1&&r.push(t[i]);return{newShape:n,newPerm:r}},Ln=(e,t)=>{let n=0;for(let r=0;r<e.length;++r)if(t[e[r]]!==1){if(e[r]<n)return!1;n=e[r]}return!0},Rn=(e,t)=>{let n=e.dataType,r=e.dims.length,i=Nn(r,t),a=Pn(e.dims,i),o=e.dims,s=a,c=r<2||Ln(i,e.dims),l;if(c)return l=e=>{let t=q(`input`,n,o,4),r=J(`output`,n,s,4);return` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(t,r)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| output[global_idx] = input[global_idx]; | |
| }`},{name:`TransposeCopy`,shaderCache:{inputDependencies:[`type`]},getRunData:()=>{let t=z.size(a);return{outputs:[{dims:a,dataType:e.dataType}],dispatchGroup:{x:Math.ceil(t/64/4)},programUniforms:[{type:12,data:Math.ceil(t/4)}]}},getShaderSource:l};let{newShape:u,newPerm:d}=In(e.dims,i),f=z.areEqual(d,[2,3,1]),p=z.areEqual(d,[3,1,2]);return u.length===2||f||p?(o=f?[u[0],u[1]*u[2]]:p?[u[0]*u[1],u[2]]:u,s=[o[1],o[0]],l=e=>{let t=q(`a`,n,o.length),r=J(`output`,n,s.length);return` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(t,r)} | |
| var<workgroup> tile : array<array<${r.type.value}, 17>, 16>; | |
| ${e.mainStart([16,16,1])} | |
| let stride = (uniforms.output_shape[1] - 1) / 16 + 1; | |
| let workgroup_id_x = workgroup_index % stride; | |
| let workgroup_id_y = workgroup_index / stride; | |
| let input_col = workgroup_id_y * 16u + local_id.x; | |
| let input_row = workgroup_id_x * 16u + local_id.y; | |
| if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) { | |
| tile[local_id.y][local_id.x] = ${t.getByIndices(`${t.type.indices}(input_row, input_col)`)}; | |
| } | |
| workgroupBarrier(); | |
| let output_col = workgroup_id_x * 16u + local_id.x; | |
| let output_row = workgroup_id_y * 16u + local_id.y; | |
| if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) { | |
| ${r.setByIndices(`${r.type.indices}(output_row, output_col)`,`tile[local_id.x][local_id.y]`)} | |
| } | |
| }`},{name:`TransposeShared`,shaderCache:{inputDependencies:[`type`]},getRunData:()=>{let t=z.size(a);return{outputs:[{dims:a,dataType:e.dataType}],dispatchGroup:{x:Math.ceil(s[1]/16),y:Math.ceil(s[0]/16)},programUniforms:[{type:12,data:t},...W(o,s)]}},getShaderSource:l}):(l=e=>{let t=q(`a`,n,o.length),a=J(`output`,n,s.length);return` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(t,a)} | |
| ${Fn(i,r,t,a)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let indices = ${a.offsetToIndices(`global_idx`)}; | |
| let aIndices = perm(indices); | |
| ${a.setByOffset(`global_idx`,t.getByIndices(`aIndices`))} | |
| }`},{name:`Transpose`,shaderCache:{hint:`${t}`,inputDependencies:[`rank`]},getRunData:()=>{let t=z.size(a);return{outputs:[{dims:a,dataType:e.dataType}],dispatchGroup:{x:Math.ceil(t/64)},programUniforms:[{type:12,data:t},...W(o,s)]}},getShaderSource:l})},zn=(e,t)=>{Mn(e.inputs,t.perm),e.compute(Rn(e.inputs[0],t.perm))},Bn=e=>V({perm:e.perm})}),Hn,Un,Wn,Gn,Kn,qn,Jn,Yn,Xn,Zn,Qn,$n,er,tr,nr,rr,ir,ar,or,sr,cr,lr=o(()=>{L(),B(),Y(),Nr(),Vn(),Hn={max:`select(bestValue, candidate, candidate > bestValue)`,min:`select(bestValue, candidate, candidate < bestValue)`,mean:`bestValue + candidate`,sum:`bestValue + candidate`,prod:`bestValue * candidate`,sumSquare:`bestValue + candidate * candidate`,logSumExp:`bestValue + exp(candidate)`,l1:`bestValue + abs(candidate)`,l2:`bestValue + candidate * candidate`,logSum:`bestValue + candidate`},Un={max:`select(bestValue, candidate, candidate > bestValue)`,min:`select(bestValue, candidate, candidate < bestValue)`,mean:`bestValue + candidate`,sum:`bestValue + candidate`,prod:`bestValue * candidate`,sumSquare:`bestValue + candidate`,logSumExp:`bestValue + candidate`,l1:`bestValue + candidate`,l2:`bestValue + candidate`,logSum:`bestValue + candidate`},Wn={max:`_A[offset]`,min:`_A[offset]`,mean:`0`,sum:`0`,prod:`1`,sumSquare:`0`,logSumExp:`0`,l1:`0`,l2:`0`,logSum:`0`},Gn={max:`bestValue`,min:`bestValue`,sum:`bestValue`,prod:`bestValue`,sumSquare:`bestValue`,logSumExp:`log(bestValue)`,l1:`bestValue`,l2:`sqrt(bestValue)`,logSum:`log(bestValue)`},Kn=(e,t)=>{let n=[];for(let r=t-e;r<t;++r)n.push(r);return n},qn=(e,t)=>{let n=[],r=e.length;for(let i=0;i<r;i++)t.indexOf(i)===-1&&n.push(e[i]);return[n,t.map(t=>e[t])]},Jn=(e,t)=>{let n=e.length+t.length,r=[],i=0;for(let a=0;a<n;a++)t.indexOf(a)===-1?r.push(e[i++]):r.push(1);return r},Yn=(e,t)=>{for(let n=0;n<e.length;++n)if(e[e.length-n-1]!==t-1-n)return!1;return!0},Xn=(e,t)=>{let n=[];if(!Yn(e,t)){for(let r=0;r<t;++r)e.indexOf(r)===-1&&n.push(r);e.forEach(e=>n.push(e))}return n},Zn=(e,t,n,r,i,a,o)=>{let s=n[0].dims,c=z.size(a),l=z.size(o),u=q(`_A`,n[0].dataType,s),d=J(`output`,i,a),f=64;c===1&&(f=256);let p=` | |
| var<workgroup> aBestValues : array<f32, ${f}>; | |
| `;return{name:e,shaderCache:{hint:`${t};${f}`,inputDependencies:[`type`]},getShaderSource:e=>` | |
| ${e.registerUniform(`reduceSize`,`u32`).declareVariables(u,d)} | |
| ${p} | |
| fn DIV_CEIL(a : u32, b : u32) -> u32 { | |
| return ((a - 1u) / b + 1u); | |
| } | |
| ${e.mainStart(f)} | |
| let outputIndex = global_idx / ${f}; | |
| let offset = outputIndex * uniforms.reduceSize; | |
| var bestValue = f32(${Wn[r]}); | |
| let Length = uniforms.reduceSize; | |
| for (var k = local_idx; k < Length; k = k + ${f}) { | |
| let candidate = f32(${u.getByOffset(`offset + k`)}); | |
| bestValue = ${Hn[r]}; | |
| } | |
| aBestValues[local_idx] = bestValue; | |
| workgroupBarrier(); | |
| var reduceSize = min(Length, ${f}u); | |
| for (var currentSize = reduceSize / 2u; reduceSize > 1u; | |
| currentSize = reduceSize / 2u) { | |
| let interval = DIV_CEIL(reduceSize, 2u); | |
| if (local_idx < currentSize) { | |
| let candidate = aBestValues[local_idx + interval]; | |
| bestValue = ${Un[r]}; | |
| aBestValues[local_idx] = bestValue; | |
| } | |
| reduceSize = interval; | |
| workgroupBarrier(); | |
| } | |
| if (local_idx == 0u) { | |
| ${d.setByOffset(`outputIndex`,`${r===`mean`?`${d.type.storage}(bestValue / f32(uniforms.reduceSize))`:`${d.type.storage}(${Gn[r]})`}`)}; | |
| } | |
| }`,getRunData:()=>({outputs:[{dims:a,dataType:i}],dispatchGroup:{x:c},programUniforms:[{type:12,data:l}]})}},Qn=(e,t,n,r)=>{let i=e.inputs.length===1?n:pr(e.inputs,n),a=i.axes;a.length===0&&!i.noopWithEmptyAxes&&(a=e.inputs[0].dims.map((e,t)=>t));let o=z.normalizeAxes(a,e.inputs[0].dims.length),s=o,c=e.inputs[0],l=Xn(s,e.inputs[0].dims.length);l.length>0&&(c=e.compute(Rn(e.inputs[0],l),{inputs:[0],outputs:[-1]})[0],s=Kn(s.length,c.dims.length));let[u,d]=qn(c.dims,s),f=u;i.keepDims&&(f=Jn(u,o)),e.compute(Zn(t,i.cacheKey,[c],r,e.inputs[0].dataType,f,d),{inputs:[c]})},$n=(e,t)=>{Qn(e,`ReduceMeanShared`,t,`mean`)},er=(e,t)=>{Qn(e,`ReduceL1Shared`,t,`l1`)},tr=(e,t)=>{Qn(e,`ReduceL2Shared`,t,`l2`)},nr=(e,t)=>{Qn(e,`ReduceLogSumExpShared`,t,`logSumExp`)},rr=(e,t)=>{Qn(e,`ReduceMaxShared`,t,`max`)},ir=(e,t)=>{Qn(e,`ReduceMinShared`,t,`min`)},ar=(e,t)=>{Qn(e,`ReduceProdShared`,t,`prod`)},or=(e,t)=>{Qn(e,`ReduceSumShared`,t,`sum`)},sr=(e,t)=>{Qn(e,`ReduceSumSquareShared`,t,`sumSquare`)},cr=(e,t)=>{Qn(e,`ReduceLogSumShared`,t,`logSum`)}}),ur,dr,fr,pr,mr,hr,gr,_r,vr,yr,br,X,xr,Sr,Z,Cr,Q,wr,Tr,Er,Dr,Or,kr,Ar,jr,Mr,Nr=o(()=>{L(),B(),H(),Y(),lr(),ur=e=>{if(!e||e.length===0||e.length>2)throw Error(`Reduce op requires 1 or 2 inputs.`);if(e.length===2&&e[1].dims.length!==1)throw Error(`Invalid axes input dims.`)},dr=e=>[``,``,`var value = ${e.getByIndices(`input_indices`)};`,``],fr=(e,t,n,r,i,a,o=!1,s=!1)=>{let c=[],l=n[0].dims,u=l.length,d=z.normalizeAxes(i,u),f=!s&&d.length===0;l.forEach((e,t)=>{f||d.indexOf(t)>=0?o&&c.push(1):c.push(e)});let p=c.length,m=z.size(c);return{name:e,shaderCache:t,getShaderSource:e=>{let t=[],i=q(`_A`,n[0].dataType,u),s=J(`output`,a,p),c=r(i,s,d),m=c[2];for(let e=0,n=0;e<u;e++)f||d.indexOf(e)>=0?(o&&n++,m=`for(var j${e}: u32 = 0; j${e} < ${l[e]}; j${e}++) { | |
| ${c[2].includes(`last_index`)?`let last_index = j${e};`:``} | |
| ${i.indicesSet(`input_indices`,e,`j${e}`)} | |
| ${m} | |
| }`):(t.push(`${i.indicesSet(`input_indices`,e,s.indicesGet(`output_indices`,n))};`),n++);return` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(i,s)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| var input_indices: ${i.type.indices}; | |
| let output_indices = ${s.offsetToIndices(`global_idx`)}; | |
| ${t.join(` | |
| `)} | |
| ${c[0]} // init ops for reduce max/min | |
| ${c[1]} | |
| ${m} | |
| ${c[3]} | |
| ${c.length===4?s.setByOffset(`global_idx`,`value`):c.slice(4).join(` | |
| `)} | |
| }`},getRunData:()=>({outputs:[{dims:c,dataType:a}],dispatchGroup:{x:Math.ceil(m/64)},programUniforms:[{type:12,data:m},...W(l,c)]})}},pr=(e,t)=>{let n=[];return e[1].dims[0]>0&&e[1].getBigInt64Array().forEach(e=>n.push(Number(e))),V({axes:n,keepDims:t.keepDims,noopWithEmptyAxes:t.noopWithEmptyAxes})},mr=(e,t,n,r)=>{let i=e.inputs,a=i.length===1?n:pr(i,n);e.compute(fr(t,{hint:a.cacheKey,inputDependencies:[`rank`]},[i[0]],a.noopWithEmptyAxes&&a.axes.length===0?dr:r,a.axes,i[0].dataType,a.keepDims,a.noopWithEmptyAxes),{inputs:[0]})},hr=(e,t)=>{ur(e.inputs),mr(e,`ReduceLogSum`,t,(e,t)=>[`var value = ${t.type.storage}(0);`,``,`value += ${e.getByIndices(`input_indices`)};`,`value = log(value);`])},gr=(e,t)=>{ur(e.inputs),mr(e,`ReduceL1`,t,(e,t)=>[`var value = ${t.type.storage}(0);`,``,`value += abs(${e.getByIndices(`input_indices`)});`,``])},_r=(e,t)=>{ur(e.inputs),mr(e,`ReduceL2`,t,(e,t)=>[`var t = ${t.type.value}(0); var value = ${t.type.value}(0);`,``,`t = ${e.getByIndices(`input_indices`)}; value += (t * t);`,`value = sqrt(value);`])},vr=(e,t)=>{ur(e.inputs),mr(e,`ReduceLogSumExp`,t,(e,t)=>[`var value = ${t.type.storage}(0);`,``,`value += exp(${e.getByIndices(`input_indices`)});`,`value = log(value);`])},yr=(e,t)=>{ur(e.inputs),mr(e,`ReduceMax`,t,(e,t,n)=>{let r=[];for(let t=0;t<e.rank;t++)(n.indexOf(t)>=0||n.length===0)&&r.push(e.indicesSet(`input_indices`,t,0));return[`${r.join(` | |
| `)}`,`var value = ${e.getByIndices(`input_indices`)};`,`value = max(value, ${e.getByIndices(`input_indices`)});`,``]})},br=(e,t)=>{ur(e.inputs),mr(e,`ReduceMean`,t,(t,n,r)=>{let i=1;for(let n=0;n<t.rank;n++)(r.indexOf(n)>=0||r.length===0)&&(i*=e.inputs[0].dims[n]);return[`var sum = f32(0);`,``,`sum += f32(${t.getByIndices(`input_indices`)});`,`let value = ${n.type.value}(sum / ${i});`]})},X=(e,t)=>{ur(e.inputs),mr(e,`ReduceMin`,t,(e,t,n)=>{let r=[];for(let t=0;t<e.rank;t++)(n.indexOf(t)>=0||n.length===0)&&r.push(`input_indices[${t}] = 0;`);return[`${r.join(` | |
| `)}`,`var value = ${e.getByIndices(`input_indices`)};`,`value = min(value, ${e.getByIndices(`input_indices`)});`,``]})},xr=(e,t)=>{ur(e.inputs),mr(e,`ReduceProd`,t,(e,t)=>[`var value = ${t.type.storage}(1);`,``,`value *= ${e.getByIndices(`input_indices`)};`,``])},Sr=(e,t)=>{ur(e.inputs),mr(e,`ReduceSum`,t,(e,t)=>[`var value = ${t.type.storage}(0);`,``,`value += ${e.getByIndices(`input_indices`)};`,``])},Z=(e,t)=>{ur(e.inputs),mr(e,`ReduceSumSquare`,t,(e,t)=>[`var t = ${t.type.value}(0); var value = ${t.type.value}(0);`,``,`t = ${e.getByIndices(`input_indices`)}; value += t * t;`,``])},Cr=(e,t,n)=>{if(t.length===0)return n;let r=1,i=1;for(let n=0;n<t.length;n++)t.indexOf(n)===-1?r*=e[n]:i*=e[n];return i<32&&r>1024},Q=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?br(e,t):$n(e,t)},wr=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?gr(e,t):er(e,t)},Tr=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?_r(e,t):tr(e,t)},Er=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?vr(e,t):nr(e,t)},Dr=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?yr(e,t):rr(e,t)},Or=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?X(e,t):ir(e,t)},kr=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?xr(e,t):ar(e,t)},Ar=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?Sr(e,t):or(e,t)},jr=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?Z(e,t):sr(e,t)},Mr=(e,t)=>{Cr(e.inputs[0].dims,t.axes,t.noopWithEmptyAxes)?hr(e,t):cr(e,t)}}),Pr,Fr,Ir,Lr,Rr=o(()=>{L(),H(),Nr(),Pr=e=>{if(!e||e.length===0||e.length>2)throw Error(`ArgMinMaxOp op requires 1 or 2 inputs.`);if(e[0].dataType!==1)throw Error(`Invalid input type.`)},Fr=(e,t)=>{Pr(e.inputs),e.compute(fr(`ArgMin`,{hint:t.cacheKey,inputDependencies:[`rank`]},[e.inputs[0]],(e,n,r)=>{let i=[];for(let t=0;t<e.rank;t++)(r.indexOf(t)>=0||r.length===0)&&i.push(`input_indices[${t}] = 0;`);return[`${i.join(` | |
| `)}`,`var value = ${e.getByIndices(`input_indices`)}; | |
| var best_index : i32 = 0;`,`if (${e.getByIndices(`input_indices`)} ${t.selectLastIndex>0?`<=`:`<`} value) { | |
| value = ${e.getByIndices(`input_indices`)}; | |
| best_index = i32(last_index); | |
| }`,``,n.setByOffset(`global_idx`,`best_index`)]},[t.axis],7,t.keepDims),{inputs:[0]})},Ir=(e,t)=>{Pr(e.inputs),e.compute(fr(`argMax`,{hint:t.cacheKey,inputDependencies:[`rank`]},[e.inputs[0]],(e,n,r)=>{let i=[];for(let t=0;t<e.rank;t++)(r.indexOf(t)>=0||r.length===0)&&i.push(`input_indices[${t}] = 0;`);return[`${i.join(` | |
| `)}`,`var value = ${e.getByIndices(`input_indices`)}; | |
| var best_index : i32 = 0;`,`if (${e.getByIndices(`input_indices`)} ${t.selectLastIndex>0?`>=`:`>`} value) { | |
| value = ${e.getByIndices(`input_indices`)}; | |
| best_index = i32(last_index); | |
| }`,``,n.setByOffset(`global_idx`,`best_index`)]},[t.axis],7,t.keepDims),{inputs:[0]})},Lr=e=>V(e)}),zr,Br,Vr,Hr,Ur,Wr,Gr,Kr,qr=o(()=>{L(),B(),ln(),Y(),zr=(e,t)=>{let n=e[0],r=e[1],i=e[2],a=e[3],o=e[4],s=e[5];if(o&&s)throw Error(`Attention cannot have both past and attention_bias`);if(n.dims.length!==3)throw Error(`Input "input" must have 3 dimensions`);let c=n.dims[0],l=n.dims[1],u=n.dims[2];if(i.dims.length!==1)throw Error(`Input "bias" is expected to have 1 dimensions`);if(r.dims.length!==2)throw Error(`Input "weights" is expected to have 2 dimensions`);if(r.dims[0]!==u)throw Error(`Input 1 dimension 0 should have same length as dimension 2 of input 0`);if(i.dims[0]!==r.dims[1])throw Error(`Input "bias" dimension 0 should have same length as dimension 1 of input "weights"`);let d=i.dims[0]/3,f=d,p=f;if(t.qkvHiddenSizes.length>0){if(t.qkvHiddenSizes.length!==3)throw Error(`qkv_hidden_sizes attribute should have 3 elements`);for(let e of t.qkvHiddenSizes)if(e%t.numHeads!==0)throw Error(`qkv_hidden_sizes should be divisible by num_heads`);d=t.qkvHiddenSizes[0],f=t.qkvHiddenSizes[1],p=t.qkvHiddenSizes[2]}let m=l;if(d!==f)throw Error(`qkv_hidden_sizes first element should be same as the second`);if(i.dims[0]!==d+f+p)throw Error(`Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes`);let h=0;if(o){if(f!==p)throw Error(`Input "past" expect k_hidden_size == v_hidden_size`);if(o.dims.length!==5)throw Error(`Input "past" must have 5 dimensions`);if(o.dims[0]!==2)throw Error(`Input "past" first dimension must be 2`);if(o.dims[1]!==c)throw Error(`Input "past" second dimension must be batch_size`);if(o.dims[2]!==t.numHeads)throw Error(`Input "past" third dimension must be num_heads`);if(o.dims[4]!==f/t.numHeads)throw Error(`Input "past" fifth dimension must be k_hidden_size / num_heads`);t.pastPresentShareBuffer||(h=o.dims[3])}let g=m+h;if(a)throw Error(`Mask not supported`);if(o)throw Error(`past is not supported`);if(s){if(s.dims.length!==4)throw Error(`Input "attention_bias" must have 4 dimensions`);if(s.dims[0]!==c||s.dims[1]!==t.numHeads||s.dims[2]!==l||s.dims[3]!==g)throw Error(`Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)`)}return{batchSize:c,sequenceLength:l,pastSequenceLength:h,kvSequenceLength:m,totalSequenceLength:g,maxSequenceLength:-1,inputHiddenSize:u,hiddenSize:d,vHiddenSize:p,headSize:Math.floor(d/t.numHeads),vHeadSize:Math.floor(p/t.numHeads),numHeads:t.numHeads,isUnidirectional:!1,pastPresentShareBuffer:!1,maskFilterValue:t.maskFilterValue,maskType:0,scale:t.scale,broadcastResPosBias:!1,passPastInKv:!1,qkvFormat:1}},Br=(e,t,n)=>t&&e?` | |
| let total_sequence_length_input = u32(${t.getByOffset(`0`)}); | |
| let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length); | |
| let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input; | |
| let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input; | |
| total_sequence_length = u32(${e?.getByOffset(`batchIdx`)}) + 1; | |
| var past_sequence_length: u32 = 0; | |
| if (is_first_prompt == false) { | |
| past_sequence_length = total_sequence_length - sequence_length; | |
| } | |
| `:` | |
| ${n?`let past_sequence_length = uniforms.past_sequence_length`:``}; | |
| let present_sequence_length = total_sequence_length; | |
| `,Vr=(e,t,n,r,i,a,o,s)=>{let c=G(o?1:a),l=64,u=a/c;u<l&&(l=32);let d=Math.ceil(a/c/l),f=[{type:12,data:t},{type:12,data:n},{type:12,data:r},{type:12,data:i},{type:12,data:u},{type:12,data:d}],p=U(e.dataType,c),m=Cn(1,c),h=[`type`];return o&&h.push(`type`),s&&h.push(`type`),{name:`AttentionProbsSoftmax`,shaderCache:{hint:`${l};${p};${c}`,inputDependencies:h},getShaderSource:t=>{let n=J(`x`,e.dataType,e.dims,c),r=[n],i=o?q(`seq_lens`,o.dataType,o.dims):void 0;i&&r.push(i);let a=s?q(`total_sequence_length_input`,s.dataType,s.dims):void 0;a&&r.push(a);let u=Cn(e.dataType);return` | |
| var<workgroup> thread_max: array<f32, ${l}>; | |
| var<workgroup> thread_sum: array<f32, ${l}>; | |
| ${t.registerUniforms([{name:`batch_size`,type:`u32`},{name:`num_heads`,type:`u32`},{name:`past_sequence_length`,type:`u32`},{name:`sequence_length`,type:`u32`},{name:`total_sequence_length`,type:`u32`},{name:`elements_per_thread`,type:`u32`}]).declareVariables(...r)} | |
| ${t.mainStart([l,1,1])} | |
| let batchIdx = workgroup_id.z / uniforms.num_heads; | |
| let headIdx = workgroup_id.z % uniforms.num_heads; | |
| let sequence_length = uniforms.sequence_length; | |
| var total_sequence_length = uniforms.total_sequence_length; | |
| ${Br(i,a,!1)} | |
| let local_offset = local_idx * uniforms.elements_per_thread; | |
| let offset = (global_idx / ${l}) * uniforms.total_sequence_length + local_offset; | |
| let seq_causal_length = ${o?`u32(past_sequence_length + workgroup_id.y + 1)`:`total_sequence_length`}; | |
| var thread_max_vector = ${m}(-3.4028234663852886e+38f); | |
| for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { | |
| thread_max_vector = max(${m}(x[offset + i]), thread_max_vector); | |
| } | |
| thread_max[local_idx] = ${(()=>{switch(c){case 1:return`thread_max_vector`;case 2:return`max(thread_max_vector.x, thread_max_vector.y)`;case 4:return`max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))`;default:throw Error(`Unsupported components: ${c}`)}})()}; | |
| workgroupBarrier(); | |
| var max_value = f32(-3.4028234663852886e+38f); | |
| for (var i = 0u; i < ${l}; i++) { | |
| max_value = max(thread_max[i], max_value); | |
| } | |
| var sum_vector = ${m}(0); | |
| for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { | |
| sum_vector += exp(${m}(x[offset + i]) - max_value); | |
| } | |
| thread_sum[local_idx] = ${(()=>{switch(c){case 1:return`sum_vector`;case 2:return`sum_vector.x + sum_vector.y`;case 4:return`sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w`;default:throw Error(`Unsupported components: ${c}`)}})()}; | |
| workgroupBarrier(); | |
| var sum: f32 = 0; | |
| for (var i = 0u; i < ${l}; i++) { | |
| sum += thread_sum[i]; | |
| } | |
| if (sum == 0) { | |
| for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { | |
| x[offset + i] = ${n.type.value}(${u}(1.0) / ${u}(seq_causal_length)); | |
| } | |
| } else { | |
| for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { | |
| var f32input = ${m}(x[offset + i]); | |
| x[offset + i] = ${n.type.value}(exp(f32input - max_value) / sum); | |
| } | |
| } | |
| ${o?` | |
| for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) { | |
| x[offset + total_seq_id] = ${n.type.value}(${u}(0)); | |
| }`:``}; | |
| }`},getRunData:()=>({outputs:[],dispatchGroup:{x:1,y:i,z:t*n},programUniforms:f})}},Hr=(e,t,n,r,i,a,o,s,c)=>{let l=o+a.kvSequenceLength,u=[a.batchSize,a.numHeads,a.sequenceLength,l],d=e>1&&r,f=a.kvNumHeads?a.kvNumHeads:a.numHeads,p=d?[a.batchSize,f,l,a.headSize]:void 0,m=a.nReps?a.nReps:1,h=a.scale===0?1/Math.sqrt(a.headSize):a.scale,g=G(a.headSize),_=a.headSize/g,v={x:Math.ceil(l/12),y:Math.ceil(a.sequenceLength/12),z:a.batchSize*a.numHeads},y=[{type:12,data:a.sequenceLength},{type:12,data:_},{type:12,data:l},{type:12,data:a.numHeads},{type:12,data:a.headSize},{type:1,data:h},{type:12,data:o},{type:12,data:a.kvSequenceLength},{type:12,data:m}],b=d&&r&&z.size(r.dims)>0,x=[`type`,`type`];b&&x.push(`type`),i&&x.push(`type`),s&&x.push(`type`),c&&x.push(`type`);let S=[{dims:u,dataType:t.dataType,gpuDataType:0}];return d&&S.push({dims:p,dataType:t.dataType,gpuDataType:0}),{name:`AttentionProbs`,shaderCache:{hint:`${g};${i!==void 0};${r!==void 0};${e}`,inputDependencies:x},getRunData:()=>({outputs:S,dispatchGroup:v,programUniforms:y}),getShaderSource:e=>{let a=q(`q`,t.dataType,t.dims,g),o=[a,q(`key`,n.dataType,n.dims,g)];if(b){let e=q(`past_key`,r.dataType,r.dims,g);o.push(e)}i&&o.push(q(`attention_bias`,i.dataType,i.dims));let l=s?q(`seq_lens`,s.dataType,s.dims):void 0;l&&o.push(l);let f=c?q(`total_sequence_length_input`,c.dataType,c.dims):void 0;f&&o.push(f);let h=J(`output`,t.dataType,u),_=[h];d&&_.push(J(`present_key`,t.dataType,p,g));let v=Cn(1,g);return` | |
| const TILE_SIZE = 12u; | |
| var<workgroup> tileQ: array<${a.type.storage}, 144>; | |
| var<workgroup> tileK: array<${a.type.storage}, 144>; | |
| ${e.registerUniforms([{name:`M`,type:`u32`},{name:`K`,type:`u32`},{name:`N`,type:`u32`},{name:`num_heads`,type:`u32`},{name:`head_size`,type:`u32`},{name:`alpha`,type:`f32`},{name:`past_sequence_length`,type:`u32`},{name:`kv_sequence_length`,type:`u32`},{name:`n_reps`,type:`u32`}]).declareVariables(...o,..._)} | |
| ${e.mainStart([12,12,1])} | |
| // x holds the N and y holds the M | |
| let headIdx = workgroup_id.z % uniforms.num_heads; | |
| let kvHeadIdx = ${m===1?`headIdx`:`headIdx / uniforms.n_reps`}; | |
| let kv_num_heads = ${m===1?`uniforms.num_heads`:`uniforms.num_heads / uniforms.n_reps`}; | |
| let batchIdx = workgroup_id.z / uniforms.num_heads; | |
| let m = workgroup_id.y * TILE_SIZE; | |
| let n = workgroup_id.x * TILE_SIZE; | |
| let sequence_length = uniforms.M; | |
| var total_sequence_length = uniforms.N; | |
| ${Br(l,f,!0)} | |
| let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; | |
| let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; | |
| ${b&&d?`let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;`:``}; | |
| let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K; | |
| ${d?`let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;`:``} | |
| var value = ${v}(0); | |
| for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { | |
| if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { | |
| tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x]; | |
| } | |
| if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { | |
| var idx = TILE_SIZE * local_id.y + local_id.x; | |
| ${b&&d?` | |
| if (n + local_id.y < past_sequence_length) { | |
| tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; | |
| } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { | |
| tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; | |
| }`:` | |
| if (n + local_id.y < uniforms.kv_sequence_length) { | |
| tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; | |
| }`} | |
| ${d?`if (n + local_id.y < present_sequence_length) { | |
| present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx]; | |
| }`:``} | |
| } | |
| workgroupBarrier(); | |
| for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { | |
| value += ${v}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (global_id.y < uniforms.M && global_id.x < total_sequence_length) { | |
| let headOffset = workgroup_id.z * uniforms.M * uniforms.N; | |
| let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; | |
| var sum: f32 = ${(()=>{switch(g){case 1:return`value`;case 2:return`value.x + value.y`;case 4:return`value.x + value.y + value.z + value.w`;default:throw Error(`Unsupported components: ${g}`)}})()}; | |
| output[outputIdx] = ${h.type.value} (sum * uniforms.alpha) + ${i?`attention_bias[outputIdx]`:`0.0`}; | |
| } | |
| }`}}},Ur=(e,t,n,r,i,a,o=void 0,s=void 0)=>{let c=a+i.kvSequenceLength,l=i.nReps?i.nReps:1,u=i.vHiddenSize*l,d=e>1&&r,f=i.kvNumHeads?i.kvNumHeads:i.numHeads,p=d?[i.batchSize,f,c,i.headSize]:void 0,m=[i.batchSize,i.sequenceLength,u],h={x:Math.ceil(i.vHeadSize/12),y:Math.ceil(i.sequenceLength/12),z:i.batchSize*i.numHeads},g=[{type:12,data:i.sequenceLength},{type:12,data:c},{type:12,data:i.vHeadSize},{type:12,data:i.numHeads},{type:12,data:i.headSize},{type:12,data:u},{type:12,data:a},{type:12,data:i.kvSequenceLength},{type:12,data:l}],_=d&&r&&z.size(r.dims)>0,v=[`type`,`type`];_&&v.push(`type`),o&&v.push(`type`),s&&v.push(`type`);let y=[{dims:m,dataType:t.dataType,gpuDataType:0}];return d&&y.push({dims:p,dataType:t.dataType,gpuDataType:0}),{name:`AttentionScore`,shaderCache:{hint:`${r!==void 0};${e}`,inputDependencies:v},getRunData:()=>({outputs:y,dispatchGroup:h,programUniforms:g}),getShaderSource:e=>{let i=q(`probs`,t.dataType,t.dims),a=[i,q(`v`,n.dataType,n.dims)];_&&a.push(q(`past_value`,r.dataType,r.dims));let c=o?q(`seq_lens`,o.dataType,o.dims):void 0;o&&a.push(c);let u=s?q(`total_sequence_length_input`,s.dataType,s.dims):void 0;s&&a.push(u);let f=[J(`output`,t.dataType,m)];return d&&f.push(J(`present_value`,t.dataType,p)),` | |
| const TILE_SIZE = 12u; | |
| var<workgroup> tileQ: array<${i.type.value}, 144>; | |
| var<workgroup> tileV: array<${i.type.value}, 144>; | |
| ${e.registerUniforms([{name:`M`,type:`u32`},{name:`K`,type:`u32`},{name:`N`,type:`u32`},{name:`num_heads`,type:`u32`},{name:`head_size`,type:`u32`},{name:`v_hidden_size`,type:`u32`},{name:`past_sequence_length`,type:`u32`},{name:`kv_sequence_length`,type:`u32`},{name:`n_reps`,type:`u32`}]).declareVariables(...a,...f)} | |
| ${e.mainStart([12,12,1])} | |
| let headIdx = workgroup_id.z % uniforms.num_heads; | |
| let batchIdx = workgroup_id.z / uniforms.num_heads; | |
| let kvHeadIdx = ${l===1?`headIdx`:`headIdx / uniforms.n_reps`}; | |
| let kv_num_heads = ${l===1?`uniforms.num_heads`:`uniforms.num_heads / uniforms.n_reps`}; | |
| let m = global_id.y; | |
| let n = global_id.x; | |
| let sequence_length = uniforms.M; | |
| var total_sequence_length = uniforms.K; | |
| ${Br(c,u,!0)} | |
| let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; | |
| let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch | |
| ${_&&d?`let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;`:``}; | |
| let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n; | |
| ${d?`let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;`:``} | |
| var value = ${i.type.storage}(0); | |
| for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { | |
| if (m < uniforms.M && w + local_id.x < uniforms.K) { | |
| tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; | |
| } | |
| if (n < uniforms.N && w + local_id.y < uniforms.K) { | |
| var idx = TILE_SIZE * local_id.y + local_id.x; | |
| ${_&&d?` | |
| if (w + local_id.y < past_sequence_length) { | |
| tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; | |
| } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { | |
| tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N]; | |
| } | |
| `:` | |
| if (w + local_id.y < uniforms.kv_sequence_length) { | |
| tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N]; | |
| }`} | |
| ${d?` | |
| if (w + local_id.y < present_sequence_length) { | |
| present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx]; | |
| }`:``} | |
| } | |
| workgroupBarrier(); | |
| for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) { | |
| value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| // we need to transpose output from BNSH_v to BSND_v | |
| if (m < uniforms.M && n < uniforms.N) { | |
| let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size | |
| + headIdx * uniforms.N + n; | |
| output[outputIdx] = value; | |
| } | |
| }`}}},Wr=(e,t,n,r,i,a,o,s,c,l,u=void 0,d=void 0)=>{let f=Math.min(e.outputCount,1+(o?1:0)+(s?1:0)),p=f>1?l.pastSequenceLength:0,m=p+l.kvSequenceLength,h=c&&z.size(c.dims)>0?c:void 0,g=[t,n];f>1&&o&&z.size(o.dims)>0&&g.push(o),h&&g.push(h),u&&g.push(u),d&&g.push(d);let _=e.compute(Hr(f,t,n,o,h,l,p,u,d),{inputs:g,outputs:f>1?[-1,1]:[-1]})[0];e.compute(Vr(_,l.batchSize,l.numHeads,p,l.sequenceLength,m,u,d),{inputs:u&&d?[_,u,d]:[_],outputs:[]});let v=[_,r];f>1&&s&&z.size(s.dims)>0&&v.push(s),u&&v.push(u),d&&v.push(d),e.compute(Ur(f,_,r,s,l,p,u,d),{inputs:v,outputs:f>1?[0,2]:[0]})},Gr=(e,t)=>{let n=[t.batchSize,t.numHeads,t.sequenceLength,t.headSize],r=t.sequenceLength,i=t.inputHiddenSize,a=t.headSize,o={x:Math.ceil(t.headSize/12),y:Math.ceil(t.sequenceLength/12),z:t.batchSize*t.numHeads},s=[e.inputs[0],e.inputs[1],e.inputs[2]],c=[{type:12,data:r},{type:12,data:i},{type:12,data:a},{type:12,data:t.numHeads},{type:12,data:t.headSize},{type:12,data:t.hiddenSize},{type:12,data:t.hiddenSize+t.hiddenSize+t.vHiddenSize}];return e.compute({name:`AttentionPrepare`,shaderCache:{inputDependencies:[`type`,`type`,`type`]},getRunData:()=>({outputs:[{dims:n,dataType:e.inputs[0].dataType,gpuDataType:0},{dims:n,dataType:e.inputs[0].dataType,gpuDataType:0},{dims:n,dataType:e.inputs[0].dataType,gpuDataType:0}],dispatchGroup:o,programUniforms:c}),getShaderSource:e=>{let t=J(`output_q`,s[0].dataType,n),r=J(`output_k`,s[0].dataType,n),i=J(`output_v`,s[0].dataType,n),a=q(`input`,s[0].dataType,s[0].dims),o=q(`weight`,s[1].dataType,s[1].dims),c=q(`bias`,s[2].dataType,s[2].dims),l=a.type.storage;return` | |
| const TILE_SIZE = 12u; | |
| var<workgroup> tileInput: array<${l}, 144>; | |
| var<workgroup> tileWeightQ: array<${l}, 144>; | |
| var<workgroup> tileWeightK: array<${l}, 144>; | |
| var<workgroup> tileWeightV: array<${l}, 144>; | |
| ${e.registerUniforms([{name:`M`,type:`u32`},{name:`K`,type:`u32`},{name:`N`,type:`u32`},{name:`num_heads`,type:`u32`},{name:`head_size`,type:`u32`},{name:`hidden_size`,type:`u32`},{name:`ldb`,type:`u32`}]).declareVariables(a,o,c,t,r,i)} | |
| ${e.mainStart([12,12,1])} | |
| let batchIndex = workgroup_id.z / uniforms.num_heads; | |
| let headNumber = workgroup_id.z % uniforms.num_heads; | |
| let m = global_id.y; | |
| let n = global_id.x; | |
| let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K; | |
| let biasOffsetQ = headNumber * uniforms.head_size; | |
| let biasOffsetK = uniforms.hidden_size + biasOffsetQ; | |
| let biasOffsetV = uniforms.hidden_size + biasOffsetK; | |
| var valueQ = ${l}(0); | |
| var valueK = ${l}(0); | |
| var valueV = ${l}(0); | |
| for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { | |
| if (m < uniforms.M && w + local_id.x < uniforms.K) { | |
| tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; | |
| } | |
| if (n < uniforms.N && w + local_id.y < uniforms.K) { | |
| let offset = n + (w + local_id.y) * uniforms.ldb; | |
| tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; | |
| tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; | |
| tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; | |
| } | |
| workgroupBarrier(); | |
| for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) { | |
| let inputTileOffset = TILE_SIZE * local_id.y + k; | |
| let weightTileOffset = TILE_SIZE * k + local_id.x; | |
| valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset]; | |
| valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset]; | |
| valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let headOffset = (m * uniforms.N + n) % uniforms.head_size; | |
| valueQ += bias[headOffset + biasOffsetQ]; | |
| valueK += bias[headOffset + biasOffsetK]; | |
| valueV += bias[headOffset + biasOffsetV]; | |
| let offset = workgroup_id.z * uniforms.M * uniforms.N; | |
| if (m < uniforms.M && n < uniforms.N) { | |
| let outputIdx = offset + m * uniforms.N + n; | |
| output_q[outputIdx] = valueQ; | |
| output_k[outputIdx] = valueK; | |
| output_v[outputIdx] = valueV; | |
| } | |
| }`}},{inputs:s,outputs:[-1,-1,-1]})},Kr=(e,t)=>{let n=zr(e.inputs,t),[r,i,a]=Gr(e,n);return Wr(e,r,i,a,e.inputs[4],void 0,void 0,void 0,e.inputs[5],n)}}),Jr,Yr,Xr,Zr,Qr=o(()=>{N(),L(),B(),H(),Y(),Jr=(e,t)=>{if(!e||e.length!==5)throw Error(`BatchNormalization requires 5 inputs`);let n=(e,t,n)=>{let r=t.length;if(r!==e.length)throw Error(`${n}: num dimensions != ${r}`);t.forEach((t,r)=>{if(t!==e[r])throw Error(`${n}: dim[${r}] do not match`)})};if(e[0].dims.length>1){let r=t.format===`NHWC`?t.spatial?e[0].dims.slice(-1):e[0].dims.slice(-1).concat(e[0].dims.slice(1,e[0].dims.length-1)):e[0].dims.slice(1,t.spatial?2:void 0);n(e[1].dims,r,`Invalid input scale`),n(e[2].dims,r,`Invalid input B`),n(e[3].dims,r,`Invalid input mean`),n(e[4].dims,r,`Invalid input var`)}else n(e[1].dims,[1],`Invalid input scale`),n(e[2].dims,[1],`Invalid input B`),n(e[3].dims,[1],`Invalid input mean`),n(e[4].dims,[1],`Invalid input var`)},Yr=(e,t)=>{let{epsilon:n,spatial:r,format:i}=t,a=e[0].dims,o=r?G(a[a.length-1]):1,s=i===`NHWC`&&a.length>1?o:1,c=z.size(a)/o,l=r,u=l?a.length:a,d=q(`x`,e[0].dataType,e[0].dims,o),f=q(`scale`,e[1].dataType,e[1].dims,s),p=q(`bias`,e[2].dataType,e[2].dims,s),m=q(`inputMean`,e[3].dataType,e[3].dims,s),h=q(`inputVar`,e[4].dataType,e[4].dims,s),g=J(`y`,e[0].dataType,u,o),_=()=>{let e=``;if(r)e=`let cOffset = ${a.length===1?`0u`:i===`NHWC`?`outputIndices[${a.length-1}] / ${o}`:`outputIndices[1]`};`;else if(i===`NCHW`)e=` | |
| ${g.indicesSet(`outputIndices`,`0`,`0`)} | |
| let cOffset = ${g.indicesToOffset(`outputIndices`)};`;else{e=`var cIndices = ${f.type.indices}(0); | |
| cIndices[0] = outputIndices[${a.length-1}];`;for(let t=1;t<f.rank;t++)e+=`cIndices[${t}] = outputIndices[${t}];`;e+=`let cOffset = ${f.indicesToOffset(`cIndices`)};`}return e};return{name:`BatchNormalization`,shaderCache:{hint:`${t.epsilon}_${t.format}_${r}_${o}`,inputDependencies:l?[`rank`,`type`,`type`,`type`,`type`]:void 0},getShaderSource:e=>` | |
| const epsilon = ${n}; | |
| ${e.registerUniform(`outputSize`,`u32`).declareVariables(d,f,p,m,h,g)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| var outputIndices = ${g.offsetToIndices(`global_idx * ${o}`)}; | |
| ${_()} | |
| let scale = ${f.getByOffset(`cOffset`)}; | |
| let bias = ${p.getByOffset(`cOffset`)}; | |
| let inputMean = ${m.getByOffset(`cOffset`)}; | |
| let inputVar = ${h.getByOffset(`cOffset`)}; | |
| let x = ${d.getByOffset(`global_idx`)}; | |
| let value = (x - inputMean) * inverseSqrt(inputVar + epsilon) * scale + bias; | |
| ${g.setByOffset(`global_idx`,`value`)} | |
| }`,getRunData:()=>({outputs:[{dims:e[0].dims,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(c/64)},programUniforms:l?[{type:12,data:c},...W(a)]:[{type:12,data:c}]})}},Xr=e=>V(e),Zr=(e,t)=>{let{inputs:n,outputCount:r}=e,i=Xr({...t,outputCount:r});if(S.webgpu.validateInputContent&&Jr(n,i),t.trainingMode)throw Error(`BatchNormalization trainingMode is not supported yet.`);e.compute(Yr(n,i))}}),$r,ei,ti,ni=o(()=>{B(),Y(),$r=e=>{if(e[0].dims.length!==3)throw Error(`input should have 3 dimensions`);if(![320,640,1280].includes(e[0].dims[2]))throw Error(`number of channels should be 320, 640 or 1280`);if(e[1].dims.length!==1)throw Error(`bias is expected to have 1 dimensions`);if(e[0].dims[2]!==e[1].dims[0])throw Error(`last dimension of input and bias are not the same`)},ei=e=>{let t=e[0].dims,n=e[0].dims[2],r=z.size(t)/4,i=e[0].dataType,a=q(`input`,i,t,4),o=q(`bias`,i,[n],4),s=q(`residual`,i,t,4),c=J(`output`,i,t,4);return{name:`BiasAdd`,getRunData:()=>({outputs:[{dims:t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(r/64)}}),getShaderSource:e=>` | |
| const channels = ${n}u / 4; | |
| ${e.declareVariables(a,o,s,c)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(r)} | |
| let value = ${a.getByOffset(`global_idx`)} | |
| + ${o.getByOffset(`global_idx % channels`)} + ${s.getByOffset(`global_idx`)}; | |
| ${c.setByOffset(`global_idx`,`value`)} | |
| }`}},ti=e=>{$r(e.inputs),e.compute(ei(e.inputs))}}),ri,$,ii,ai,oi,si,ci,li,ui,di,fi,pi,mi,hi,gi,_i,vi,yi,bi,xi,Si,Ci,wi,Ti,Ei,Di,Oi,ki,Ai,ji,Mi,Ni,Pi,Fi,Ii,Li,Ri,zi,Bi,Vi,Hi,Ui,Wi,Gi,Ki,qi=o(()=>{L(),B(),H(),Y(),ri=(e,t,n,r,i,a,o)=>{let s=Math.ceil(t/4),c=``;c=typeof i==`string`?`${i}(a)`:i(`a`);let l=q(`inputData`,n,[s],4),u=J(`outputData`,r,[s],4),d=[{name:`vec_size`,type:`u32`}];return o&&d.push(...o),` | |
| ${e.registerUniforms(d).declareVariables(l,u)} | |
| ${a??``} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.vec_size`)} | |
| let a = ${l.getByOffset(`global_idx`)}; | |
| ${u.setByOffset(`global_idx`,c)} | |
| }`},$=(e,t,n,r,i,a=e.dataType,o,s)=>{let c=[{type:12,data:Math.ceil(z.size(e.dims)/4)}];return o&&c.push(...o),{name:t,shaderCache:{hint:i,inputDependencies:[`type`]},getShaderSource:t=>ri(t,z.size(e.dims),e.dataType,a,n,r,s),getRunData:t=>({outputs:[{dims:e.dims,dataType:a}],dispatchGroup:{x:Math.ceil(z.size(t[0].dims)/64/4)},programUniforms:c})}},ii=e=>{e.compute($(e.inputs[0],`Abs`,`abs`))},ai=e=>{e.compute($(e.inputs[0],`Acos`,`acos`))},oi=e=>{e.compute($(e.inputs[0],`Acosh`,`acosh`))},si=e=>{e.compute($(e.inputs[0],`Asin`,`asin`))},ci=e=>{e.compute($(e.inputs[0],`Asinh`,`asinh`))},li=e=>{e.compute($(e.inputs[0],`Atan`,`atan`))},ui=e=>{e.compute($(e.inputs[0],`Atanh`,`atanh`))},di=e=>V(e),fi=(e,t)=>{let n;switch(t.to){case 10:n=`vec4<f16>`;break;case 1:n=`vec4<f32>`;break;case 12:n=`vec4<u32>`;break;case 6:n=`vec4<i32>`;break;case 9:n=`vec4<bool>`;break;default:throw RangeError(`not supported type (specified in attribute 'to' from 'Cast' operator): ${t.to}`)}e.compute($(e.inputs[0],`Cast`,n,void 0,t.cacheKey,t.to))},pi=e=>{let t,n,r=e.length>=2&&e[1].data!==0,i=e.length>=3&&e[2].data!==0;switch(e[0].dataType){case 1:t=r?e[1].getFloat32Array()[0]:-34028234663852886e22,n=i?e[2].getFloat32Array()[0]:34028234663852886e22;break;case 10:t=r?e[1].getUint16Array()[0]:64511,n=i?e[2].getUint16Array()[0]:31743;break;default:throw Error(`Unsupport data type`)}return V({min:t,max:n})},mi=(e,t)=>{let n=t||pi(e.inputs),r=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`Clip`,e=>`clamp(${e}, vec4<${r}>(uniforms.min), vec4<${r}>(uniforms.max))`,void 0,n.cacheKey,void 0,[{type:e.inputs[0].dataType,data:n.min},{type:e.inputs[0].dataType,data:n.max}],[{name:`min`,type:r},{name:`max`,type:r}]),{inputs:[0]})},hi=e=>{e.compute($(e.inputs[0],`Ceil`,`ceil`))},gi=e=>{e.compute($(e.inputs[0],`Cos`,`cos`))},_i=e=>{e.compute($(e.inputs[0],`Cosh`,`cosh`))},vi=e=>V(e),yi=(e,t)=>{let n=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`Elu`,e=>`elu_vf32(${e})`,` | |
| const elu_alpha_ = ${n}(${t.alpha}); | |
| fn elu_f32(a: ${n}) -> ${n} { | |
| return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0); | |
| } | |
| fn elu_vf32(v: vec4<${n}>) -> vec4<${n}> { | |
| return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w)); | |
| }`,t.cacheKey))},bi=(e=`f32`)=>` | |
| const r0: ${e} = 0.3275911; | |
| const r1: ${e} = 0.254829592; | |
| const r2: ${e} = -0.284496736; | |
| const r3: ${e} = 1.421413741; | |
| const r4: ${e} = -1.453152027; | |
| const r5: ${e} = 1.061405429; | |
| fn erf_vf32(v: vec4<${e}>) -> vec4<${e}> { | |
| let absv = abs(v); | |
| let x = 1.0 / (1.0 + r0 * absv); | |
| return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); | |
| }`,xi=e=>{let t=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`Erf`,e=>`erf_vf32(${e})`,bi(t)))},Si=e=>{e.compute($(e.inputs[0],`Exp`,`exp`))},Ci=e=>{e.compute($(e.inputs[0],`Floor`,`floor`))},wi=e=>{let t=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`Gelu`,e=>`0.5 * ${e} * (1.0 + erf_vf32(${e} * 0.7071067811865475))`,bi(t)))},Ti=(e,t)=>{let n=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`LeakyRelu`,e=>`select(leaky_relu_alpha_ * ${e}, ${e}, ${e} >= vec4<${n}>(0.0))`,`const leaky_relu_alpha_ = ${n}(${t.alpha});`,t.cacheKey))},Ei=e=>{e.compute($(e.inputs[0],`Not`,e=>`!${e}`))},Di=e=>{e.compute($(e.inputs[0],`Neg`,e=>`-${e}`))},Oi=e=>{e.compute($(e.inputs[0],`Reciprocal`,e=>`1.0/${e}`))},ki=e=>{let t=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`Relu`,e=>`select(vec4<${t}>(0.0), ${e}, ${e} > vec4<${t}>(0.0))`))},Ai=e=>{e.compute($(e.inputs[0],`Sigmoid`,e=>`(1.0 / (1.0 + exp(-${e})))`))},ji=e=>V(e),Mi=(e,t)=>{let n=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`HardSigmoid`,e=>`max(vec4<${n}>(0.0), min(vec4<${n}>(1.0), ${t.alpha} * ${e} + vec4<${n}>(${t.beta})))`,void 0,t.cacheKey))},Ni=e=>{e.compute($(e.inputs[0],`Sin`,`sin`))},Pi=e=>{e.compute($(e.inputs[0],`Sinh`,`sinh`))},Fi=e=>{e.compute($(e.inputs[0],`Sqrt`,`sqrt`))},Ii=e=>{e.compute($(e.inputs[0],`Tan`,`tan`))},Li=e=>`sign(${e}) * (1 - exp(-2 * abs(${e}))) / (1 + exp(-2 * abs(${e})))`,Ri=e=>{e.compute($(e.inputs[0],`Tanh`,Li))},zi=(e=`f32`)=>` | |
| const fast_gelu_a: ${e} = 0.5; | |
| const fast_gelu_b: ${e} = 0.7978845608028654; | |
| const fast_gelu_c: ${e} = 0.035677408136300125; | |
| fn tanh_v(v: vec4<${e}>) -> vec4<${e}> { | |
| return ${Li(`v`)}; | |
| } | |
| `,Bi=e=>`(fast_gelu_a + fast_gelu_a * tanh_v(${e} * (fast_gelu_c * ${e} * ${e} + fast_gelu_b))) * ${e}`,Vi=e=>{let t=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`FastGelu`,Bi,zi(t),void 0,e.inputs[0].dataType))},Hi=(e,t)=>{let n=Cn(e.inputs[0].dataType);return e.compute($(e.inputs[0],`ThresholdedRelu`,e=>`select(vec4<${n}>(0.0), ${e}, ${e} > thresholded_relu_alpha_)`,`const thresholded_relu_alpha_ = vec4<${n}>(${t.alpha});`,t.cacheKey)),0},Ui=e=>{e.compute($(e.inputs[0],`Log`,`log`))},Wi=(e,t)=>` | |
| const alpha = vec4<${e}>(${t}); | |
| const one = ${e}(1.0); | |
| const zero = ${e}(0.0); | |
| fn quick_gelu_impl(x: vec4<${e}>) -> vec4<${e}> { | |
| let v = x *alpha; | |
| var x1 : vec4<${e}>; | |
| for (var i = 0; i < 4; i = i + 1) { | |
| if (v[i] >= zero) { | |
| x1[i] = one / (one + exp(-v[i])); | |
| } else { | |
| x1[i] = one - one / (one + exp(v[i])); | |
| } | |
| } | |
| return x * x1; | |
| } | |
| `,Gi=e=>`quick_gelu_impl(${e})`,Ki=(e,t)=>{let n=Cn(e.inputs[0].dataType);e.compute($(e.inputs[0],`QuickGelu`,Gi,Wi(n,t.alpha),t.cacheKey,e.inputs[0].dataType))}}),Ji,Yi,Xi,Zi=o(()=>{B(),Y(),qi(),Ji=e=>{if(e[0].dims.length!==3)throw Error(`input should have 3 dimensions`);if(![2560,5120,10240].includes(e[0].dims[2]))throw Error(`hidden state should be 2560, 5120 or 10240`);if(e[1].dims.length!==1)throw Error(`bias is expected to have 1 dimensions`);if(e[0].dims[2]!==e[1].dims[0])throw Error(`last dimension of input and bias are not the same`)},Yi=e=>{let t=e[0].dims.slice();t[2]/=2;let n=q(`input`,e[0].dataType,e[0].dims,4),r=q(`bias`,e[0].dataType,[e[0].dims[2]],4),i=J(`output`,e[0].dataType,t,4),a=z.size(t)/4,o=U(e[0].dataType);return{name:`BiasSplitGelu`,getRunData:()=>({outputs:[{dims:t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(a/64)}}),getShaderSource:t=>` | |
| const M_SQRT2 = sqrt(2.0); | |
| const halfChannels = ${e[0].dims[2]/4/2}u; | |
| ${t.declareVariables(n,r,i)} | |
| ${bi(o)} | |
| ${t.mainStart()} | |
| ${t.guardAgainstOutOfBoundsWorkgroupSizes(a)} | |
| let biasIdx = global_idx % halfChannels; | |
| let batchIndex = global_idx / halfChannels; | |
| let inputOffset = biasIdx + batchIndex * halfChannels * 2; | |
| let valueLeft = input[inputOffset] + bias[biasIdx]; | |
| let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels]; | |
| let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1); | |
| ${i.setByOffset(`global_idx`,`valueLeft * geluRight`)} | |
| }`}},Xi=e=>{Ji(e.inputs),e.compute(Yi(e.inputs))}}),Qi,$i,ea,ta,na,ra,ia,aa,oa,sa,ca,la,ua,da=o(()=>{L(),B(),Y(),Qi=(e,t,n,r,i,a,o,s,c,l,u,d)=>{let f,p;typeof s==`string`?f=p=(e,t)=>`${s}((${e}),(${t}))`:typeof s==`function`?f=p=s:(f=s.scalar,p=s.vector);let m=J(`outputData`,u,r.length,4),h=q(`aData`,c,t.length,4),g=q(`bData`,l,n.length,4),_;if(i)if(a){let e=z.size(t)===1,r=z.size(n)===1,i=t.length>0&&t[t.length-1]%4==0,a=n.length>0&&n[n.length-1]%4==0;_=e||r?m.setByOffset(`global_idx`,p(e?`${h.type.value}(${h.getByOffset(`0`)}.x)`:h.getByOffset(`global_idx`),r?`${g.type.value}(${g.getByOffset(`0`)}.x)`:g.getByOffset(`global_idx`))):` | |
| let outputIndices = ${m.offsetToIndices(`global_idx * 4u`)}; | |
| let offsetA = ${h.broadcastedIndicesToOffset(`outputIndices`,m)}; | |
| let offsetB = ${g.broadcastedIndicesToOffset(`outputIndices`,m)}; | |
| ${m.setByOffset(`global_idx`,p(o||i?h.getByOffset(`offsetA / 4u`):`${h.type.value}(${h.getByOffset(`offsetA / 4u`)}[offsetA % 4u])`,o||a?g.getByOffset(`offsetB / 4u`):`${g.type.value}(${g.getByOffset(`offsetB / 4u`)}[offsetB % 4u])`))} | |
| `}else _=m.setByOffset(`global_idx`,p(h.getByOffset(`global_idx`),g.getByOffset(`global_idx`)));else{if(!a)throw Error(`no necessary to use scalar implementation for element-wise binary op implementation.`);let e=(e,t,n=``)=>{let r=`aData[indexA${t}][componentA${t}]`,i=`bData[indexB${t}][componentB${t}]`;return` | |
| let outputIndices${t} = ${m.offsetToIndices(`global_idx * 4u + ${t}u`)}; | |
| let offsetA${t} = ${h.broadcastedIndicesToOffset(`outputIndices${t}`,m)}; | |
| let offsetB${t} = ${g.broadcastedIndicesToOffset(`outputIndices${t}`,m)}; | |
| let indexA${t} = offsetA${t} / 4u; | |
| let indexB${t} = offsetB${t} / 4u; | |
| let componentA${t} = offsetA${t} % 4u; | |
| let componentB${t} = offsetB${t} % 4u; | |
| ${e}[${t}] = ${n}(${f(r,i)}); | |
| `};_=u===9?` | |
| var data = vec4<u32>(0); | |
| ${e(`data`,0,`u32`)} | |
| ${e(`data`,1,`u32`)} | |
| ${e(`data`,2,`u32`)} | |
| ${e(`data`,3,`u32`)} | |
| outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`:` | |
| ${e(`outputData[global_idx]`,0)} | |
| ${e(`outputData[global_idx]`,1)} | |
| ${e(`outputData[global_idx]`,2)} | |
| ${e(`outputData[global_idx]`,3)} | |
| `}return` | |
| ${e.registerUniform(`vec_size`,`u32`).declareVariables(h,g,m)} | |
| ${d??``} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.vec_size`)} | |
| ${_} | |
| }`},$i=(e,t,n,r,i,a,o=n.dataType)=>{let s=n.dims.map(Number),c=r.dims.map(Number),l=!z.areEqual(s,c),u=s,d=z.size(s),f=!1,p=!1,m=[l];if(l){let e=zt.calcShape(s,c,!1);if(!e)throw Error(`Can't perform binary op on the given tensors`);u=e.slice(),d=z.size(u);let t=z.size(s)===1,n=z.size(c)===1,r=s.length>0&&s[s.length-1]%4==0,i=c.length>0&&c[c.length-1]%4==0;m.push(t),m.push(n),m.push(r),m.push(i);let a=1;for(let e=1;e<u.length;e++){let t=s[s.length-e];if(t===c[c.length-e])a*=t;else break}a%4==0?(p=!0,f=!0):(t||n||r||i)&&(f=!0)}else f=!0;return m.push(f),{name:e,shaderCache:{hint:t+m.map(e=>e.toString()).join(`_`),inputDependencies:[`rank`,`rank`]},getShaderSource:e=>Qi(e,s,c,u,f,l,p,i,n.dataType,r.dataType,o,a),getRunData:()=>({outputs:[{dims:u,dataType:o}],dispatchGroup:{x:Math.ceil(d/64/4)},programUniforms:[{type:12,data:Math.ceil(z.size(u)/4)},...W(s,c,u)]})}},ea=(e,t,n,r,i,a)=>{e.compute($i(t,i??``,e.inputs[0],e.inputs[1],n,r,a))},ta=e=>{ea(e,`Add`,(e,t)=>`${e}+${t}`)},na=e=>{ea(e,`Div`,(e,t)=>`${e}/${t}`)},ra=e=>{ea(e,`Equal`,{scalar:(e,t)=>`u32(${e}==${t})`,vector:(e,t)=>`vec4<u32>(${e}==${t})`},void 0,void 0,9)},ia=e=>{ea(e,`Mul`,(e,t)=>`${e}*${t}`)},aa=e=>{let t=q(`input`,e.inputs[0].dataType,e.inputs[0].dims).type.value;ea(e,`Pow`,{scalar:(e,t)=>`pow_custom(${e},${t})`,vector:(e,t)=>`pow_vector_custom(${e},${t})`},` | |
| fn pow_custom(a : ${t}, b : ${t}) -> ${t} { | |
| if (b == ${t}(0.0)) { | |
| return ${t}(1.0); | |
| } else if (a < ${t}(0.0) && f32(b) != floor(f32(b))) { | |
| return ${t}(pow(f32(a), f32(b))); // NaN | |
| } | |
| return select(sign(a), ${t}(1.0), round(f32(abs(b) % ${t}(2.0))) != 1.0) * ${t}(${t===`i32`?`round`:``}(pow(f32(abs(a)), f32(b)))); | |
| } | |
| fn pow_vector_custom(a : vec4<${t}>, b : vec4<${t}>) -> vec4<${t}> { | |
| // TODO: implement vectorized pow | |
| return vec4<${t}>(pow_custom(a.x, b.x), pow_custom(a.y, b.y), pow_custom(a.z, b.z), pow_custom(a.w, b.w)); | |
| } | |
| `)},oa=e=>{ea(e,`Sub`,(e,t)=>`${e}-${t}`)},sa=e=>{ea(e,`Greater`,{scalar:(e,t)=>`u32(${e}>${t})`,vector:(e,t)=>`vec4<u32>(${e}>${t})`},void 0,void 0,9)},ca=e=>{ea(e,`Less`,{scalar:(e,t)=>`u32(${e}<${t})`,vector:(e,t)=>`vec4<u32>(${e}<${t})`},void 0,void 0,9)},la=e=>{ea(e,`GreaterOrEqual`,{scalar:(e,t)=>`u32(${e}>=${t})`,vector:(e,t)=>`vec4<u32>(${e}>=${t})`},void 0,void 0,9)},ua=e=>{ea(e,`LessOrEqual`,{scalar:(e,t)=>`u32(${e}<=${t})`,vector:(e,t)=>`vec4<u32>(${e}<=${t})`},void 0,void 0,9)}}),fa,pa,ma,ha,ga,_a,va=o(()=>{L(),B(),H(),Y(),fa=(e,t)=>{if(!e||e.length<1)throw Error(`too few inputs`);let n=e[0],r=n.dataType,i=n.dims.length;e.forEach((e,a)=>{if(a!==0){if(e.dataType!==r)throw Error(`input tensors should be one type`);if(e.dims.length!==i)throw Error(`input tensors should have the same shape`);e.dims.forEach((e,r)=>{if(r!==t&&e!==n.dims[r])throw Error(`non concat dimensions must match`)})}})},pa=(e,t)=>` | |
| fn calculateInputIndex(index: u32) -> u32 { | |
| let sizeInConcatAxis = array<u32, ${e}u>(${t}); | |
| for (var i: u32 = 0u; i < ${e}; i += 1u ) { | |
| if (index < sizeInConcatAxis[i]) { | |
| return i; | |
| } | |
| } | |
| return ${e}u; | |
| }`,ma=(e,t)=>{let n=e.length,r=[];for(let i=0;i<n;++i){let a=t.setByOffset(`global_idx`,e[i].getByIndices(`indices`));n===1?r.push(a):i===0?r.push(`if (inputIndex == ${i}u) { ${a} }`):i===n-1?r.push(`else { ${a} }`):r.push(`else if (inputIndex == ${i}) { ${a} }`)}return r.join(` | |
| `)},ha=(e,t,n,r)=>{let i=z.size(n),a=Array(e.length),o=Array(e.length),s=0,c=[],l=[],u=[{type:12,data:i}];for(let n=0;n<e.length;++n)s+=e[n].dims[t],a[n]=s,l.push(e[n].dims.length),o[n]=q(`input${n}`,r,l[n]),c.push(`rank`),u.push({type:12,data:a[n]});for(let t=0;t<e.length;++t)u.push(...W(e[t].dims));u.push(...W(n));let d=J(`output`,r,n.length),f=d.indicesGet(`indices`,t),p=Array.from(Array(a.length).keys()).map(e=>`uniforms.sizeInConcatAxis${e}`).join(`,`);return{name:`Concat`,shaderCache:{hint:`${t}`,inputDependencies:c},getRunData:()=>({outputs:[{dims:n,dataType:r}],dispatchGroup:{x:Math.ceil(i/64)},programUniforms:u}),getShaderSource:t=>` | |
| ${(()=>{t.registerUniform(`outputSize`,`u32`);for(let n=0;n<e.length;n++)t.registerUniform(`sizeInConcatAxis${n}`,`u32`);return t.declareVariables(...o,d)})()} | |
| ${pa(a.length,p)} | |
| ${t.mainStart()} | |
| ${t.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| var indices = ${d.offsetToIndices(`global_idx`)}; | |
| let inputIndex = calculateInputIndex(${f}); | |
| if (inputIndex != 0u) { | |
| let sizeInConcatAxis = array<u32, ${a.length}u>(${p}); | |
| ${f} -= sizeInConcatAxis[inputIndex - 1u]; | |
| } | |
| ${ma(o,d)} | |
| }`}},ga=(e,t)=>{let n=e.inputs,r=n[0].dims,i=z.normalizeAxis(t.axis,r.length);fa(n,i);let a=r.slice();a[i]=n.reduce((e,t)=>e+(t.dims.length>i?t.dims[i]:0),0);let o=n.filter(e=>z.size(e.dims)>0);e.compute(ha(o,i,a,n[0].dataType),{inputs:o})},_a=e=>V({axis:e.axis})}),ya,ba,xa,Sa,Ca=o(()=>{L(),B(),ya=(e,t,n=`f32`)=>{switch(e.activation){case`Relu`:return`value = max(value, ${t}(0.0));`;case`Sigmoid`:return`value = (${t}(1.0) / (${t}(1.0) + exp(-value)));`;case`Clip`:return`value = clamp(value, ${t}(${n}(uniforms.clip_min)), ${t}(${n}(uniforms.clip_max)));`;case`HardSigmoid`:return`value = max(${t}(0.0), min(${t}(1.0), ${n}(uniforms.alpha) * value + ${n}(uniforms.beta)));`;case`LeakyRelu`:return`value = select(${n}(uniforms.alpha) * value, value, value >= ${t}(0.0));`;case`Tanh`:return`let e2x = exp(-2.0 * abs(value)); | |
| value = sign(value) * (1.0 - e2x) / (1.0 + e2x); | |
| `;case``:return``;default:throw Error(`Unsupported activation ${e.activation}`)}},ba=(e,t)=>{e.activation===`Clip`?t.push({type:1,data:e.clipMax},{type:1,data:e.clipMin}):e.activation===`HardSigmoid`?t.push({type:1,data:e.alpha},{type:1,data:e.beta}):e.activation===`LeakyRelu`&&t.push({type:1,data:e.alpha})},xa=(e,t)=>{e.activation===`Clip`?t.push({name:`clip_max`,type:`f32`},{name:`clip_min`,type:`f32`}):e.activation===`HardSigmoid`?t.push({name:`alpha`,type:`f32`},{name:`beta`,type:`f32`}):e.activation===`LeakyRelu`&&t.push({name:`alpha`,type:`f32`})},Sa=e=>{let t=e?.activation||``;if(t===`HardSigmoid`){let[n,r]=e?.activation_params||[.2,.5];return{activation:t,alpha:n,beta:r}}else if(t===`Clip`){let[n,r]=e?.activation_params||[Ht,Ut];return{activation:t,clipMax:r,clipMin:n}}else if(t===`LeakyRelu`){let[n]=e?.activation_params||[.01];return{activation:t,alpha:n}}return{activation:t}}}),wa,Ta,Ea=o(()=>{wa=(e,t)=>{switch(e){case 1:return t;case 2:return`vec2<${t}>`;case 3:return`vec3<${t}>`;case 4:return`vec4<${t}>`;default:throw Error(`${e}-component is not supported.`)}},Ta=e=>` | |
| ${e?`value = value + getBiasByOutputCoords(coords);`:``} | |
| `}),Da,Oa=o(()=>{Da=e=>` | |
| fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 { | |
| return dot(coords, vec4<i32>( | |
| shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); | |
| } | |
| fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 { | |
| return dot(coords, vec4<i32>( | |
| i32(${e}.x), i32(${e}.y), i32(${e}.z), 1)); | |
| } | |
| `}),ka,Aa,ja=o(()=>{L(),B(),Y(),Ca(),ka=(e,t,n,r,i)=>{let a=r-n;return` | |
| ${Array.from({length:n}).map((n,o)=>` | |
| if (${K(t.shape,o,t.rank)} != 1) { | |
| ${t.indicesSet(e,o,K(i,o+a,r))} | |
| } else { | |
| ${t.indicesSet(e,o,0)} | |
| }`).join(``)} | |
| `},Aa=(e,t,n,r,i=!1,a)=>{let o=e[0].dims,s=e[1].dims,c=o[o.length-2],l=s[s.length-1],u=o[o.length-1],d=G(l),f=G(u),p=G(c),m=z.size(n)/d/p,h=e.length>2,g=r?r.slice(0,-2):n.slice(0,-2),_=[z.size(g),c,l],v=[{type:12,data:m},{type:12,data:c},{type:12,data:l},{type:12,data:u}];return ba(t,v),v.push(...W(g,o,s)),h&&v.push(...W(e[2].dims)),v.push(...W(_)),{name:`MatMulNaive`,shaderCache:{hint:`${t.activation};${d};${f};${p};${i}`,inputDependencies:h?[`rank`,`rank`,`rank`]:[`rank`,`rank`]},getRunData:()=>({outputs:[{dims:a?a(n):n,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(m/64)},programUniforms:v}),getShaderSource:r=>{let a=kn(`batch_dims`,e[0].dataType,g.length),c=q(`a`,e[0].dataType,o.length,f),l=q(`b`,e[1].dataType,s.length,d),u=J(`output`,e[0].dataType,_.length,d),m=U(u.type.tensor),v=ya(t,u.type.value,m),y=[c,l],b=``;if(h){let t=i?d:1;y.push(q(`bias`,e[2].dataType,e[2].dims.length,t)),b=`${i?`value += bias[col / ${t}];`:`value += ${u.type.value}(bias[row + i]);`}`}let x=[{name:`output_size`,type:`u32`},{name:`M`,type:`u32`},{name:`N`,type:`u32`},{name:`K`,type:`u32`}];xa(t,x);let S=()=>{let e=`var a_data: ${c.type.value};`;for(let t=0;t<f;t++)e+=` | |
| let b_data${t} = b[(b_offset + (k + ${t}) * uniforms.N + col) / ${d}];`;for(let t=0;t<p;t++){e+=`a_data = a[(a_offset + (row + ${t}) * uniforms.K + k) / ${f}];`;for(let n=0;n<f;n++)e+=` | |
| values[${t}] = fma(${l.type.value}(a_data${f===1?``:`[${n}]`}), b_data${n}, values[${t}]); | |
| `}return e};return` | |
| ${r.registerUniforms(x).registerInternalVariables(a).declareVariables(...y,u)} | |
| ${r.mainStart()} | |
| ${r.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let col = (global_idx % (uniforms.N / ${d})) * ${d}; | |
| var index1 = global_idx / (uniforms.N / ${d}); | |
| let stride1 = uniforms.M / ${p}; | |
| let row = (index1 % stride1) * ${p}; | |
| let batch = index1 / stride1; | |
| ${n.length===2?``:`let batch_indices = ${a.offsetToIndices(`batch`)};`} | |
| var a_indices: ${c.type.indices}; | |
| ${ka(`a_indices`,c,c.rank-2,a.rank,`batch_indices`)} | |
| ${c.indicesSet(`a_indices`,c.rank-2,0)} | |
| ${c.indicesSet(`a_indices`,c.rank-1,0)} | |
| let a_offset = ${c.indicesToOffset(`a_indices`)}; | |
| var b_indices: ${l.type.indices}; | |
| ${ka(`b_indices`,l,l.rank-2,a.rank,`batch_indices`)} | |
| ${l.indicesSet(`b_indices`,l.rank-2,0)} | |
| ${l.indicesSet(`b_indices`,l.rank-1,0)} | |
| let b_offset = ${l.indicesToOffset(`b_indices`)}; | |
| var values: array<${u.type.value}, ${p}>; | |
| for (var k: u32 = 0u; k < uniforms.K; k = k + ${f}) { | |
| ${S()} | |
| } | |
| for (var i = 0u; i < ${p}u; i++) { | |
| var value = values[i]; | |
| ${b} | |
| ${v} | |
| let cur_indices = ${u.type.indices}(batch, row + i, col); | |
| let offset = ${u.indicesToOffset(`cur_indices`)}; | |
| ${u.setByOffset(`offset / ${d}`,`value`)}; | |
| } | |
| } | |
| `}}}}),Ma,Na,Pa,Fa,Ia,La,Ra,za,Ba=o(()=>{L(),B(),Y(),Ca(),ja(),Ea(),Ma=(e,t)=>e?` | |
| mm_Asub[inputRow][inputCol] = mm_readA(batch, | |
| kStart + inputRow, | |
| globalRowStart / innerElementSize + inputCol${t?`, batchIndices`:``}); | |
| `:` | |
| mm_Asub[inputRow][inputCol] = mm_readA(batch, | |
| globalRow + innerRow, | |
| kStart / innerElementSize + inputCol${t?`, batchIndices`:``}); | |
| `,Na=(e,t)=>e?` | |
| let ACached0 = mm_Asub[k * innerElementSize][localRow]; | |
| let ACached1 = mm_Asub[k * innerElementSize + 1][localRow]; | |
| let ACached2 = mm_Asub[k * innerElementSize + 2][localRow]; | |
| ${t===3?``:`let ACached3 = mm_Asub[k * innerElementSize + 3][localRow];`} | |
| for (var i = 0; i < rowPerThread; i = i + 1) { | |
| acc[i] = BCached0 * ACached0[i] + acc[i]; | |
| acc[i] = BCached1 * ACached1[i] + acc[i]; | |
| acc[i] = BCached2 * ACached2[i] + acc[i]; | |
| ${t===3?``:`acc[i] = BCached3 * ACached3[i] + acc[i];`} | |
| }`:` | |
| for (var i = 0; i < rowPerThread; i = i + 1) { | |
| let ACached = mm_Asub[tileRow + i][k]; | |
| acc[i] = BCached0 * ACached.x + acc[i]; | |
| acc[i] = BCached1 * ACached.y + acc[i]; | |
| acc[i] = BCached2 * ACached.z + acc[i]; | |
| ${t===3?``:`acc[i] = BCached3 * ACached.w + acc[i];`} | |
| }`,Pa=(e,t,n=`f32`,r,i=!1,a=32,o=!1,s=32)=>{let c=t[1]*e[1],l=t[0]*e[0],u=i?c:a,d=i?a:c,f=u/t[0],p=a/t[1];if(!((i&&f===4&&e[1]===4||!i&&(f===3||f===4))&&u%t[0]===0&&a%t[1]===0&&e[0]===4))throw Error(`If transposeA ${i} is true, innerElementSize ${f} and workPerThread[1] ${e[1]} must be 4. | |
| Otherwise, innerElementSize ${f} must be 3 or 4. | |
| tileAWidth ${u} must be divisible by workgroupSize[0]${t[0]}. tileInner ${a} must be divisible by workgroupSize[1] ${t[1]}. colPerThread ${e[0]} must be 4.`);return` | |
| var<workgroup> mm_Asub: array<array<vec${f}<${n}>, ${u/f}>, ${d}>; | |
| var<workgroup> mm_Bsub: array<array<vec4<${n}>, ${l/e[0]}>, ${a}>; | |
| const rowPerThread = ${e[1]}; | |
| const colPerThread = ${e[0]}; | |
| const innerElementSize = ${f}; | |
| const tileInner = ${a}; | |
| @compute @workgroup_size(${t[0]}, ${t[1]}, ${t[2]}) | |
| fn main(@builtin(local_invocation_id) localId : vec3<u32>, | |
| @builtin(global_invocation_id) globalId : vec3<u32>, | |
| @builtin(workgroup_id) workgroupId : vec3<u32>) { | |
| let localRow = i32(localId.y); | |
| let tileRow = localRow * rowPerThread; | |
| let tileCol = i32(localId.x); | |
| let globalRow =i32(globalId.y) * rowPerThread; | |
| let globalCol = i32(globalId.x); | |
| let batch = ${o?`0`:`i32(globalId.z)`}; | |
| ${r?`let batchIndices = ${r.offsetToIndices(`u32(batch)`)};`:``} | |
| let globalRowStart = i32(workgroupId.y) * ${c}; | |
| let num_tiles = ${o?`${Math.ceil(s/a)}`:`(uniforms.dim_inner - 1) / tileInner + 1`}; | |
| var kStart = ${o?`i32(globalId.z) * ${s}`:`0`}; | |
| var acc: array<vec4<${n}>, rowPerThread>; | |
| // Loop over shared dimension. | |
| let tileRowB = localRow * ${p}; | |
| for (var t = 0; t < num_tiles; t = t + 1) { | |
| // Load one tile of A into local memory. | |
| for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { | |
| let inputRow = tileRow + innerRow; | |
| let inputCol = tileCol; | |
| ${Ma(i,r)} | |
| } | |
| // Load one tile of B into local memory. | |
| for (var innerRow = 0; innerRow < ${p}; innerRow = innerRow + 1) { | |
| let inputRow = tileRowB + innerRow; | |
| let inputCol = tileCol; | |
| mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol${r?`, batchIndices`:``}); | |
| } | |
| kStart = kStart + tileInner; | |
| workgroupBarrier(); | |
| // Compute acc values for a single thread. | |
| for (var k = 0; k < tileInner / innerElementSize; k = k + 1) { | |
| let BCached0 = mm_Bsub[k * innerElementSize][tileCol]; | |
| let BCached1 = mm_Bsub[k * innerElementSize + 1][tileCol]; | |
| let BCached2 = mm_Bsub[k * innerElementSize + 2][tileCol]; | |
| ${f===3?``:`let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];`} | |
| ${Na(i,f)} | |
| } | |
| workgroupBarrier(); | |
| } | |
| for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { | |
| mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); | |
| } | |
| }`},Fa=(e,t)=>e?` | |
| mm_Asub[inputRow][inputCol] = mm_readA(batch, | |
| kStart + inputRow, | |
| globalRowStart + inputCol${t?`, batchIndices`:``}); | |
| `:` | |
| mm_Asub[inputRow][inputCol] = mm_readA(batch, | |
| globalRowStart + inputRow, | |
| kStart + inputCol${t?`, batchIndices`:``}); | |
| `,Ia=e=>e?`let ACached = mm_Asub[k][tileRow + innerRow];`:`let ACached = mm_Asub[tileRow + innerRow][k];`,La=(e,t,n=`f32`,r,i=!1,a=32,o=!1,s=32,c=!1)=>{let l=e[1]*t[1],u=e[0]*t[0],d=i?l:a,f=i?a:l;if(!(f%t[1]===0&&d%t[0]===0&&a%t[1]===0))throw Error(`tileAHight ${f} must be divisible by workgroupSize[1]${t[1]}, tileAWidth ${d} must be divisible by workgroupSize[0]${t[0]}, tileInner ${a} must be divisible by workgroupSize[1]${t[1]}`);let p=f/t[1],m=d/t[0],h=a/t[1],g=c?` | |
| let localRow = i32(localId.y); | |
| let localCol = i32(localId.x); | |
| let globalRowStart = i32(workgroupId.y) * ${l}; | |
| let globalColStart = i32(workgroupId.x) * ${u}; | |
| // Loop over shared dimension. | |
| for (var t = 0; t < num_tiles; t = t + 1) { | |
| // Load one tile of A into local memory. | |
| for (var inputRow = localRow; inputRow < ${f}; inputRow = inputRow + ${t[1]}) { | |
| for (var inputCol = localCol; inputCol < ${d}; inputCol = inputCol + ${t[0]}) { | |
| ${Fa(i,r)} | |
| } | |
| } | |
| // Load one tile of B into local memory. | |
| for (var inputRow = localRow; inputRow < ${a}; inputRow = inputRow + ${t[1]}) { | |
| for (var inputCol = localCol; inputCol < ${u}; inputCol = inputCol + ${t[0]}) { | |
| mm_Bsub[inputRow][inputCol] = mm_readB(batch, | |
| kStart + inputRow, | |
| globalColStart + inputCol${r?`, batchIndices`:``}); | |
| } | |
| } | |
| kStart = kStart + tileInner; | |
| workgroupBarrier(); | |
| // Compute acc values for a single thread. | |
| var BCached : array<${n}, colPerThread>; | |
| for (var k = 0; k < tileInner; k = k + 1) { | |
| for (var inner = 0; inner < colPerThread; inner = inner + 1) { | |
| BCached[inner] = mm_Bsub[k][localCol + inner * ${t[0]}]; | |
| } | |
| for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { | |
| let ACached = ${i?`mm_Asub[k][localRow + innerRow * ${t[1]}];`:`mm_Asub[localRow + innerRow * ${t[1]}][k];`} | |
| for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { | |
| acc[innerRow][innerCol] = acc[innerRow][innerCol] + | |
| ACached * BCached[innerCol]; | |
| } | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { | |
| let gRow = globalRowStart + localRow + innerRow * ${t[1]}; | |
| for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { | |
| let gCol = globalColStart + localCol + innerCol * ${t[0]}; | |
| mm_write(batch, gRow, gCol, acc[innerRow][innerCol]); | |
| } | |
| } | |
| `:` | |
| let tileRow = i32(localId.y) * rowPerThread; | |
| let tileCol = i32(localId.x) * colPerThread; | |
| let globalRow = i32(globalId.y) * rowPerThread; | |
| let globalCol = i32(globalId.x) * colPerThread; | |
| let globalRowStart = i32(workgroupId.y) * ${l}; | |
| let tileRowA = i32(localId.y) * ${p}; | |
| let tileColA = i32(localId.x) * ${m}; | |
| let tileRowB = i32(localId.y) * ${h}; | |
| // Loop over shared dimension. | |
| for (var t = 0; t < num_tiles; t = t + 1) { | |
| // Load one tile of A into local memory. | |
| for (var innerRow = 0; innerRow < ${p}; innerRow = innerRow + 1) { | |
| for (var innerCol = 0; innerCol < ${m}; innerCol = innerCol + 1) { | |
| let inputRow = tileRowA + innerRow; | |
| let inputCol = tileColA + innerCol; | |
| ${Fa(i,r)} | |
| } | |
| } | |
| // Load one tile of B into local memory. | |
| for (var innerRow = 0; innerRow < ${h}; innerRow = innerRow + 1) { | |
| for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { | |
| let inputRow = tileRowB + innerRow; | |
| let inputCol = tileCol + innerCol; | |
| mm_Bsub[inputRow][inputCol] = mm_readB(batch, | |
| kStart + inputRow, | |
| globalCol + innerCol${r?`, batchIndices`:``}); | |
| } | |
| } | |
| kStart = kStart + tileInner; | |
| workgroupBarrier(); | |
| // Compute acc values for a single thread. | |
| var BCached : array<${n}, colPerThread>; | |
| for (var k = 0; k < tileInner; k = k + 1) { | |
| for (var inner = 0; inner < colPerThread; inner = inner + 1) { | |
| BCached[inner] = mm_Bsub[k][tileCol + inner]; | |
| } | |
| for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { | |
| ${Ia(i)} | |
| for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { | |
| acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; | |
| } | |
| } | |
| } | |
| workgroupBarrier(); | |
| } | |
| for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { | |
| for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { | |
| mm_write(batch, globalRow + innerRow, globalCol + innerCol, | |
| acc[innerRow][innerCol]); | |
| } | |
| } | |
| `;return` | |
| var<workgroup> mm_Asub : array<array<${n}, ${d}>, ${f}>; | |
| var<workgroup> mm_Bsub : array<array<${n}, ${u}>, ${a}>; | |
| const rowPerThread = ${e[1]}; | |
| const colPerThread = ${e[0]}; | |
| const tileInner = ${a}; | |
| @compute @workgroup_size(${t[0]}, ${t[1]}, ${t[2]}) | |
| fn main(@builtin(local_invocation_id) localId : vec3<u32>, | |
| @builtin(global_invocation_id) globalId : vec3<u32>, | |
| @builtin(workgroup_id) workgroupId : vec3<u32>) { | |
| let batch = ${o?`0`:`i32(globalId.z)`}; | |
| ${r?`let batchIndices = ${r.offsetToIndices(`u32(batch)`)};`:``} | |
| let num_tiles = ${o?`${Math.ceil(s/a)}`:`(uniforms.dim_inner - 1) / tileInner + 1`}; | |
| var kStart = ${o?`i32(globalId.z) * ${s}`:`0`}; | |
| var acc : array<array<${n}, colPerThread>, rowPerThread>; | |
| ${g} | |
| } | |
| `},Ra=(e,t,n,r,i=!1)=>{let[a,o,s,c]=r,l=U(r[0].type.tensor);return` | |
| fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${a.type.indices}) -> ${wa(e,l)} { | |
| var value = ${wa(e,l)}(0.0); | |
| let col = colIn * ${e}; | |
| if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) | |
| { | |
| var aIndices: ${o.type.indices}; | |
| ${ka(`aIndices`,o,o.rank-2,a.rank,`batchIndices`)} | |
| ${o.indicesSet(`aIndices`,o.rank-2,`u32(row)`)} | |
| ${o.indicesSet(`aIndices`,o.rank-1,`u32(colIn)`)} | |
| value = ${o.getByIndices(`aIndices`)}; | |
| } | |
| return value; | |
| } | |
| fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${a.type.indices}) -> ${wa(e,l)} { | |
| var value = ${wa(e,l)}(0.0); | |
| let col = colIn * ${e}; | |
| if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) | |
| { | |
| var bIndices: ${s.type.indices}; | |
| ${ka(`bIndices`,s,s.rank-2,a.rank,`batchIndices`)} | |
| ${s.indicesSet(`bIndices`,s.rank-2,`u32(row)`)} | |
| ${s.indicesSet(`bIndices`,s.rank-1,`u32(colIn)`)} | |
| value = ${s.getByIndices(`bIndices`)}; | |
| } | |
| return value; | |
| } | |
| fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${wa(e,l)}) { | |
| let col = colIn * ${e}; | |
| if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { | |
| var value = valueIn; | |
| let coords = vec3<i32>(batch, row, colIn); | |
| ${t?`value = value + ${i?`bias[colIn]`:`${wa(e,l)}(bias[row])`};`:``} | |
| ${n} | |
| ${c.setByIndices(`vec3<u32>(coords)`,`value`)} | |
| } | |
| } | |
| `},za=(e,t,n,r,i=!1,a)=>{let o=e[0].dims,s=e[1].dims,c=o.slice(0,-2),l=s.slice(0,-2),u=r?r.slice(0,-2):n.slice(0,-2),d=z.size(u),f=o[o.length-2],p=o[o.length-1],m=s[s.length-1],h=p%4==0&&m%4==0,g=f<=8?[4,1,1]:[4,4,1],_=[8,8,1],v=[Math.ceil(m/_[0]/g[0]),Math.ceil(f/_[1]/g[1]),Math.ceil(d/_[2]/g[2])],y=h?4:1,b=[...c,f,p/y],x=b.length,S=[...l,p,m/y],C=S.length,w=[d,f,m/y],ee=[{type:6,data:f},{type:6,data:m},{type:6,data:p}];ba(t,ee),ee.push(...W(u,b,S));let T=[`rank`,`rank`],E=e.length>2;return E&&(ee.push(...W(e[2].dims)),T.push(`rank`)),ee.push(...W(w)),{name:`MatMul`,shaderCache:{hint:`${g};${t.activation};${h};${i}`,inputDependencies:T},getRunData:()=>({outputs:[{dims:a?a(n):n,dataType:e[0].dataType}],dispatchGroup:{x:v[0],y:v[1],z:v[2]},programUniforms:ee}),getShaderSource:n=>{let r=u.length,a=kn(`batchDims`,e[0].dataType,r,1),o=U(e[0].dataType),s=q(`a`,e[0].dataType,x,y),c=q(`b`,e[1].dataType,C,y),l=J(`result`,e[0].dataType,w.length,y),d=[s,c];if(E){let t=i?y:1;d.push(q(`bias`,e[2].dataType,e[2].dims.length,t))}let f=[{name:`dim_a_outer`,type:`i32`},{name:`dim_b_outer`,type:`i32`},{name:`dim_inner`,type:`i32`}];xa(t,f);let p=U(l.type.tensor),m=ya(t,l.type.value,p),v=Ra(y,E,m,[a,s,c,l],i);return` | |
| ${n.registerUniforms(f).registerInternalVariables(a).declareVariables(...d,l)} | |
| ${v} | |
| ${h?Pa(g,_,o,a):La(g,_,o,a)} | |
| `}}}}),Va,Ha,Ua=o(()=>{L(),Lt(),Y(),Ca(),Ea(),Oa(),Ba(),Va=(e,t,n,r,i=!1,a,o=4,s=4,c=4,l=`f32`)=>{let u=e=>{switch(e){case 1:return`resData = x[xIndex];`;case 3:return`resData = vec3<${l}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`;case 4:return`resData = x[xIndex / 4];`;default:throw Error(`innerElementSize ${e} is not supported.`)}},d=e=>{switch(e){case 1:return`return w[row * i32(uniforms.w_shape[3]) + colIn];`;case 4:return`return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];`;default:throw Error(`innerElementSize ${e} is not supported.`)}},f=e?` | |
| let coord = vec4<i32>(batch, xRow, xCol, xCh); | |
| `:` | |
| let coord = vec4<i32>(batch, xCh, xRow, xCol); | |
| `,p=e?` | |
| let coords = vec4<i32>( | |
| batch, | |
| row / outWidth, | |
| row % outWidth, | |
| col); | |
| `:` | |
| let coords = vec4<i32>( | |
| batch, | |
| row, | |
| col / outWidth, | |
| col % outWidth); | |
| `,m=e?`i32(uniforms.x_shape[1])`:`i32(uniforms.x_shape[2])`,h=e?`i32(uniforms.x_shape[2])`:`i32(uniforms.x_shape[3])`,g=e?`row`:`col`,_=e?`col`:`row`,v=` | |
| let inChannels = i32(uniforms.w_shape[2]); | |
| let outWidth = ${e?`i32(uniforms.result_shape[2])`:`i32(uniforms.result_shape[3])`}; | |
| let outRow = ${g} / outWidth; | |
| let outCol = ${g} % outWidth; | |
| let WRow = ${_} / (i32(uniforms.w_shape[1]) * inChannels); | |
| let WCol = ${_} / inChannels % i32(uniforms.w_shape[1]); | |
| let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; | |
| let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; | |
| let xCh = ${_} % inChannels; | |
| var resData = ${wa(o,l)}(0.0); | |
| // The bounds checking is always needed since we use it to pad zero for | |
| // the 'same' padding type. | |
| if (xRow >= 0 && xRow < ${m} && xCol >= 0 && xCol < ${h}) { | |
| ${f} | |
| let xIndex = getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape)); | |
| ${u(o)} | |
| } | |
| return resData;`,y=e?t&&r?` | |
| let col = colIn * ${o}; | |
| ${v}`:` | |
| let col = colIn * ${o}; | |
| if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { | |
| ${v} | |
| } | |
| return ${wa(o,l)}(0.0);`:r&&n?` | |
| let col = colIn * ${o}; | |
| ${v}`:` | |
| let col = colIn * ${o}; | |
| if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { | |
| ${v} | |
| } | |
| return ${wa(o,l)}(0.0);`,b=e?r&&n?d(s):` | |
| let col = colIn * ${s}; | |
| if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { | |
| ${d(s)} | |
| } | |
| return ${wa(s,l)}(0.0);`:` | |
| let col = colIn * ${s}; | |
| if (row < uniforms.dim_inner && col < uniforms.dim_a_outer) { | |
| ${d(s)} | |
| } | |
| return ${wa(s,l)}(0.0);`,x=wa(c,l),S=wa(e?o:s,l),C=wa(e?s:o,l),w=ya(a,x,l);return` | |
| fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${S} { | |
| ${e?y:b} | |
| } | |
| fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${C} { | |
| ${e?b:y} | |
| } | |
| fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${x}) { | |
| let col = colIn * ${c}; | |
| if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) | |
| { | |
| var value = valueIn; | |
| let outWidth = ${e?`i32(uniforms.result_shape[2])`:`i32(uniforms.result_shape[3])`}; | |
| ${p} | |
| ${Ta(i)} | |
| ${w} | |
| setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); | |
| } | |
| }`},Ha=(e,t,n,r,i,a,o,s,c)=>{let l=t.format===`NHWC`,u=l?e[0].dims[3]:e[0].dims[1],d=n[0],f=l?n[2]:n[3],p=l?n[1]:n[2],m=l?n[3]:n[1],h=l&&(u%4==0||u%3==0)&&m%4==0,g=l?m:f*p,_=l?f*p:m,v=[8,8,1],y=r<=8?[4,1,1]:[4,4,1],b=[Math.ceil(g/v[0]/y[0]),Math.ceil(_/v[1]/y[1]),Math.ceil(d/v[2]/y[2])];R(`verbose`,()=>`[conv2d_mm_webgpu] dispatch = ${b}`);let x=h?l&&u%4!=0?3:4:1,S=v[1]*y[1],C=v[0]*y[0],w=Math.max(v[0]*x,v[1]),ee=r%S===0,T=i%C===0,E=a%w===0,D=h?[x,4,4]:[1,1,1],O=[{type:6,data:r},{type:6,data:i},{type:6,data:a},{type:6,data:[t.pads[0],t.pads[1]]},{type:6,data:t.strides},{type:6,data:t.dilations}];ba(t,O),O.push(...W(e[0].dims,e[1].dims));let k=[`rank`,`rank`];return o&&(O.push(...W(e[2].dims)),k.push(`rank`)),O.push(...W(n)),{name:`Conv2DMatMul`,shaderCache:{hint:`${t.cacheKey};${x};${h};${ee};${T};${E};${S};${C};${w}`,inputDependencies:k},getRunData:()=>({outputs:[{dims:c?c(n):n,dataType:e[0].dataType}],dispatchGroup:{x:b[0],y:b[1],z:b[2]},programUniforms:O}),getShaderSource:r=>{let i=[{name:`dim_a_outer`,type:`i32`},{name:`dim_b_outer`,type:`i32`},{name:`dim_inner`,type:`i32`},{name:`pad`,type:`i32`,length:2},{name:`stride`,type:`i32`,length:2},{name:`dilation`,type:`i32`,length:2}];xa(t,i);let a=h?4:1,c=U(e[0].dataType),u=` | |
| fn setOutputAtIndex(flatIndex : i32, value : ${h?`vec4<${c}>`:c}) { | |
| result[flatIndex] = ${h?`vec4<${c}>`:c}(value); | |
| } | |
| fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${h?`vec4<${c}>`:c}) { | |
| let flatIndex = getOutputIndexFromCoords(vec4<i32>(d0, d1, d2, d3)); | |
| setOutputAtIndex(flatIndex ${h?`/ 4`:``}, value); | |
| }`,d=[q(`x`,e[0].dataType,e[0].dims.length,x===3?1:x),q(`w`,e[1].dataType,e[1].dims.length,a)],f=J(`result`,e[0].dataType,n.length,a);if(o){let t=q(`bias`,e[2].dataType,e[2].dims.length,a);d.push(t),u+=` | |
| fn getBiasByOutputCoords(coords : vec4<i32>) -> ${h?`vec4<${c}>`:c} { | |
| return bias[coords.${l?`w`:`y`}${h?`/ 4`:``}]; | |
| }`}return` | |
| ${Da(`uniforms.result_strides`)} | |
| //struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>, | |
| // outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>, | |
| // dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; | |
| ${r.registerUniforms(i).declareVariables(...d,f)} | |
| ${u} | |
| ${Va(l,ee,T,E,o,t,D[0],D[1],D[2],c)} | |
| ${h?Pa(y,v,c,void 0,!l,w):La(y,v,c,void 0,!l,w,!1,void 0,s)}`}}}}),Wa,Ga,Ka,qa,Ja,Ya,Xa,Za,Qa=o(()=>{L(),Lt(),B(),Y(),Ca(),Ea(),Wa=e=>{let t=1;for(let n=0;n<e.length;n++)t*=e[n];return t},Ga=e=>typeof e==`number`?[e,e,e]:e,Ka=(e,t)=>t<=1?e:e+(e-1)*(t-1),qa=(e,t,n,r=1)=>{let i=Ka(t,r);return Math.floor((e[0]*(n-1)-n+i)/2)},Ja=(e,t,n,r,i)=>{i??=qa(e,t[0],r[0]);let a=[0,0,0,n];for(let n=0;n<3;n++)e[n]+2*i>=t[n]&&(a[n]=Math.trunc((e[n]-t[n]+2*i)/r[n]+1));return a},Ya=(e,t,n,r,i,a,o,s,c,l)=>{let u,d,f,p;if(e===`VALID`&&(e=0),typeof e==`number`){u={top:e,bottom:e,left:e,right:e,front:e,back:e};let m=Ja([t,n,r,1],[s,c,l],1,[i,a,o],e);d=m[0],f=m[1],p=m[2]}else if(Array.isArray(e)){if(!e.every((e,t,n)=>e===n[0]))throw Error(`Unsupported padding parameter: ${e}`);u={top:e[0],bottom:e[1],left:e[2],right:e[3],front:e[4],back:e[5]};let m=Ja([t,n,r,1],[s,c,l],1,[i,a,o],e[0]);d=m[0],f=m[1],p=m[2]}else if(e===`SAME_UPPER`){d=Math.ceil(t/i),f=Math.ceil(n/a),p=Math.ceil(r/o);let e=(d-1)*i+s-t,m=(f-1)*a+c-n,h=(p-1)*o+l-r,g=Math.floor(e/2),_=e-g,v=Math.floor(m/2),y=m-v,b=Math.floor(h/2);u={top:v,bottom:y,left:b,right:h-b,front:g,back:_}}else throw Error(`Unknown padding parameter: ${e}`);return{padInfo:u,outDepth:d,outHeight:f,outWidth:p}},Xa=(e,t,n,r,i,a=!1,o=`channelsLast`)=>{let s,c,l,u,d;if(o===`channelsLast`)[s,c,l,u,d]=e;else if(o===`channelsFirst`)[s,d,c,l,u]=e;else throw Error(`Unknown dataFormat ${o}`);let[f,,p,m,h]=t,[g,_,v]=Ga(n),[y,b,x]=Ga(r),S=Ka(p,y),C=Ka(m,b),w=Ka(h,x),{padInfo:ee,outDepth:T,outHeight:E,outWidth:D}=Ya(i,c,l,u,g,_,v,S,C,w),O=a?f*d:f,k=[0,0,0,0,0];return o===`channelsFirst`?k=[s,O,T,E,D]:o===`channelsLast`&&(k=[s,T,E,D,O]),{batchSize:s,dataFormat:o,inDepth:c,inHeight:l,inWidth:u,inChannels:d,outDepth:T,outHeight:E,outWidth:D,outChannels:O,padInfo:ee,strideDepth:g,strideHeight:_,strideWidth:v,filterDepth:p,filterHeight:m,filterWidth:h,effectiveFilterDepth:S,effectiveFilterHeight:C,effectiveFilterWidth:w,dilationDepth:y,dilationHeight:b,dilationWidth:x,inShape:e,outShape:k,filterShape:t}},Za=(e,t,n,r,i,a)=>{let o=a===`channelsLast`;o?e[0].dims[3]:e[0].dims[1];let s=[64,1,1],c={x:n.map((e,t)=>t)},l=[Math.ceil(Wa(c.x.map(e=>n[e]))/s[0]),1,1];R(`verbose`,()=>`[conv3d_naive_webgpu] dispatch = ${l}`);let u=[{type:12,data:z.size(n)},{type:12,data:r},{type:12,data:i},{type:12,data:t.strides},{type:12,data:t.dilations}];ba(t,u),u.push(...W(e[0].dims,e[1].dims));let d=[`rank`,`rank`],f=e.length===3;return f&&(u.push(...W(e[2].dims)),d.push(`rank`)),u.push(...W(n)),{name:`Conv3DNaive`,shaderCache:{hint:`${t.cacheKey};${o};1;${f}`,inputDependencies:d},getRunData:()=>({outputs:[{dims:n,dataType:e[0].dataType}],dispatchGroup:{x:l[0],y:l[1],z:l[2]},programUniforms:u}),getShaderSource:a=>{let s=[{name:`output_size`,type:`u32`},{name:`filter_dims`,type:`u32`,length:r.length},{name:`pads`,type:`u32`,length:i.length},{name:`strides`,type:`u32`,length:t.strides.length},{name:`dilations`,type:`u32`,length:t.dilations.length}];xa(t,s);let c=U(e[0].dataType),l=q(`x`,e[0].dataType,e[0].dims.length,1),u=q(`W`,e[1].dataType,e[1].dims.length,1),d=[l,u],p=J(`result`,e[0].dataType,n.length,1),m=``;if(f){let t=q(`bias`,e[2].dataType,e[2].dims.length,1);d.push(t),m+=` | |
| fn getBiasByOutputCoords(coords : array<u32, 5>) -> ${c} { | |
| return bias[${o?K(`coords`,4,5):K(`coords`,1,5)}]; | |
| }`}let h=wa(1,c),g=ya(t,h,c);return` | |
| ${m} | |
| fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 { | |
| let aIndices = array<u32, 5>(d0, d1, d2, d3, d4); | |
| return ${l.getByIndices(`aIndices`)}; | |
| } | |
| fn getW(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> f32 { | |
| let aIndices = array<u32, 5>(d0, d1, d2, d3, d4); | |
| return ${u.getByIndices(`aIndices`)}; | |
| } | |
| ${a.registerUniforms(s).declareVariables(...d,p)} | |
| ${a.mainStart()} | |
| ${a.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let coords = ${p.offsetToIndices(`global_idx`)}; | |
| let batch = ${K(`coords`,0,l.rank)}; | |
| let d2 = ${o?K(`coords`,l.rank-1,l.rank):K(`coords`,1,l.rank)}; | |
| let xFRCCorner = vec3<u32>(${o?K(`coords`,1,l.rank):K(`coords`,2,l.rank)}, | |
| ${o?K(`coords`,2,l.rank):K(`coords`,3,l.rank)}, | |
| ${o?K(`coords`,3,l.rank):K(`coords`,4,l.rank)}) * uniforms.strides - uniforms.pads; | |
| let xFCorner = xFRCCorner.x; | |
| let xRCorner = xFRCCorner.y; | |
| let xCCorner = xFRCCorner.z; | |
| let xShapeY = ${o?K(`uniforms.x_shape`,1,l.rank):K(`uniforms.x_shape`,2,l.rank)}; | |
| let xShapeZ = ${o?K(`uniforms.x_shape`,2,l.rank):K(`uniforms.x_shape`,3,l.rank)}; | |
| let xShapeW = ${o?K(`uniforms.x_shape`,3,l.rank):K(`uniforms.x_shape`,4,l.rank)}; | |
| let xShapeU = ${o?K(`uniforms.x_shape`,4,l.rank):K(`uniforms.x_shape`,1,l.rank)}; | |
| let inputDepthNearestVec4 = (xShapeU / 4) * 4; | |
| let inputDepthVec4Remainder = xShapeU % 4; | |
| var value = 0.0; | |
| for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) { | |
| let xF = xFCorner + wF * uniforms.dilations[0]; | |
| if (xF < 0 || xF >= xShapeY) { | |
| continue; | |
| } | |
| for (var wR = 0u; wR < uniforms.filter_dims[1]; wR++) { | |
| let xR = xRCorner + wR * uniforms.dilations[1]; | |
| if (xR < 0 || xR >= xShapeZ) { | |
| continue; | |
| } | |
| for (var wC = 0u; wC < uniforms.filter_dims[2]; wC++) { | |
| let xC = xCCorner + wC * uniforms.dilations[2]; | |
| if (xC < 0 || xC >= xShapeW) { | |
| continue; | |
| } | |
| for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) { | |
| ${o?`let xValues = vec4<f32>( | |
| getX(batch, xF, xR, xC, d1), | |
| getX(batch, xF, xR, xC, d1 + 1), | |
| getX(batch, xF, xR, xC, d1 + 2), | |
| getX(batch, xF, xR, xC, d1 + 3)); | |
| `:`let xValues = vec4<f32>( | |
| getX(batch, d1, xF, xR, xC), | |
| getX(batch, d1 + 1, xF, xR, xC), | |
| getX(batch, d1 + 2, xF, xR, xC), | |
| getX(batch, d1 + 3, xF, xR, xC)); | |
| `} | |
| let wValues = vec4<f32>( | |
| getW(d2, d1, wF, wR, wC), | |
| getW(d2, d1 + 1, wF, wR, wC), | |
| getW(d2, d1 + 2, wF, wR, wC), | |
| getW(d2, d1 + 3, wF, wR, wC)); | |
| value += dot(xValues, wValues); | |
| } | |
| if (inputDepthVec4Remainder == 1) { | |
| ${o?`value += getX(batch, xF, xR, xC, inputDepthNearestVec4) | |
| * getW(d2, inputDepthNearestVec4, wF, wR, wC);`:`value += getX(batch, inputDepthNearestVec4, xF, xR, xC) | |
| * getW(d2, inputDepthNearestVec4, wF, wR, wC);`} | |
| } else if (inputDepthVec4Remainder == 2) { | |
| ${o?`let xValues = vec2<f32>( | |
| getX(batch, xF, xR, xC, inputDepthNearestVec4), | |
| getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1)); | |
| `:`let xValues = vec2<f32>( | |
| getX(batch, inputDepthNearestVec4, xF, xR, xC), | |
| getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC)); | |
| `} | |
| let wValues = vec2<f32>( | |
| getW(d2, inputDepthNearestVec4, wF, wR, wC), | |
| getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC)); | |
| value += dot(xValues, wValues); | |
| } else if (inputDepthVec4Remainder == 3) { | |
| ${o?`let xValues = vec3<f32>( | |
| getX(batch, xF, xR, xC, inputDepthNearestVec4), | |
| getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1), | |
| getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2)); | |
| `:`let xValues = vec3<f32>( | |
| getX(batch, inputDepthNearestVec4, xF, xR, xC), | |
| getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC), | |
| getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC)); | |
| `} | |
| let wValues = vec3<f32>( | |
| getW(d2, inputDepthNearestVec4, wF, wR, wC), | |
| getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC), | |
| getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC)); | |
| value += dot(xValues, wValues); | |
| } | |
| } | |
| } | |
| } | |
| ${f?`value = value + getBiasByOutputCoords(coords)`:``}; | |
| ${g} | |
| result[global_idx] = f32(value); | |
| }`}}}}),$a,eo,to=o(()=>{L(),B(),Y(),Ca(),$a=(e,t,n,r)=>{let i=e.length>2,a=i?`value += b[output_channel];`:``,o=e[0].dims,s=e[1].dims,c=t.format===`NHWC`,l=c?n[3]:n[1],u=l/t.group,d=c&&u>=4?G(l):1,f=z.size(n)/d,p=[{type:12,data:f},{type:12,data:t.dilations},{type:12,data:[t.strides[0],t.strides[1]]},{type:12,data:[t.pads[0],t.pads[1]]},{type:12,data:u}];ba(t,p),p.push(...W(o,[s[0],s[1],s[2],s[3]/d]));let m=i?[`rank`,`rank`,`rank`]:[`rank`,`rank`];return p.push(...W([n[0],n[1],n[2],n[3]/d])),{name:`GroupedConv`,shaderCache:{hint:`${t.cacheKey}_${d}`,inputDependencies:m},getRunData:()=>({outputs:[{dims:r?r(n):n,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(f/64)},programUniforms:p}),getShaderSource:r=>{let l=J(`output`,e[0].dataType,n.length,d),u=U(l.type.tensor),f=ya(t,l.type.value,u),p=q(`x`,e[0].dataType,o.length),m=q(`w`,e[1].dataType,s.length,d),h=[p,m];i&&h.push(q(`b`,e[2].dataType,e[2].dims,d));let g=[{name:`output_size`,type:`u32`},{name:`dilations`,type:`u32`,length:t.dilations.length},{name:`strides`,type:`u32`,length:2},{name:`pads`,type:`u32`,length:2},{name:`output_channels_per_group`,type:`u32`}];xa(t,g);let _=c?` | |
| for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[0]; wHeight++) { | |
| let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; | |
| if (xHeight < 0u || xHeight >= uniforms.x_shape[1]) { | |
| continue; | |
| } | |
| for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[1]; wWidth++) { | |
| let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; | |
| if (xWidth < 0u || xWidth >= uniforms.x_shape[2]) { | |
| continue; | |
| } | |
| for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[2]; wInChannel++) { | |
| let input_channel = in_channel_offset + wInChannel; | |
| let xVal = ${p.get(`batch`,`xHeight`,`xWidth`,`input_channel`)}; | |
| let wVal = ${m.get(`wHeight`,`wWidth`,`wInChannel`,`output_channel`)}; | |
| value += xVal * wVal; | |
| } | |
| } | |
| } | |
| `:` | |
| for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { | |
| let input_channel = in_channel_offset + wInChannel; | |
| for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { | |
| let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; | |
| if (xHeight < 0u || xHeight >= uniforms.x_shape[2]) { | |
| continue; | |
| } | |
| for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { | |
| let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; | |
| if (xWidth < 0u || xWidth >= uniforms.x_shape[3]) { | |
| continue; | |
| } | |
| let xVal = ${p.get(`batch`,`input_channel`,`xHeight`,`xWidth`)}; | |
| let wVal = ${m.get(`output_channel`,`wInChannel`,`wHeight`,`wWidth`)}; | |
| value += xVal * wVal; | |
| } | |
| } | |
| } | |
| `;return` | |
| ${r.registerUniforms(g).declareVariables(...h,l)} | |
| ${r.mainStart()} | |
| ${r.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let outputIndices = ${l.offsetToIndices(`global_idx`)}; | |
| let batch: u32 = outputIndices[0]; | |
| let output_channel: u32 = outputIndices[${c?3:1}]; | |
| let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[${c?1:2}], outputIndices[${c?2:3}]) * uniforms.strides - uniforms.pads; | |
| let group_id: u32 = output_channel * ${d} / uniforms.output_channels_per_group; | |
| var in_channel_offset = group_id * uniforms.w_shape[${c?2:1}]; | |
| var value: ${l.type.value} = ${l.type.value}(0); | |
| ${_} | |
| ${a} | |
| ${f} | |
| ${l.setByOffset(`global_idx`,`value`)} | |
| }`}}},eo=(e,t,n,r)=>{let i=e.length>2,a=G(n[3]),o=G(n[2]),s=z.size(n)/a/o,c=[e[0].dims[0],e[0].dims[1],e[0].dims[2],e[0].dims[3]/a],l=[e[1].dims[0],e[1].dims[1],e[1].dims[2],e[1].dims[3]/a],u=[n[0],n[1],n[2],n[3]/a],d=[{type:12,data:s},{type:6,data:[t.strides[0],t.strides[1]]},{type:6,data:[t.pads[0],t.pads[1]]}];ba(t,d),d.push(...W(c,l,u));let f=(o-1)*t.strides[1]+l[1];return{name:`GroupedConv-Vectorize`,shaderCache:{hint:`${t.cacheKey};${a};${o};${f};${l[0]};${l[1]}`,inputDependencies:i?[`rank`,`rank`,`type`]:[`rank`,`rank`]},getRunData:()=>({outputs:[{dims:r?r(n):n,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(s/64)},programUniforms:d}),getShaderSource:n=>{let r=J(`output`,e[0].dataType,u.length,a),s=U(r.type.tensor),d=ya(t,r.type.value,s),p=q(`x`,e[0].dataType,c.length,a),m=q(`w`,e[1].dataType,l.length,a),h=[p,m];i&&h.push(q(`b`,e[2].dataType,e[2].dims,a));let g=i?`value += b[output_channel];`:``,_=[{name:`output_size`,type:`u32`},{name:`strides`,type:`i32`,length:2},{name:`pads`,type:`i32`,length:2}];return xa(t,_),` | |
| ${n.registerUniforms(_).declareVariables(...h,r)} | |
| ${n.mainStart()} | |
| ${n.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let width0 = uniforms.output_shape[3]; | |
| let output_channel = global_idx % width0; | |
| var index1 = global_idx / width0; | |
| let width1 = uniforms.output_shape[2] / ${o}u; | |
| let col = (index1 % width1) * ${o}u; | |
| index1 = index1 / width1; | |
| let row = index1 % uniforms.output_shape[1]; | |
| let batch = index1 / uniforms.output_shape[1]; | |
| let x_corner = vec2<i32>(i32(row), i32(col)) * uniforms.strides - uniforms.pads; | |
| var x_vals: array<${p.type.value}, ${f}>; | |
| var values: array<${r.type.value}, ${o}>; | |
| let input_channel = output_channel; | |
| // Use constant instead of uniform can give better performance for w's height/width. | |
| for (var w_height: u32 = 0u; w_height < ${l[0]}; w_height++) { | |
| let x_height = x_corner.x + i32(w_height); | |
| if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { | |
| for (var i = 0; i < ${f}; i++) { | |
| let x_width = x_corner.y + i; | |
| if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { | |
| x_vals[i] = ${p.get(`batch`,`u32(x_height)`,`u32(x_width)`,`input_channel`)}; | |
| } else { | |
| x_vals[i] = ${p.type.value}(0); | |
| } | |
| } | |
| for (var w_width: u32 = 0u; w_width < ${l[1]}; w_width++) { | |
| let w_val = ${m.get(`w_height`,`w_width`,`0`,`output_channel`)}; | |
| for (var i = 0u; i < ${o}u; i++) { | |
| values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); | |
| } | |
| } | |
| } | |
| } | |
| for (var i = 0u; i < ${o}u; i++) { | |
| var value = values[i]; | |
| ${g} | |
| ${d} | |
| ${r.set(`batch`,`row`,`col + i`,`output_channel`,`value`)}; | |
| } | |
| }`}}}}),no,ro,io,ao,oo,so,co,lo,uo,fo=o(()=>{B(),Ua(),Qa(),Ba(),to(),Ca(),ja(),Vn(),no=(e,t,n,r,i,a)=>{let o=e[0],s=e.slice(a?1:2,a?3:4),c=s.length,l=t[0],u=t.slice(2).map((e,t)=>e+(e-1)*(n[t]-1)),d=s.map((e,t)=>e+r[t]+r[t+c]).map((e,t)=>Math.floor((e-u[t]+i[t])/i[t]));return d.splice(0,0,o),d.splice(a?3:1,0,l),d},ro=[2,3,1,0],io=(e,t)=>{if(!e||e.length!==2&&e.length!==3)throw Error(`Conv requires 2 or 3 inputs`);if(e[0].dims.length>5)throw Error(`greater than 5D is not supported`);if(e[0].dims.length!==e[1].dims.length)throw Error(`filter does not have same dimension as input`);if(e[0].dims[t.format===`NHWC`?e[0].dims.length-1:1]!==e[1].dims[1]*t.group)throw Error(`FILTER_IN_CHANNEL should be equal to DATA_CHANNEL`);if(e.length===3&&(e[2].dims.length!==1||e[1].dims[0]!==e[2].dims[0]))throw Error(`invalid bias`);let n=e[0].dims.length-2;if(t.dilations.length!==n)throw Error(`dilations should be ${n}D`);if(t.strides.length!==n)throw Error(`strides should be ${n}D`);if(t.pads.length!==n*2)throw Error(`pads should be ${n*2}D`);if(t.kernelShape.length!==0&&t.kernelShape.length!==e[1].dims.length-2)throw Error(`invalid kernel shape`)},ao=(e,t)=>{let n=e.kernelShape.slice();n.length<t[1].dims.length-2&&n.push(...Array(t[1].dims.length-2-n.length).fill(0));for(let e=2;e<t[1].dims.length;++e)n[e-2]===0&&(n[e-2]=t[1].dims[e]);let r=e.pads.slice();Bt.adjustPadsBasedOnAutoPad(t[0].dims,e.strides,e.dilations,n,r,e.format===`NHWC`,e.autoPad);let i=Object.assign({},e);return Object.assign(i,{kernelShape:n,pads:r}),i},oo=e=>{let t=Sa(e),n=e.format;return{autoPad:[`NOTSET`,`VALID`,`SAME_UPPER`,`SAME_LOWER`][e.auto_pad],format:n,dilations:e.dilations,group:e.group,kernelShape:e.kernel_shape,pads:e.pads,strides:e.strides,wIsConst:e.w_is_const(),...t,cacheKey:`${e.format};${t.activation};`}},so=(e,t,n,r)=>{let i=n.format===`NHWC`,a=no(t[0].dims,t[1].dims,n.dilations,n.pads,n.strides,i);if(n.group!==1){let o=[t[0]];if(i){let r=e.kernelCustomData.wT??e.compute(Rn(t[1],ro),{inputs:[1],outputs:[n.wIsConst?-2:-1]})[0];n.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=r),o.push(r)}else o.push(t[1]);t.length===3&&o.push(t[2]),!e.adapterInfo.isArchitecture(`ampere`)&&i&&t[1].dims[0]===n.group&&t[1].dims[1]===1&&n.dilations[0]===1&&n.dilations[1]===1?e.compute(eo(o,n,a,r),{inputs:o}):e.compute($a(o,n,a,r),{inputs:o});return}let o=t.length===3,s=t[0].dims[i?1:2],c=t[0].dims[i?2:3],l=t[0].dims[i?3:1],u=t[1].dims[2],d=t[1].dims[3],f=a[i?1:2],p=a[i?2:3],m=a[i?3:1],h=i&&u===s&&d===c&&n.pads[0]===0&&n.pads[1]===0;if(h||u===1&&d===1&&n.dilations[0]===1&&n.dilations[1]===1&&n.strides[0]===1&&n.strides[1]===1&&n.pads[0]===0&&n.pads[1]===0){let u=a[0],d,g,_,v=[];if(i){let r=e.kernelCustomData.wT??e.compute(Rn(t[1],ro),{inputs:[1],outputs:[n.wIsConst?-2:-1]})[0];if(n.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=r),h){let e=s*c*l;d=t[0].reshape([1,u,e]),g=r.reshape([1,e,m]),_=[1,u,m]}else d=t[0].reshape([u,s*c,l]),g=r.reshape([1,l,m]),_=[u,f*p,m];v.push(d),v.push(g)}else d=t[0].reshape([u,l,s*c]),g=t[1].reshape([1,m,l]),_=[u,m,f*p],v.push(g),v.push(d);o&&v.push(t[2]);let y=_[2],b=v[0].dims[v[0].dims.length-1];y<8&&b<8?e.compute(Aa(v,n,a,_,i,r),{inputs:v}):e.compute(za(v,n,a,_,i,r),{inputs:v});return}let g=e.kernelCustomData.wT??e.compute(Rn(t[1],ro),{inputs:[1],outputs:[n.wIsConst?-2:-1]})[0];n.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=g);let _=[t[0],g];o&&_.push(t[2]);let v=i?f*p:m,y=i?m:f*p,b=u*d*l;e.compute(Ha(_,n,a,v,y,b,o,!0,r),{inputs:_})},co=(e,t)=>{let n=t.format===`NHWC`,r=[e.inputs[0].reshape(n?[e.inputs[0].dims[0],1,e.inputs[0].dims[1],e.inputs[0].dims[2]]:[e.inputs[0].dims[0],e.inputs[0].dims[1],1,e.inputs[0].dims[2]]),e.inputs[1].reshape([e.inputs[1].dims[0],e.inputs[1].dims[1],1,e.inputs[1].dims[2]])];e.inputs.length===3&&r.push(e.inputs[2]);let i=[0,t.pads[0],0,t.pads[1]],a=[1].concat(t.strides),o=[1].concat(t.dilations),s=[1].concat(t.kernelShape),c=ao({...t,pads:i,strides:a,dilations:o,kernelShape:s},r);so(e,r,c,e=>n?[e[0],e[2],e[3]]:[e[0],e[1],e[3]])},lo=(e,t,n)=>{let r=n.format===`NHWC`?`channelsLast`:`channelsFirst`,i=ao(n,t),a=n.autoPad===`NOTSET`?n.pads:n.autoPad,o=Xa(t[0].dims,t[1].dims,n.strides,n.dilations,a,!1,r);e.compute(Za(t,i,o.outShape,[o.filterDepth,o.filterHeight,o.filterWidth],[o.padInfo.front,o.padInfo.top,o.padInfo.left],r))},uo=(e,t)=>{if(io(e.inputs,t),e.inputs[0].dims.length===3)co(e,t);else if(e.inputs[0].dims.length===5)lo(e,e.inputs,t);else{let n=ao(t,e.inputs);so(e,e.inputs,n)}}}),po,mo=o(()=>{L(),Lt(),B(),Y(),po=(e,t,n)=>{let r=e.length>2,i=t.outputShape,a=t.format===`NHWC`,o=t.group,s=e[1].dims,c=s[2]/o,l=s[3],u=a?G(c):1,d=a&&l===1&&c>=4,f=d?Math.floor(c/4)*4:Math.floor(c/u)*u,p=c-f,m=a?G(l):1,h=a?l===1?u:m:1,g=z.size(i)/m,_=[Math.ceil(g/64),1,1];R(`verbose`,()=>`[conv2d_backprop_webgpu] dispatch = ${_}`);let v=[`rank`,`rank`],y=[t.strides[0],t.strides[1]],b=[t.kernelShape[a?1:2],t.kernelShape[a?2:3]],x=[t.dilations[0],t.dilations[1]],S=[b[0]+(t.dilations[0]<=1?0:(t.kernelShape[a?1:2]-1)*(t.dilations[0]-1)),b[1]+(t.dilations[1]<=1?0:(t.kernelShape[a?2:3]-1)*(t.dilations[1]-1))],C=[S[0]-1-Math.floor((t.pads[0]+t.pads[2])/2),S[1]-1-Math.floor((t.pads[1]+t.pads[3])/2)],w=[{type:12,data:g},{type:12,data:y},{type:12,data:b},{type:12,data:x},{type:12,data:S},{type:6,data:C},{type:12,data:f},{type:12,data:c},{type:12,data:l},...W(e[0].dims,e[1].dims)];return r&&(w.push(...W(e[2].dims)),v.push(`rank`)),w.push(...W(i)),{name:`ConvTranspose2D`,shaderCache:{hint:`${t.cacheKey};${u}${h}${m}${d}${p}`,inputDependencies:v},getRunData:()=>({dispatchGroup:{x:_[0],y:_[1],z:_[2]},outputs:[{dims:n?n(i):i,dataType:e[0].dataType}],programUniforms:w}),getShaderSource:t=>{let n=[{name:`output_size`,type:`u32`},{name:`strides`,type:`u32`,length:y.length},{name:`filter_dims`,type:`u32`,length:b.length},{name:`dilations`,type:`u32`,length:b.length},{name:`effective_filter_dims`,type:`u32`,length:S.length},{name:`pads`,type:`i32`,length:C.length},{name:`input_channels_per_group_int`,type:`u32`},{name:`input_channels_per_group`,type:`u32`},{name:`output_channels_per_group`,type:`u32`}],o=U(e[0].dataType),s=a?1:2,c=a?2:3,l=a?3:1,f=q(`W`,e[1].dataType,e[1].dims.length,h),g=q(`Dy`,e[0].dataType,e[0].dims.length,u),_=[g,f];r&&_.push(q(`bias`,e[2].dataType,[i[l]].length,m));let v=J(`result`,e[0].dataType,i.length,m),x=` | |
| let outputIndices = ${v.offsetToIndices(`global_idx * ${m}`)}; | |
| let batch = ${v.indicesGet(`outputIndices`,0)}; | |
| let d1 = ${v.indicesGet(`outputIndices`,l)}; | |
| let r = ${v.indicesGet(`outputIndices`,s)}; | |
| let c = ${v.indicesGet(`outputIndices`,c)}; | |
| let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads; | |
| let dyRCorner = dyCorner.x; | |
| let dyCCorner = dyCorner.y; | |
| let groupId = d1 / uniforms.output_channels_per_group; | |
| let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; | |
| // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). | |
| // ? = to be determined. : = across all values in that axis. | |
| var dotProd = ${v.type.value}(0.0); | |
| var wR: u32 = 0; | |
| if (uniforms.dilations.x == 1) { | |
| // Minimum wR >= 0 that satisfies (dyRCorner + wR) % (uniforms.strides.x) == 0 | |
| wR = u32(((dyRCorner + i32(uniforms.strides.x) - 1) / i32(uniforms.strides.x)) * i32(uniforms.strides.x) - dyRCorner); | |
| } | |
| for (; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { | |
| if (wR % uniforms.dilations.x != 0) { | |
| continue; | |
| } | |
| let dyR = (${o}(dyRCorner) + ${o}(wR)) / ${o}(uniforms.strides[0]); | |
| let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; | |
| if (dyR < 0.0 || dyR >= ${o}(uniforms.Dy_shape[${s}]) || fract(dyR) > 0.0 || | |
| wRPerm < 0) { | |
| continue; | |
| } | |
| let idyR: u32 = u32(dyR); | |
| var wC: u32 = 0; | |
| if (uniforms.dilations.y == 1) { | |
| // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0 | |
| wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner); | |
| } | |
| for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { | |
| if (wC % uniforms.dilations.y != 0) { | |
| continue; | |
| } | |
| let dyC = (${o}(dyCCorner) + ${o}(wC)) / ${o}(uniforms.strides.y); | |
| let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; | |
| if (dyC < 0.0 || dyC >= ${o}(uniforms.Dy_shape[${c}]) || | |
| fract(dyC) > 0.0 || wCPerm < 0) { | |
| continue; | |
| } | |
| let idyC: u32 = u32(dyC); | |
| var inputChannel = groupId * uniforms.input_channels_per_group; | |
| ${d?` | |
| var x_offset = ${g.indicesToOffset(`${g.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${u}; | |
| var w_offset = ${f.indicesToOffset(`${f.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${h}; | |
| `:``} | |
| for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${d?4:u}) { | |
| ${(()=>{let e=``;if(d)u===4?e+=` | |
| let xValue = ${g.getByOffset(`x_offset`)}; | |
| let wValue = ${f.getByOffset(`w_offset`)}; | |
| dotProd = dotProd + dot(xValue, wValue); | |
| x_offset += 1u; | |
| w_offset += 1u;`:u===2?e+=` | |
| dotProd = dotProd + dot(vec4<${o}>(${g.getByOffset(`x_offset`)}, ${g.getByOffset(`x_offset + 1u`)}), vec4<${o}>(${f.getByOffset(`w_offset`)}, ${f.getByOffset(`w_offset + 1u`)})); | |
| x_offset += 2u; | |
| w_offset += 2u;`:u===1&&(e+=` | |
| dotProd = dotProd + dot(vec4<${o}>(${g.getByOffset(`x_offset`)}, ${g.getByOffset(`x_offset + 1u`)}, ${g.getByOffset(`x_offset + 2u`)}, ${g.getByOffset(`x_offset + 3u`)}), vec4<${o}>(${f.getByOffset(`w_offset`)}, ${f.getByOffset(`w_offset + 1u`)}, ${f.getByOffset(`w_offset + 2u`)}, ${f.getByOffset(`w_offset + 3u`)})); | |
| x_offset += 4u; | |
| w_offset += 4u;`);else if(e+=` | |
| let xValue = ${a?g.getByOffset(`${g.indicesToOffset(`${g.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${u}`):g.get(`batch`,`inputChannel`,`idyR`,`idyC`)}; | |
| `,u===1)e+=` | |
| let w_offset = ${f.indicesToOffset(`${f.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)}; | |
| let wValue = ${f.getByOffset(`w_offset / ${h}`)}; | |
| dotProd = dotProd + xValue * wValue;`;else for(let t=0;t<u;t++)e+=` | |
| let wValue${t} = ${f.getByOffset(`${f.indicesToOffset(`${f.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel + ${t}, wOutChannel)`)} / ${h}`)}; | |
| dotProd = dotProd + xValue[${t}] * wValue${t};`;return e})()} | |
| inputChannel = inputChannel + ${d?4:u}; | |
| } | |
| ${(()=>{if(p===0)return``;if(!d)throw Error(`packInputAs4 ${d} is not true.`);let e=``;if(u===1){e+=`dotProd = dotProd`;for(let t=0;t<p;t++)e+=` | |
| + ${g.getByOffset(`x_offset + ${t}`)} * ${f.getByOffset(`w_offset + ${t}`)}`;e+=`;`}else if(u===2){if(p!==2)throw Error(`Invalid inputChannelsRemainder ${p}.`);e+=` | |
| let xValue = ${g.getByOffset(`x_offset`)}; | |
| let wValue = ${f.getByOffset(`w_offset`)}; | |
| dotProd = dotProd + dot(xValue, wValue);`}return e})()} | |
| wC = wC + uniforms.strides.y - 1; | |
| } | |
| wR = wR + uniforms.strides[0] - 1; | |
| } | |
| let value = dotProd${r?` + bias[d1 / ${m}]`:``}; | |
| ${v.setByOffset(`global_idx`,`value`)}; | |
| `;return` | |
| ${t.registerUniforms(n).declareVariables(..._,v)} | |
| ${t.mainStart()} | |
| ${t.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)}; | |
| ${x}}`}}}}),ho,go,_o,vo,yo,bo,xo,So,Co,wo=o(()=>{mo(),Ca(),Vn(),ho=(e,t,n,r,i,a)=>(e-1)*t+n+(r-1)*i+1-a,go=(e,t,n,r,i)=>{let a=Math.floor(e/2);t===`SAME_UPPER`?(n[r]=a,n[i]=e-a):t===`SAME_LOWER`&&(n[r]=e-a,n[i]=a)},_o=(e,t,n,r,i,a,o,s,c,l)=>{let u=e.length-2,d=l.length===0;c.length<u&&c.push(...Array(u-c.length).fill(0));let f=e[0],p=t[s?3:1]*i;for(let i=0,f=e.length-u-(s?1:0);i<u;++i,++f){let s=e[f],p=d?s*o[i]:l[i],m=ho(s,o[i],a[i],t[f],n[i],p);go(m,r,a,i,i+u),d&&l.push(o[i]*(s-1)+c[i]+(t[f]-1)*n[i]+1-a[i]-a[i+u])}l.splice(0,0,f),l.splice(s?3:1,0,p)},vo=(e,t)=>{let n=e.kernelShape.slice();if(e.kernelShape.length===0||e.kernelShape.reduce((e,t)=>e*t,1)===0){n.length=0;for(let e=2;e<t[1].dims.length;++e)n.push(t[1].dims[e])}let r=e.format===`NHWC`;n.splice(0,0,t[1].dims[0]),n.splice(r?3:1,0,t[1].dims[1]);let i=e.pads.slice(),a=e.outputShape.slice(),o=e.outputPadding.slice(),s=t[0].dims,c=e.dilations.slice();if(c.reduce((e,t)=>e+t,0)===0){let e=t[0].dims.length-2;c=Array(e).fill(1)}let l=e.strides.slice();if(l.reduce((e,t)=>e+t,0)===0){let e=t[0].dims.length-2;l=Array(e).fill(1)}_o(s,n,c,e.autoPad,e.group,i,l,r,o,a);let u=Object.assign({},e);return Object.assign(u,{kernelShape:n,pads:i,outputPadding:o,outputShape:a,dilations:c,strides:l}),u},yo=e=>{let t=Sa(e),n=e.format,r=[`NOTSET`,`VALID`,`SAME_UPPER`,`SAME_LOWER`][typeof e.autoPad>`u`?0:e.autoPad],i=e.dilations,a=e.group??1,o=e.kernelShape,s=e.pads,c=e.strides,l=e.wIsConst();return{autoPad:r,format:n,dilations:i,group:a,kernelShape:o,outputPadding:e.outputPadding,outputShape:e.outputShape,pads:s,strides:c,wIsConst:l,...t,cacheKey:`${e.format};${t.activation};`}},bo=(e,t)=>{if(!e||e.length!==2&&e.length!==3)throw Error(`Conv requires 2 or 3 inputs`);if(e[0].dims.length!==4&&e[0].dims.length!==3)throw Error(`currently only support 2-dimensional conv`);if(e[0].dims.length!==e[1].dims.length)throw Error(`filter does not have same dimension as input`);if(e[0].dims[t.format===`NHWC`?e[0].dims.length-1:1]!==e[1].dims[0])throw Error(`FILTER_IN_CHANNEL should be equal to DATA_CHANNEL`);let n=e[1].dims[1]*t.group;if(e.length===3&&(e[2].dims.length!==1||e[2].dims[0]!==n))throw Error(`invalid bias`);let r=e[0].dims.length-2;if(t.dilations.reduce((e,t)=>e+t,0)>0&&t.dilations.length!==r)throw Error(`dilations should be ${r}D`);if(t.strides.reduce((e,t)=>e+t,0)>0&&t.strides.length!==r)throw Error(`strides should be ${r}D`);if(t.pads.reduce((e,t)=>e+t,0)>0&&t.pads.length!==r*2)throw Error(`pads should be ${r*2}D`);if(t.outputPadding.length!==r&&t.outputPadding.length!==0)throw Error(`output_padding should be ${r}D`);if(t.kernelShape.reduce((e,t)=>e+t,0)>0&&t.kernelShape.length!==0&&t.kernelShape.length!==e[1].dims.length-2)throw Error(`invalid kernel shape`);if(t.outputShape.length!==0&&t.outputShape.length!==e[0].dims.length-2)throw Error(`invalid output shape`)},xo=(e,t,n,r)=>{let i=e.kernelCustomData.wT??e.compute(Rn(t[1],[2,3,0,1]),{inputs:[1],outputs:[n.wIsConst?-2:-1]})[0];n.wIsConst&&!e.kernelCustomData.wT&&(e.kernelCustomData.wT=i);let a=[t[0],i];t.length===3&&a.push(t[2]),e.compute(po(a,n,r),{inputs:a})},So=(e,t)=>{let n=t.format===`NHWC`,r=[e.inputs[0].reshape(n?[e.inputs[0].dims[0],1,e.inputs[0].dims[1],e.inputs[0].dims[2]]:[e.inputs[0].dims[0],e.inputs[0].dims[1],1,e.inputs[0].dims[2]]),e.inputs[1].reshape([e.inputs[1].dims[0],e.inputs[1].dims[1],1,e.inputs[1].dims[2]])];e.inputs.length===3&&r.push(e.inputs[2]);let i=t.kernelShape;(i.length===0||i[0]===0)&&(i=[e.inputs[1].dims[2]]);let a=t.dilations;(a.length===0||a[0]===0)&&(a=[1]);let o=t.strides;(o.length===0||o[0]===0)&&(o=[1]);let s=t.pads;s.length===0&&(s=[0,0]),s=[0,s[0],0,s[1]],o=[1].concat(o),a=[1].concat(a),i=[1].concat(i);let c=t.outputPadding;c=[0].concat(c);let l=vo({...t,pads:s,strides:o,dilations:a,kernelShape:i,outputPadding:c},r);xo(e,r,l,e=>n?[e[0],e[2],e[3]]:[e[0],e[1],e[3]])},Co=(e,t)=>{if(bo(e.inputs,t),e.inputs[0].dims.length===3)So(e,t);else{let n=vo(t,e.inputs);xo(e,e.inputs,n)}}}),To,Eo,Do,Oo=o(()=>{L(),B(),H(),Y(),To=(e,t,n,r)=>{let i=z.size(t),a=t.length,o=q(`input`,e,a),s=J(`output`,e,a),c=n.dataType===6?n.getInt32Array()[0]:Number(n.getBigInt64Array()[0]),l=z.normalizeAxis(c,a);return{name:`CumSum`,shaderCache:{hint:r.cacheKey,inputDependencies:[`rank`]},getRunData:()=>({outputs:[{dims:t,dataType:e}],dispatchGroup:{x:Math.ceil(i/64)},programUniforms:[{type:12,data:i},{type:12,data:l},...W(t,t)]}),getShaderSource:e=>{let t=` i32(${o.indicesGet(`inputIndices`,`uniforms.axis`)}) `,n=K(`uniforms.input_shape`,`uniforms.axis`,a),i=r.reverse?t+(r.exclusive?` + 1`:``):`0`,c=r.reverse?n:t+(r.exclusive?``:` + 1`);return` | |
| ${e.registerUniform(`outputSize`,`u32`).registerUniform(`axis`,`u32`).declareVariables(o,s)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| var inputIndices = ${s.offsetToIndices(`global_idx`)}; | |
| var sum = ${s.type.value}(0); | |
| let first : i32 = ${i}; | |
| let last : i32 = ${c}; | |
| for (var i : i32 = first; i < last; i++) { | |
| ${o.indicesSet(`inputIndices`,`uniforms.axis`,`u32(i)`)}; | |
| sum = sum + ${o.getByIndices(`inputIndices`)}; | |
| } | |
| ${s.setByOffset(`global_idx`,`sum`)}; | |
| }`}}},Eo=(e,t)=>{let n=e.inputs[0].dims,r=e.inputs[0].dataType,i=e.inputs[1];e.compute(To(r,n,i,t),{inputs:[0]})},Do=e=>{let t=e.exclusive===1,n=e.reverse===1;return V({exclusive:t,reverse:n})}}),ko,Ao,jo,Mo,No,Po=o(()=>{L(),B(),H(),Y(),ko=e=>{if(!e||e.length!==1)throw Error(`DepthToSpace requires 1 input.`);if(e[0].dims.length!==4)throw Error(`DepthToSpace requires 4D input.`)},Ao=(e,t,n,r)=>{let i=[];i.push(`fn perm(i: ${r.type.indices}) -> ${n.type.indices} { | |
| var a: ${n.type.indices};`);for(let r=0;r<t;++r)i.push(n.indicesSet(`a`,e[r],`i[${r}]`));return i.push(`return a;}`),i.join(` | |
| `)},jo=(e,t)=>{let n,r,i,a,o,s,c=t.format===`NHWC`,l=t.blocksize,u=t.mode===`DCR`;c?([n,r,i,a]=e.dims,o=u?[n,r,i,l,l,a/l**2]:[n,r,i,a/l**2,l,l],s=u?[0,1,3,2,4,5]:[0,1,4,2,5,3]):([n,r,i,a]=[e.dims[0],e.dims[2],e.dims[3],e.dims[1]],o=u?[n,l,l,a/l**2,r,i]:[n,a/l**2,l,l,r,i],s=u?[0,3,4,1,5,2]:[0,1,4,2,5,3]);let d=e.reshape(o),f=d.dims.length,p=e.dataType,m=q(`a`,p,f),h=J(`output`,p,f);return{name:`DepthToSpace`,shaderCache:{hint:`${e.dims};${t.blocksize};${t.mode}`,inputDependencies:[`rank`]},getRunData:e=>{let t=c?[n,r*l,i*l,a/l**2]:[n,a/l**2,r*l,i*l],o=z.size(t),u=d.dims,f=z.sortBasedOnPerm(u,s);return{outputs:[{dims:t,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(o/64)},programUniforms:[{type:12,data:o},...W(u,f)]}},getShaderSource:e=>` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(m,h)} | |
| ${Ao(s,f,m,h)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let indices = ${h.offsetToIndices(`global_idx`)}; | |
| let aIndices = perm(indices); | |
| ${h.setByOffset(`global_idx`,m.getByIndices(`aIndices`))} | |
| }`}},Mo=(e,t)=>{ko(e.inputs),e.compute(jo(e.inputs[0],t))},No=e=>V({blocksize:e.blocksize,mode:e.mode,format:e.format})}),Fo,Io,Lo,Ro,zo,Bo,Vo,Ho,Uo,Wo,Go,Ko=o(()=>{L(),B(),H(),Y(),Fo=`[a-zA-Z]|\\.\\.\\.`,Io=`(`+Fo+`)+`,Lo=`^`+Io+`$`,Ro=`(`+Io+`,)*`+Io,zo=`^`+Ro+`$`,Bo=class{constructor(e=-1){this.symbolToIndices=new Map,this.inputIndex=e}addSymbol(e,t){let n=this.symbolToIndices.get(e);n===void 0?n=[t]:n.push(t),this.symbolToIndices.set(e,n)}},Vo=class{constructor(e,t){this.equation=t,this.hasEllipsis=!1,this.symbolToInfo=new Map,this.lhs=[],this.outputDims=[];let[n,r]=t.includes(`->`)?t.split(`->`,2):[t,``];if(!n.match(RegExp(zo)))throw Error(`Invalid LHS term`);if(n.split(`,`).forEach((t,n)=>{let r=e[n].dims.slice();if(!t.match(RegExp(Lo)))throw Error(`Invalid LHS term`);let i=this.processTerm(t,!0,r,n);this.lhs.push(i)}),r===``)r+=[...this.symbolToInfo.entries()].filter(([e,t])=>t.count===1||e===`...`).map(([e])=>e).join(``);else if(!r.match(RegExp(Io)))throw Error(`Invalid RHS`);r.match(RegExp(Fo,`g`))?.forEach(e=>{if(e===`...`)this.outputDims=this.outputDims.concat(this.ellipsisDims);else{let t=this.symbolToInfo.get(e);if(t===void 0)throw Error(`Invalid RHS symbol`);this.outputDims.push(t.dimValue)}}),this.rhs=this.processTerm(r,!1,this.outputDims)}addSymbol(e,t,n){let r=this.symbolToInfo.get(e);if(r!==void 0){if(r.dimValue!==t&&r.count!==1)throw Error(`Dimension mismatch`);r.count++,r.inputIndices.push(n)}else r={count:1,dimValue:t,inputIndices:[n]};this.symbolToInfo.set(e,r)}processTerm(e,t,n,r=-1){let i=n.length,a=!1,o=[],s=0;if(!e.match(RegExp(Lo))&&!t&&e!==``)throw Error(`Invalid LHS term`);let c=e.match(RegExp(Fo,`g`)),l=new Bo(r);return c?.forEach((e,u)=>{if(e===`...`){if(a)throw Error(`Only one ellipsis is allowed per input term`);a=!0;let e=i-c.length+1;if(e<0)throw Error(`Ellipsis out of bounds`);if(o=n.slice(s,s+e),this.hasEllipsis){if(this.ellipsisDims.length!==o.length||this.ellipsisDims.toString()!==o.toString())throw Error(`Ellipsis dimensions mismatch`)}else if(t)this.hasEllipsis=!0,this.ellipsisDims=o;else throw Error(`Ellipsis must be specified in the LHS`);for(let e=0;e<o.length;e++){let t=String.fromCharCode(48+e);l.addSymbol(t,u+e),this.addSymbol(t,n[s++],r)}}else l.addSymbol(e,u+(this.hasEllipsis?this.ellipsisDims.length-1:0)),this.addSymbol(e,n[s++],r)}),l}},Ho=e=>e+`_max`,Uo=(e,t,n,r)=>{let i=e.map(e=>e.length).map((e,n)=>q(`input${n}`,t,e)),a=z.size(r),o=J(`output`,t,r.length),s=[...n.symbolToInfo.keys()].filter(e=>!n.rhs.symbolToIndices.has(e));return{name:`Einsum`,shaderCache:{hint:n.equation,inputDependencies:e.map(()=>`rank`)},getRunData:()=>{let i=s.filter(e=>n.symbolToInfo.has(e)).map(e=>({type:12,data:n.symbolToInfo.get(e)?.dimValue||0}));i.push({type:12,data:a});let o=e.map((e,t)=>[...W(e)]).reduce((e,t)=>e.concat(t),i);return o.push(...W(r)),{outputs:[{dims:r,dataType:t}],dispatchGroup:{x:Math.ceil(a/64)},programUniforms:o}},getShaderSource:e=>{let t=[],r=[],a=[],c=[],l=[],u=n.symbolToInfo.size===n.rhs.symbolToIndices.size;n.symbolToInfo.forEach((e,s)=>{if(n.rhs.symbolToIndices.has(s)){let r=n.rhs.symbolToIndices.get(s)?.[0];r!==void 0&&n.lhs.forEach((n,a)=>{if(e.inputIndices.includes(a)){let e=n.symbolToIndices.get(s);if(e===void 0)throw Error(`Invalid symbol error`);e.forEach(e=>{t.push(`${i[a].indicesSet(`input${a}Indices`,e,o.indicesGet(`outputIndices`,r))}`)})}})}else n.lhs.forEach((t,n)=>{if(e.inputIndices.includes(n)){let e=t.symbolToIndices.get(s);if(e===void 0)throw Error(`Invalid symbol error`);e.forEach(e=>{r.push(`${i[n].indicesSet(`input${n}Indices`,e,`${s}`)}`)}),l.push(`prod *= ${i[n].getByIndices(`input${n}Indices`)};`)}}),a.push(`for(var ${s}: u32 = 0; ${s} < uniforms.${Ho(s)}; ${s}++) {`),c.push(`}`)});let d=u?[...t,`let sum = ${i.map((e,t)=>e.getByIndices(`input${t}Indices`)).join(` * `)};`]:[...t,`var sum = 0.0;`,...a,...r,`var prod = 1.0;`,...l,`sum += prod;`,...c];return` | |
| ${e.registerUniforms(s.map(e=>({name:`${Ho(e)}`,type:`u32`}))).registerUniform(`outputSize`,`u32`).declareVariables(...i,o)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| var outputIndices = ${o.offsetToIndices(`global_idx`)}; | |
| ${i.map((e,t)=>`var input${t}Indices: ${i[t].type.indices};`).join(` | |
| `)} | |
| ${d.join(` | |
| `)}; | |
| ${o.setByOffset(`global_idx`,`sum`)}; | |
| }`}}},Wo=(e,t)=>{let n=new Vo(e.inputs,t.equation),r=n.outputDims,i=e.inputs.map((e,t)=>e.dims);e.compute(Uo(i,e.inputs[0].dataType,n,r))},Go=e=>{let t=e.equation.replace(/\s+/g,``);return V({equation:t})}}),qo,Jo,Yo,Xo,Zo,Qo=o(()=>{L(),B(),Y(),qo=e=>{if(!e||e.length!==2)throw Error(`Expand requires 2 input.`);let t=e[0].dims,n=Array.from(e[1].getBigInt64Array(),Number),r=n.length<t.length?0:n.length-t.length,i=t.length<n.length?0:t.length-n.length;for(;r<n.length&&i<t.length;++r,++i)if(n[r]!==t[i]&&n[r]!==1&&t[i]!==1)throw Error(`Expand requires shape to be broadcastable to input`)},Jo=(e,t)=>{let n=e.length-t.length,r=[];for(let t=0;t<n;++t)r.push(e[t]);for(let i=0;i<t.length;++i)r.push(t[i]===1?e[i+n]:t[i]);return r},Yo=(e,t)=>e.length>t.length?Jo(e,t):Jo(t,e),Xo=e=>{let t=e[0].dims,n=Array.from(e[1].getBigInt64Array(),Number),r=Yo(t,n),i=e[0].dataType,a=i===9||z.size(t)===1,o=i===9||t.length>0&&t[t.length-1]%4==0?4:1,s=a||r.length>0&&r[r.length-1]%4==0?4:1,c=Math.ceil(z.size(r)/s),l=e=>{let n=q(`input`,i,t.length,o),a=J(`output`,i,r.length,s),c;if(i===9){let e=(e,t,r=``)=>` | |
| let outputIndices${t} = ${a.offsetToIndices(`outputOffset + ${t}u`)}; | |
| let offset${t} = ${n.broadcastedIndicesToOffset(`outputIndices${t}`,a)}; | |
| let index${t} = offset${t} / 4u; | |
| let component${t} = offset${t} % 4u; | |
| ${e}[${t}] = ${r}(${n.getByOffset(`index${t}`)}[component${t}]); | |
| `;c=` | |
| let outputOffset = global_idx * ${s}; | |
| var data = vec4<u32>(0); | |
| ${e(`data`,0,`u32`)} | |
| ${e(`data`,1,`u32`)} | |
| ${e(`data`,2,`u32`)} | |
| ${e(`data`,3,`u32`)} | |
| ${a.setByOffset(`global_idx`,`data`)} | |
| }`}else c=` | |
| let outputIndices = ${a.offsetToIndices(`global_idx * ${s}`)}; | |
| let inputOffset = ${n.broadcastedIndicesToOffset(`outputIndices`,a)}; | |
| let data = ${a.type.value}(${n.getByOffset(`inputOffset / ${o}`)}); | |
| ${a.setByOffset(`global_idx`,`data`)} | |
| }`;return` | |
| ${e.registerUniform(`vec_size`,`u32`).declareVariables(n,a)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.vec_size`)} | |
| ${c}`},u=[{type:12,data:c},...W(t,r)];return{name:`Expand`,shaderCache:{hint:`${r.length};${o}${s}`,inputDependencies:[`rank`]},getShaderSource:l,getRunData:()=>({outputs:[{dims:r,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(c/64)},programUniforms:u})}},Zo=e=>{qo(e.inputs),e.compute(Xo(e.inputs),{inputs:[0]})}}),$o,es,ts=o(()=>{L(),B(),Y(),qi(),$o=e=>{let t=e[0].dataType,n=z.size(e[0].dims),r=z.size(e[1].dims),i=r%4==0;return{name:`FastGeluWithBias`,shaderCache:{hint:`${i}`,inputDependencies:[`type`,`type`]},getShaderSource:e=>{let n=q(`x`,t,[1],4),r=q(`bias`,t,[1],4),a=J(`y`,t,[1],4),o=[{name:`output_vec_size`,type:`u32`},{name:`bias_size`,type:`u32`}],s=e=>` | |
| let bias${e}_offset: u32 = (global_idx * 4 + ${e}) % uniforms.bias_size; | |
| let bias${e} = ${r.getByOffset(`bias${e}_offset / 4`)}[bias${e}_offset % 4];`,c=i?` | |
| let bias = ${r.getByOffset(`global_idx % (uniforms.bias_size / 4)`)};`:`${s(0)}${s(1)}${s(2)}${s(3)} | |
| let bias = ${n.type.value}(bias0, bias1, bias2, bias3);`;return`${e.registerUniforms(o).declareVariables(n,r,a)} | |
| ${zi(Cn(t))} | |
| ${e.mainStart(xn)} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_vec_size`)} | |
| let x = ${n.getByOffset(`global_idx`)}; | |
| ${c} | |
| let x_in = x + bias; | |
| ${a.setByOffset(`global_idx`,Bi(`x_in`))} | |
| }`},getRunData:e=>({outputs:[{dims:e[0].dims,dataType:e[0].dataType}],programUniforms:[{type:12,data:Math.ceil(n/4)},{type:12,data:r}],dispatchGroup:{x:Math.ceil(n/xn/4)}})}},es=e=>{e.inputs.length<2||z.size(e.inputs[1].dims)===0?Vi(e):e.compute($o(e.inputs))}}),ns,rs,is,as,os=o(()=>{L(),B(),H(),Y(),ns=e=>{if(!e||e.length!==2)throw Error(`Gather requires 2 inputs.`)},rs=(e,t)=>{let n=e[0].dims,r=e[1].dims,i=n.length,a=z.normalizeAxis(t.axis,i),o=n.slice(0);o.splice(a,1,...r);let s=n[a],c=e[0].dataType===9?4:1,l=Math.ceil(z.size(o)/c),u=[{type:12,data:l},{type:6,data:s},{type:12,data:a},...W(e[0].dims,e[1].dims,o)];return{name:`Gather`,shaderCache:{hint:t.cacheKey,inputDependencies:[`rank`,`rank`]},getRunData:()=>({outputs:[{dims:o,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:u}),getShaderSource:t=>{let n=q(`data`,e[0].dataType,e[0].dims.length,c),s=q(`inputIndices`,e[1].dataType,e[1].dims.length),l=J(`output`,e[0].dataType,o.length,c),u=e=>{let t=r.length,c=`var indicesIndices${e} = ${s.type.indices}(0);`;for(let n=0;n<t;n++)c+=`${t>1?`indicesIndices${e}[${n}]`:`indicesIndices${e}`} = ${o.length>1?`outputIndices${e}[uniforms.axis + ${n}]`:`outputIndices${e}`};`;c+=` | |
| var idx${e} = ${s.getByIndices(`indicesIndices${e}`)}; | |
| if (idx${e} < 0) { | |
| idx${e} = idx${e} + uniforms.axisDimLimit; | |
| } | |
| var dataIndices${e} : ${n.type.indices}; | |
| `;for(let n=0,r=0;n<i;n++)n===a?(c+=`${i>1?`dataIndices${e}[${n}]`:`dataIndices${e}`} = u32(idx${e});`,r+=t):(c+=`${i>1?`dataIndices${e}[${n}]`:`dataIndices${e}`} = ${o.length>1?`outputIndices${e}[${r}]`:`outputIndices${e}`};`,r++);return c},d;if(e[0].dataType===9){let e=(e,t,r=``)=>` | |
| let outputIndices${t} = ${l.offsetToIndices(`outputOffset + ${t}u`)}; | |
| ${u(t)}; | |
| let offset${t} = ${n.indicesToOffset(`dataIndices${t}`)}; | |
| let index${t} = offset${t} / 4u; | |
| let component${t} = offset${t} % 4u; | |
| ${e}[${t}] = ${r}(${n.getByOffset(`index${t}`)}[component${t}]); | |
| `;d=` | |
| let outputOffset = global_idx * ${c}; | |
| var value = vec4<u32>(0); | |
| ${e(`value`,0,`u32`)} | |
| ${e(`value`,1,`u32`)} | |
| ${e(`value`,2,`u32`)} | |
| ${e(`value`,3,`u32`)} | |
| ${l.setByOffset(`global_idx`,`value`)} | |
| `}else d=` | |
| let outputIndices = ${l.offsetToIndices(`global_idx`)}; | |
| ${u(``)}; | |
| let value = ${n.getByIndices(`dataIndices`)}; | |
| ${l.setByOffset(`global_idx`,`value`)}; | |
| `;return` | |
| ${t.registerUniform(`outputSize`,`u32`).registerUniform(`axisDimLimit`,`i32`).registerUniform(`axis`,`u32`).declareVariables(n,s,l)} | |
| ${t.mainStart()} | |
| ${t.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| ${d} | |
| }`}}},is=e=>V({axis:e.axis}),as=(e,t)=>{let n=e.inputs;ns(n),e.compute(rs(e.inputs,t))}}),ss,cs,ls,us=o(()=>{L(),B(),Y(),ss=(e,t,n,r,i,a,o,s,c)=>{let l=[{type:12,data:a},{type:12,data:r},{type:12,data:i},{type:12,data:n},{type:12,data:o},{type:12,data:s},{type:12,data:c}],u=[a];return l.push(...W(t.dims,u)),e.compute({name:`computeSliceOffsets`,shaderCache:{hint:`${i.length}_${n.length}`,inputDependencies:[`rank`]},getRunData:()=>({outputs:[{dims:u,dataType:e.inputs[1].dataType}],dispatchGroup:{x:Math.ceil(a/64)},programUniforms:l}),getShaderSource:e=>{let r=[q(`indices_data`,t.dataType,t.dims.length),J(`input_slice_offsets_data`,12,1,1)],a=[{name:`output_size`,type:`u32`},{name:`batch_dims`,type:`u32`},{name:`input_dims`,type:`u32`,length:i.length},{name:`sizes_from_slice_dims_data`,type:`u32`,length:n.length},{name:`num_slices_per_batch`,type:`u32`},{name:`input_batch_stride`,type:`u32`},{name:`num_slice_dims`,type:`u32`}];return` | |
| ${e.registerUniforms(a).declareVariables(...r)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let batch_idx = global_idx / uniforms.num_slices_per_batch; | |
| let base_offset = batch_idx * uniforms.input_batch_stride; | |
| let slice_indices_base_offset = global_idx * uniforms.num_slice_dims; | |
| var relative_slice_offset = 0; | |
| for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) { | |
| var index = i32(indices_data[dim_idx + slice_indices_base_offset].x); | |
| let input_dim_idx = uniforms.batch_dims + dim_idx; | |
| if (index < 0) { | |
| ${i.length===1?`index += i32(uniforms.input_dims);`:`index += i32(uniforms.input_dims[input_dim_idx]);`} | |
| } | |
| ${n.length===1?`relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);`:`relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);`} | |
| } | |
| input_slice_offsets_data[global_idx] = base_offset + u32(relative_slice_offset); | |
| }`}},{inputs:[t],outputs:[-1]})[0]},cs=(e,t)=>{let n=e.inputs,r=n[0].dims,i=n[0].dataType,a=n[1].dims,o=a[a.length-1],s=z.sizeToDimension(a,a.length-1),c=z.sizeFromDimension(r,t.batchDims+o),l=z.sizeToDimension(r,t.batchDims),u=z.sizeFromDimension(r,t.batchDims),d=s/l,f=Array(o),p=c;for(let e=0;e<o;++e)f[o-1-e]=p,p*=r[t.batchDims+o-1-e];let m=ss(e,n[1],f,t.batchDims,r,s,d,u,o),h=t.batchDims+o;if(h>r.length)throw Error(`last dimension of indices must not be larger than rank of input tensor`);let g=a.slice(0,-1).concat(r.slice(h)),_=z.size(g),v=[{type:12,data:_},{type:12,data:c},...W(n[0].dims,m.dims,g)];e.compute({name:`GatherND`,shaderCache:{hint:t.cacheKey,inputDependencies:[`rank`,`rank`]},getRunData:()=>({outputs:[{dims:g,dataType:i}],dispatchGroup:{x:Math.ceil(_/64)},programUniforms:v}),getShaderSource:e=>{let t=q(`data`,n[0].dataType,n[0].dims.length),r=q(`slice_offsets`,12,m.dims.length),i=J(`output`,n[0].dataType,g.length);return` | |
| ${e.registerUniform(`output_size`,`u32`).registerUniform(`slice_size`,`u32`).declareVariables(t,r,i)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let slice_offset = slice_offsets[global_idx / uniforms.slice_size]; | |
| output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size]; | |
| }`}},{inputs:[n[0],m]})},ls=e=>({batchDims:e.batch_dims,cacheKey:``})}),ds,fs,ps,ms,hs=o(()=>{L(),B(),H(),Y(),ds=(e,t)=>{if(e.length<3||e.length>4)throw Error(`GatherBlockQuantized requires 3 or 4 inputs.`);let n=z.normalizeAxis(t.quantizeAxis,e[0].dims.length),r=t.blockSize,i=e[0],a=e[2],o=e.length===4?e[3]:void 0;if(a.dims.length!==i.dims.length||!i.dims.map((e,t)=>t===n?Math.ceil(e/r)===a.dims[t]:e===a.dims[t]).reduce((e,t)=>e&&t,!0))throw Error(`Scales must have the same rank as the input tensor and the dims should match except on gatherAxis.`);if(o){if(o.dataType!==i.dataType)throw Error(`Zero point must have the same data type as the input tensor.`);if(o.dims.length!==a.dims.length||!o.dims.map((e,t)=>e===a.dims[t]).reduce((e,t)=>e&&t,!0))throw Error(`Zero point must have the same rank as the input tensor and the dims should match except on quantizeAxis.`)}},fs=(e,t)=>{let n=e[0].dims,r=e[1].dims,i=n.length,a=z.normalizeAxis(t.gatherAxis,i),o=z.normalizeAxis(t.quantizeAxis,i),s=n.slice(0);s.splice(a,1,...r);let c=z.size(s),l=e[2].dataType,u=e[0].dataType===22,d=[{type:12,data:c},{type:12,data:o},{type:12,data:a},{type:12,data:t.blockSize},...W(...e.map((e,t)=>e.dims),s)];return{name:`GatherBlockQuantized`,shaderCache:{hint:`${t.cacheKey};${e.filter((e,t)=>t!==1).map(e=>e.dims.join(`_`)).join(`;`)}`,inputDependencies:Array.from({length:e.length},(e,t)=>`rank`)},getRunData:()=>({outputs:[{dims:s,dataType:l}],dispatchGroup:{x:Math.ceil(c/64)},programUniforms:d}),getShaderSource:t=>{let i=q(`data`,e[0].dataType,e[0].dims.length),o=q(`inputIndices`,e[1].dataType,e[1].dims.length),c=q(`scales`,e[2].dataType,e[2].dims.length),d=e.length>3?q(`zeroPoint`,e[3].dataType,e[3].dims.length):void 0,f=J(`output`,l,s.length),p=[i,o,c];return d&&p.push(d),` | |
| ${t.registerUniforms([{name:`output_size`,type:`u32`},{name:`quantize_axis`,type:`u32`},{name:`gather_axis`,type:`u32`},{name:`block_size`,type:`u32`}]).declareVariables(...p,f)} | |
| ${t.mainStart()} | |
| let output_indices = ${f.offsetToIndices(`global_idx`)}; | |
| var indices_indices = ${o.type.indices}(0); | |
| ${r.length>1?` | |
| for (var i: u32 = 0; i < ${r.length}; i++) { | |
| let index = ${f.indicesGet(`output_indices`,`uniforms.gather_axis + i`)}; | |
| ${o.indicesSet(`indices_indices`,`i`,`index`)}; | |
| }`:`indices_indices = ${f.indicesGet(`output_indices`,`uniforms.gather_axis`)};`}; | |
| var data_indices = ${i.type.indices}(0); | |
| for (var i: u32 = 0; i < uniforms.gather_axis; i++) { | |
| let index = ${f.indicesGet(`output_indices`,`i`)}; | |
| ${i.indicesSet(`data_indices`,`i`,`index`)}; | |
| } | |
| var index_from_indices = ${o.getByIndices(`indices_indices`)}; | |
| if (index_from_indices < 0) { | |
| index_from_indices += ${n[a]}; | |
| } | |
| ${i.indicesSet(`data_indices`,`uniforms.gather_axis`,`u32(index_from_indices)`)}; | |
| for (var i = uniforms.gather_axis + 1; i < ${s.length}; i++) { | |
| let index = ${f.indicesGet(`output_indices`,`i + ${r.length} - 1`)}; | |
| ${i.indicesSet(`data_indices`,`i`,`index`)}; | |
| } | |
| let data_offset = ${i.indicesToOffset(`data_indices`)}; | |
| let data_index = data_offset % 8; | |
| // Convert 4-bit packed data to 8-bit packed data. | |
| let packed_4bit_quantized_data = ${i.getByOffset(`data_offset / 8`)}; | |
| let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f; | |
| let quantized_data_vec = ${u?`unpack4xI8`:`unpack4xU8`}(u32(packed_8bit_quantized_data)); | |
| let quantized_data = quantized_data_vec[data_index / 2]; | |
| var scale_indices = data_indices; | |
| let quantize_axis_index = ${c.indicesGet(`data_indices`,`uniforms.quantize_axis`)} / uniforms.block_size; | |
| ${c.indicesSet(`scale_indices`,`uniforms.quantize_axis`,`quantize_axis_index`)}; | |
| var scale = ${c.getByIndices(`scale_indices`)}; | |
| ${d?` | |
| let zero_point_indices = scale_indices; | |
| let zero_point_offset = ${d.indicesToOffset(`zero_point_indices`)}; | |
| let zero_point_index = zero_point_offset % 8; | |
| let packed_4bit_zero_points = ${d.getByOffset(`zero_point_offset / 8`)}; | |
| let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f; | |
| let zero_point_vec = ${u?`unpack4xI8`:`unpack4xU8`}(u32(packed_8bit_zero_points)); | |
| let zero_point = zero_point_vec[zero_point_index / 2];`:`var zero_point = 0`}; | |
| let dequantized_data = ${Cn(l)}(quantized_data - zero_point) * scale; | |
| ${f.setByOffset(`global_idx`,`dequantized_data`)}; | |
| }`}}},ps=(e,t)=>{let n=e.inputs;ds(n,t),e.compute(fs(e.inputs,t))},ms=e=>V({blockSize:e.blockSize,gatherAxis:e.gatherAxis,quantizeAxis:e.quantizeAxis})}),gs,_s,vs,ys,bs=o(()=>{L(),B(),H(),Y(),gs=e=>{if(!e||e.length!==2)throw Error(`GatherElements requires 2 inputs.`);if(e[0].dims.length<1)throw Error(`GatherElements requires that the data input be rank >= 1.`);if(e[0].dims.length!==e[1].dims.length)throw Error(`GatherElements requires that the data input and | |
| indices input tensors be of same rank.`)},_s=(e,t)=>{let n=e[0].dims,r=e[0].dataType,i=n.length,a=e[1].dims,o=e[1].dataType,s=z.normalizeAxis(t.axis,i),c=n[s],l=a.slice(0),u=z.size(l),d=q(`input`,r,i),f=q(`indicesInput`,o,a.length),p=J(`output`,r,l.length),m=[{type:12,data:u},{type:6,data:c},{type:12,data:s}];return m.push(...W(n,a,l)),{name:`GatherElements`,shaderCache:{inputDependencies:[`rank`,`rank`]},getRunData:()=>({outputs:[{dims:l,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(u/64)},programUniforms:m}),getShaderSource:e=>` | |
| ${e.registerUniform(`outputSize`,`u32`).registerUniform(`axisDimLimit`,`i32`).registerUniform(`axis`,`u32`).declareVariables(d,f,p)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| let outputIndices = ${p.offsetToIndices(`global_idx`)}; | |
| var idx = ${f.getByOffset(`global_idx`)}; | |
| if (idx < 0) { | |
| idx = idx + uniforms.axisDimLimit; | |
| } | |
| var inputIndices = ${d.type.indices}(outputIndices); | |
| ${d.indicesSet(`inputIndices`,`uniforms.axis`,`u32(idx)`)}; | |
| let value = ${d.getByIndices(`inputIndices`)}; | |
| ${p.setByOffset(`global_idx`,`value`)}; | |
| }`}},vs=e=>V({axis:e.axis}),ys=(e,t)=>{let n=e.inputs;gs(n),e.compute(_s(e.inputs,t))}}),xs,Ss,Cs,ws,Ts=o(()=>{L(),B(),Y(),xs=e=>{if(!e)throw Error(`Input is missing`);if(e.length<2||e.length>3)throw Error(`Invaid input number.`);if(e.length===3&&e[2].dims.length>2)throw Error(`Invalid input shape of C`);if(e[0].dataType!==e[1].dataType||e.length===3&&e[0].dataType!==e[2].dataType)throw Error(`Input types are mismatched`)},Ss=(e,t)=>{let n=e[0].dims.slice(),r=e[1].dims.slice(),[i,a,o]=Vt.getShapeOfGemmResult(n,t.transA,r,t.transB,e.length===3?e[2].dims:void 0),s=[i,a];if(!s)throw Error(`Can't use gemm on the given tensors`);let c=Math.ceil(a/16),l=Math.ceil(i/16);z.size(s);let u=[{type:12,data:c},{type:12,data:i},{type:12,data:a},{type:12,data:o},{type:1,data:t.alpha},{type:1,data:t.beta}],d=[`type`,`type`];return e.length===3&&(u.push(...W(e[2].dims)),d.push(`rank`)),u.push(...W(s)),{name:`GemmShared`,shaderCache:{hint:`${t.cacheKey}`,inputDependencies:d},getRunData:()=>({outputs:[{dims:s,dataType:e[0].dataType}],dispatchGroup:{x:c*l},programUniforms:u}),getShaderSource:n=>{let r=q(`a`,e[0].dataType,e[0].dims),i=q(`b`,e[1].dataType,e[1].dims),a=null,o=[r,i];e.length===3&&(a=q(`c`,e[2].dataType,e[2].dims.length),o.push(a));let c=J(`output`,e[0].dataType,s.length);o.push(c);let l=[{name:`num_tile_n`,type:`u32`},{name:`M`,type:`u32`},{name:`N`,type:`u32`},{name:`K`,type:`u32`},{name:`alpha`,type:`f32`},{name:`beta`,type:`f32`}],u=``,d=``;t.transA&&t.transB?(d=` | |
| var col = tile_row_start + local_id.x; | |
| var row = k_start + local_id.y; | |
| if (col < uniforms.M && row < uniforms.K) { | |
| tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; | |
| } else { | |
| tile_a[local_id.y][local_id.x] = ${r.type.value}(0); | |
| } | |
| col = k_start + local_id.x; | |
| row = tile_col_start + local_id.y; | |
| if (col < uniforms.K && row < uniforms.N) { | |
| tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; | |
| } else { | |
| tile_b[local_id.y][local_id.x] = ${i.type.value}(0); | |
| } | |
| `,u=`value += tile_a[k][local_id.y] * tile_b[local_id.x][k];`):t.transA&&!t.transB?(d=` | |
| var col = tile_row_start + local_id.x; | |
| var row = k_start + local_id.y; | |
| if (col < uniforms.M && row < uniforms.K) { | |
| tile_a[local_id.y][local_id.x] = a[row * uniforms.M + col]; | |
| } else { | |
| tile_a[local_id.y][local_id.x] = ${r.type.value}(0); | |
| } | |
| col = tile_col_start + local_id.x; | |
| row = k_start + local_id.y; | |
| if (col < uniforms.N && row < uniforms.K) { | |
| tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; | |
| } else { | |
| tile_b[local_id.y][local_id.x] = ${i.type.value}(0); | |
| } | |
| `,u=`value += tile_a[k][local_id.y] * tile_b[k][local_id.x];`):!t.transA&&t.transB?(d=` | |
| var col = k_start + local_id.x; | |
| var row = tile_row_start + local_id.y; | |
| if (col < uniforms.K && row < uniforms.M) { | |
| tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; | |
| } else { | |
| tile_a[local_id.y][local_id.x] = ${r.type.value}(0); | |
| } | |
| col = k_start + local_id.x; | |
| row = tile_col_start + local_id.y; | |
| if (col < uniforms.K && row < uniforms.N) { | |
| tile_b[local_id.y][local_id.x] = b[row * uniforms.K + col]; | |
| } else { | |
| tile_b[local_id.y][local_id.x] = ${i.type.value}(0); | |
| } | |
| `,u=`value += tile_a[local_id.y][k] * tile_b[local_id.x][k];`):!t.transA&&!t.transB&&(d=` | |
| var col = k_start + local_id.x; | |
| var row = tile_row_start + local_id.y; | |
| if (col < uniforms.K && row < uniforms.M) { | |
| tile_a[local_id.y][local_id.x] = a[row * uniforms.K + col]; | |
| } else { | |
| tile_a[local_id.y][local_id.x] = ${r.type.value}(0); | |
| } | |
| col = tile_col_start + local_id.x; | |
| row = k_start + local_id.y; | |
| if (col < uniforms.N && row < uniforms.K) { | |
| tile_b[local_id.y][local_id.x] = b[row * uniforms.N + col]; | |
| } else { | |
| tile_b[local_id.y][local_id.x] = ${i.type.value}(0); | |
| } | |
| `,u=`value += tile_a[local_id.y][k] * tile_b[k][local_id.x];`);let f=t.alpha===1?``:`value *= uniforms.alpha;`;return` | |
| ${n.registerUniforms(l).declareVariables(...o)} | |
| var<workgroup> tile_a: array<array<${r.type.storage}, 16>, 16>; | |
| var<workgroup> tile_b: array<array<${i.type.storage}, 16>, 16>; | |
| ${n.mainStart([16,16,1])} | |
| let tile_col_start = (workgroup_index % uniforms.num_tile_n) * 16; | |
| let tile_row_start = (workgroup_index / uniforms.num_tile_n) * 16; | |
| let num_tiles = (uniforms.K - 1) / 16 + 1; | |
| var k_start = 0u; | |
| var value = ${c.type.value}(0); | |
| for (var t: u32 = 0u; t < num_tiles; t++) { | |
| ${d} | |
| k_start = k_start + 16; | |
| workgroupBarrier(); | |
| for (var k: u32 = 0u; k < 16; k++) { | |
| ${u} | |
| } | |
| workgroupBarrier(); | |
| } | |
| ${f} | |
| let m = tile_row_start + local_id.y; | |
| let n = tile_col_start + local_id.x; | |
| ${a==null?``:`let cOffset = ${a.broadcastedIndicesToOffset(`vec2(m, n)`,c)}; value += ${c.type.value}(uniforms.beta) * ${a.getByOffset(`cOffset`)};`} | |
| if (m < uniforms.M && n < uniforms.N) { | |
| output[m * uniforms.N + n] = value; | |
| } | |
| }`}}},Cs=e=>({transA:e.transA,transB:e.transB,alpha:e.alpha,beta:e.beta,cacheKey:`${e.transA};${e.transB};${e.alpha===1}`}),ws=(e,t)=>{xs(e.inputs),e.compute(Ss(e.inputs,t))}}),Es,Ds,Os,ks,As,js,Ms,Ns,Ps,Fs,Is,Ls,Rs,zs,Bs=o(()=>{L(),B(),H(),Y(),[Es,Ds,Os,ks]=[0,1,2,3],As=e=>{if(e[0].dims.length!==4)throw Error(`only 4-D tensor is supported.`);if(e[0].dims.length!==e[1].dims.length)throw Error(`input dimensions must be equal to grid dimensions`);if(e[0].dims.length-2!==e[1].dims[e[1].dims.length-1])throw Error(`last dimension of grid must be equal to ${e[0].dims.length-2}`);if(e[0].dims[0]!==e[1].dims[0])throw Error(`grid batch size must match input batch size`)},js=` | |
| fn gs_get_cubic_coeffs(x: f32) -> vec4<f32> { | |
| let cubic_alpha = -0.75f; | |
| let x_abs = abs(x); | |
| var coeffs: vec4<f32>; | |
| coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha); | |
| coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1); | |
| coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1); | |
| coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha); | |
| return coeffs; | |
| } | |
| `,Ms=e=>` | |
| fn gs_bicubic_interpolate(p: mat4x4<${e}>, x: f32, y: f32) -> ${e} { | |
| var v: vec4<f32>; | |
| var coeffs = gs_get_cubic_coeffs(x); | |
| for (var i = 0; i < 4; i++) { | |
| v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; | |
| } | |
| coeffs = gs_get_cubic_coeffs(y); | |
| let pixel = ${e}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]); | |
| return pixel; | |
| } | |
| `,Ns=e=>` | |
| fn gs_denormalize(n: f32, length: i32) -> f32 { | |
| ${e.alignCorners===0?` | |
| // alignCorners: false => [-1, 1] to [-0.5, length - 0.5] | |
| return ((n + 1.0) * f32(length) - 1.0) / 2.0; | |
| `:` | |
| // alignCorners: true => [-1, 1] to [0, length - 1] | |
| return (n + 1.0) / 2.0 * (f32(length - 1)); | |
| `} | |
| } | |
| `,Ps=e=>` | |
| ${e.paddingMode===`reflection`?` | |
| fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 { | |
| var dx = 0.0; | |
| var fx = f32(x); | |
| let range = x_max - x_min; | |
| if (fx < x_min) { | |
| dx = x_min - fx; | |
| let n = u32(dx / range); | |
| let r = dx - f32(n) * range; | |
| if (n % 2 == 0) { | |
| fx = x_min + r; | |
| } else { | |
| fx = x_max - r; | |
| } | |
| } else if (fx > x_max) { | |
| dx = fx - x_max; | |
| let n = u32(dx / range); | |
| let r = dx - f32(n) * range; | |
| if (n % 2 == 0) { | |
| fx = x_max - r; | |
| } else { | |
| fx = x_min + r; | |
| } | |
| } | |
| return u32(fx); | |
| }`:``} | |
| `,Fs=(e,t,n)=>` | |
| fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4<f32>) -> ${t} { | |
| var pixel = ${t}(0); | |
| var indices = vec4<u32>(0); | |
| indices[${Es}] = batch; | |
| indices[${Ds}] = channel;`+(()=>{switch(n.paddingMode){case`zeros`:return` | |
| if (r >= 0 && r < H && c >=0 && c < W) { | |
| indices[${Os}] = u32(r); | |
| indices[${ks}] = u32(c); | |
| } else { | |
| return ${t}(0); | |
| } | |
| `;case`border`:return` | |
| indices[${Os}] = u32(clamp(r, 0, H - 1)); | |
| indices[${ks}] = u32(clamp(c, 0, W - 1)); | |
| `;case`reflection`:return` | |
| indices[${Os}] = gs_reflect(r, border[1], border[3]); | |
| indices[${ks}] = gs_reflect(c, border[0], border[2]); | |
| `;default:throw Error(`padding mode ${n.paddingMode} is not supported`)}})()+` | |
| return ${e.getByIndices(`indices`)}; | |
| } | |
| `,Is=(e,t,n)=>(()=>{switch(n.mode){case`nearest`:return` | |
| let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${Es}], indices[${Ds}], border); | |
| `;case`bilinear`:return` | |
| let x1 = i32(floor(x)); | |
| let y1 = i32(floor(y)); | |
| let x2 = x1 + 1; | |
| let y2 = y1 + 1; | |
| let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${Es}], indices[${Ds}], border); | |
| let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${Es}], indices[${Ds}], border); | |
| let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${Es}], indices[${Ds}], border); | |
| let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${Es}], indices[${Ds}], border); | |
| let dx2 = ${t}(f32(x2) - x); | |
| let dx1 = ${t}(x - f32(x1)); | |
| let dy2 = ${t}(f32(y2) - y); | |
| let dy1 = ${t}(y - f32(y1)); | |
| let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); | |
| `;case`bicubic`:return` | |
| let x0 = i32(floor(x)) - 1; | |
| let y0 = i32(floor(y)) - 1; | |
| var p: mat4x4<${t}>; | |
| for (var h = 0; h < 4; h++) { | |
| for (var w = 0; w < 4; w++) { | |
| p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${Es}], indices[${Ds}], border); | |
| } | |
| } | |
| let dx = x - f32(x0 + 1); | |
| let dy = y - f32(y0 + 1); | |
| let result = gs_bicubic_interpolate(p, dx, dy); | |
| `;default:throw Error(`mode ${n.mode} is not supported`)}})()+`${e.setByOffset(`global_idx`,`result`)}`,Ls=(e,t)=>{let n=q(`x`,e[0].dataType,e[0].dims.length),r=[e[1].dims[0],e[1].dims[1],e[1].dims[2]],i=q(`grid`,e[1].dataType,r.length,2),a=[e[0].dims[0],e[0].dims[1],e[1].dims[1],e[1].dims[2]];t.format===`NHWC`&&(a=[e[0].dims[0],e[1].dims[1],e[1].dims[2],e[0].dims[3]],[Es,Ds,Os,ks]=[0,3,1,2]);let o=J(`output`,e[0].dataType,a.length),s=n.type.value,c=[{type:12,data:z.size(a)},...W(e[0].dims,r,a)];return{name:`GridSample`,shaderCache:{hint:`${t.cacheKey}`,inputDependencies:[`type`,`type`]},getRunData:e=>{let t=z.size(a);return{outputs:[{dims:a,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(t/64)},programUniforms:c}},getShaderSource:e=>` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(n,i,o)} | |
| ${js} | |
| ${Ms(s)} | |
| ${Ns(t)} | |
| ${Ps(t)} | |
| ${Fs(n,s,t)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let H_in = i32(uniforms.x_shape[${Os}]); | |
| let W_in = i32(uniforms.x_shape[${ks}]); | |
| ${t.alignCorners===0?` | |
| let x_min = -0.5; | |
| let x_max = f32(W_in) - 0.5; | |
| let y_min = -0.5; | |
| let y_max = f32(H_in) - 0.5; | |
| `:` | |
| let x_min = 0.0; | |
| let x_max = f32(W_in) - 1.0; | |
| let y_min = 0.0; | |
| let y_max = f32(H_in) - 1.0; | |
| `}; | |
| let border = vec4<f32>(x_min, y_min, x_max, y_max); | |
| let indices = ${o.offsetToIndices(`global_idx`)}; | |
| var grid_indices = vec3<u32>(indices[${Es}], indices[${Os}], indices[${ks}]); | |
| let nxy = ${i.getByIndices(`grid_indices`)}; | |
| var x = gs_denormalize(f32(nxy[0]), W_in); | |
| var y = gs_denormalize(f32(nxy[1]), H_in); | |
| ${Is(o,s,t)} | |
| }`}},Rs=(e,t)=>{As(e.inputs),e.compute(Ls(e.inputs,t))},zs=e=>V({alignCorners:e.align_corners,mode:e.mode,paddingMode:e.padding_mode,format:e.format})}),Vs,Hs,Us,Ws,Gs,Ks,qs,Js=o(()=>{L(),B(),H(),ln(),qr(),Y(),Vn(),Vs=(e,t)=>e.length>t&&e[t].dims.length>0?e[t]:void 0,Hs=(e,t)=>{let n=e[0],r=Vs(e,1),i=Vs(e,2),a=Vs(e,3),o=Vs(e,4),s=Vs(e,5),c=Vs(e,6),l=Vs(e,7);if(n.dims.length!==3&&n.dims.length!==5)throw Error(`Input query is expected to have 3 or 5 dimensions`);let u=n.dims[0],d=n.dims[1],f=n.dims.length===3?n.dims[2]:t.numHeads*n.dims[4],p=d,m=0,h=0,g=Math.floor(f/t.numHeads);if(c&&l&&z.size(c.dims)&&z.size(l.dims)){if(c.dims.length!==4)throw Error(`Input "past_key" is expected to have 4 dimensions`);if(c.dims[0]!==u||c.dims[1]!==t.numHeads||c.dims[3]!==g)throw Error(`Input "past_key" shape (batch_size, num_heads, past_sequence_length, head_size)`);if(l.dims[0]!==u||l.dims[1]!==t.numHeads||l.dims[3]!==g)throw Error(`Input "past_value" shape (batch_size, num_heads, past_sequence_length, head_size)`);if(c.dims[2]!==l.dims[2])throw Error(`Input "past_key" and "past_value" shall have same dim 2 (past_sequence_length)`);if(l.dims.length!==4)throw Error(`Input "past_value" is expected to have 4 dimensions`);m=c.dims[2],h=c.dims[2]}else if(c&&z.size(c.dims)||l&&z.size(l.dims))throw Error(`Input "past_key" and "past_value" shall be both present or both absent`);let _;if(r&&z.size(r.dims)>0){if(n.dims.length!==3)throw Error(`Input "query" is expected to have 3 dimensions when key is given`);if(r.dims.length<3||r.dims.length>5)throw Error(`Input "key" is expected to have 3, 4, or 5 dimensions`);if(n.dims[0]!==r.dims[0])throw Error(`Input "query" and "key" shall have same dim 0 (batch size)`);if(r.dims.length===3){if(r.dims[2]!==n.dims[2])throw Error(`Input "query" and "key" shall have same dim 2 (hidden_size)`);_=2,p=r.dims[1]}else if(r.dims.length===5){if(r.dims[2]!==t.numHeads||r.dims[3]!==2||r.dims[4]!==g)throw Error(`Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv`);if(i)throw Error(`Expect "value" be none when "key" has packed kv format.`);_=5,p=r.dims[1]}else{if(r.dims[1]!==t.numHeads||r.dims[3]!==g)throw Error(`Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key`);_=0,p=r.dims[2]}}else{if(n.dims.length!==5)throw Error(`Input "query" is expected to have 5 dimensions when key is empty`);if(n.dims[2]!==t.numHeads||n.dims[3]!==3)throw Error(`Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv`);_=3}if(a&&z.size(a.dims)>0){if(a.dims.length!==1)throw Error(`Input "bias" is expected to have 1 dimension`);if(r&&r.dims.length===5&&r.dims[3]===2)throw Error(`bias is not allowed for packed kv.`)}let v=m+p,y=0;if(o&&z.size(o.dims)>0){y=8;let e=o.dims;throw e.length===1?e[0]===u?y=1:e[0]===3*u+2&&(y=3):e.length===2&&e[0]===u&&e[1]===v&&(y=5),Error(y===8?`Input "key_padding_mask" shape shall be (batch_size) or (batch_size, total_sequence_length)`:`Mask not supported`)}let b=!1,x=f;if(i&&z.size(i.dims)>0){if(i.dims.length!==3&&i.dims.length!==4)throw Error(`Input "value" is expected to have 3 or 4 dimensions`);if(n.dims[0]!==i.dims[0])throw Error(`Input "query" and "value" shall have same dim 0 (batch_size)`);if(i.dims.length===3){if(p!==i.dims[1])throw Error(`Input "key" and "value" shall have the same dim 1 (kv_sequence_length)`);x=i.dims[2]}else{if(p!==i.dims[2])throw Error(`Input "key" and "value" shall have the same dim 2 (kv_sequence_length)`);x=i.dims[1]*i.dims[3],b=!0}}if(o&&z.size(o.dims)>0)throw Error(`Key padding mask is not supported`);if(s&&z.size(s.dims)>0){if(s.dims.length!==4)throw Error(`Input "attention_bias" is expected to have 4 dimensions`);if(s.dims[0]!==u||s.dims[1]!==t.numHeads||s.dims[2]!==d||s.dims[3]!==v)throw Error(`Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)`)}return{batchSize:u,sequenceLength:d,pastSequenceLength:m,kvSequenceLength:p,totalSequenceLength:v,maxSequenceLength:h,inputHiddenSize:0,hiddenSize:f,vHiddenSize:x,headSize:g,vHeadSize:Math.floor(x/t.numHeads),numHeads:t.numHeads,isUnidirectional:!1,pastPresentShareBuffer:!1,maskFilterValue:t.maskFilterValue,maskType:y,scale:t.scale,broadcastResPosBias:!1,passPastInKv:b,qkvFormat:_}},Us=e=>V({...e}),Ws=V({perm:[0,2,1,3]}),Gs=(e,t,n,r,i,a,o)=>{let s=[r,i,a],c=z.size(s),l=[{type:12,data:c},{type:12,data:o},{type:12,data:a}];return e.compute({name:`MultiHeadAttentionAddBias`,shaderCache:{inputDependencies:[`type`,`type`]},getRunData:()=>({outputs:[{dims:s,dataType:t.dataType,gpuDataType:0}],dispatchGroup:{x:Math.ceil(c/64)},programUniforms:l}),getShaderSource:e=>{let r=J(`qkv_with_bias`,t.dataType,s),i=q(`qkv`,t.dataType,s),a=q(`bias`,n.dataType,s);return` | |
| ${e.registerUniforms([{name:`output_size`,type:`u32`},{name:`bias_offset`,type:`u32`},{name:`hidden_size`,type:`u32`}]).declareVariables(i,a,r)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset; | |
| qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx]; | |
| }`}},{inputs:[t,n],outputs:[-1]})[0]},Ks=(e,t,n,r,i,a,o,s)=>{let c=a;if(o&&z.size(o.dims)>0){if(r===1)throw Error(`AddBiasReshape is not implemented. Please export your model with packed QKV or KV`);return c=Gs(e,a,o,t,r,n*i,s),c=c.reshape([t,r,n,i]),n===1||r===1?c:e.compute(Rn(c,Ws.perm),{inputs:[c],outputs:[-1]})[0]}else return a.dims.length===3&&(c=a.reshape([t,r,n,i])),n===1||r===1?c:e.compute(Rn(c,Ws.perm),{inputs:[c],outputs:[-1]})[0]},qs=(e,t)=>{let n=Hs(e.inputs,t),r=e.inputs[0],i=Vs(e.inputs,1),a=Vs(e.inputs,2),o=Vs(e.inputs,3),s=Vs(e.inputs,4),c=Vs(e.inputs,5),l=Vs(e.inputs,6),u=Vs(e.inputs,7);if(r.dims.length===5)throw Error(`Packed QKV is not implemented`);if(i?.dims.length===5)throw Error(`Packed KV is not implemented`);let d=i&&a&&i.dims.length===4&&a.dims.length===4,f=Ks(e,n.batchSize,n.numHeads,n.sequenceLength,n.headSize,r,o,0);if(d)return Wr(e,f,i,a,s,void 0,l,u,c,n);if(!i||!a)throw Error(`key and value must be provided`);let p=Ks(e,n.batchSize,n.numHeads,n.kvSequenceLength,n.headSize,i,o,n.hiddenSize),m=Ks(e,n.batchSize,n.numHeads,n.kvSequenceLength,n.vHeadSize,a,o,2*n.hiddenSize);Wr(e,f,p,m,s,void 0,l,u,c,n)}}),Ys,Xs,Zs,Qs,$s,ec,tc,nc=o(()=>{L(),B(),H(),Y(),Ys=e=>{if(!e||e.length<1)throw Error(`too few inputs`)},Xs=(e,t)=>{let n=[],r=t.numOutputs;return e[1].dims[0]>0&&(e[1].getBigInt64Array().forEach(e=>n.push(Number(e))),r=n.length),V({numOutputs:r,axis:t.axis,splitSizes:n})},Zs=e=>` | |
| fn calculateOutputIndex(index: u32) -> u32 { | |
| for (var i: u32 = 0u; i < ${e}u; i += 1u ) { | |
| if (index < ${K(`uniforms.size_in_split_axis`,`i`,e)}) { | |
| return i; | |
| } | |
| } | |
| return ${e}u; | |
| }`,Qs=e=>{let t=e.length,n=[];for(let r=0;r<t;++r){let i=e[r].setByIndices(`indices`,`input[global_idx]`);t===1?n.push(i):r===0?n.push(`if (output_number == ${r}u) { ${i} }`):r===t-1?n.push(`else { ${i} }`):n.push(`else if (output_number == ${r}) { ${i} }`)}return` | |
| fn writeBufferData(output_number: u32, indices: ${e[0].type.indices}, global_idx: u32) { | |
| ${n.join(` | |
| `)} | |
| }`},$s=(e,t)=>{let n=e[0].dims,r=z.size(n),i=e[0].dataType,a=z.normalizeAxis(t.axis,n.length),o=Array(t.numOutputs),s=q(`input`,i,n.length),c=Array(t.numOutputs),l=[],u=[],d=0,f=[{type:12,data:r}];for(let r=0;r<t.numOutputs;r++){d+=t.splitSizes[r],c[r]=d;let s=n.slice();s[a]=t.splitSizes[r],u.push(s),o[r]=J(`output${r}`,i,s.length),l.push({dims:u[r],dataType:e[0].dataType})}return f.push({type:12,data:c},...W(n,...u)),{name:`Split`,shaderCache:{hint:t.cacheKey,inputDependencies:[`rank`]},getShaderSource:e=>` | |
| ${e.registerUniform(`input_size`,`u32`).registerUniform(`size_in_split_axis`,`u32`,c.length).declareVariables(s,...o)} | |
| ${Zs(c.length)} | |
| ${Qs(o)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.input_size`)} | |
| var indices = ${s.offsetToIndices(`global_idx`)}; | |
| var index = ${s.indicesGet(`indices`,a)}; | |
| let output_number = calculateOutputIndex(index); | |
| if (output_number != 0) { | |
| index -= ${K(`uniforms.size_in_split_axis`,`output_number - 1u`,c.length)}; | |
| ${s.indicesSet(`indices`,a,`index`)}; | |
| } | |
| writeBufferData(output_number, indices, global_idx); | |
| }`,getRunData:()=>({outputs:l,dispatchGroup:{x:Math.ceil(r/64)},programUniforms:f})}},ec=(e,t)=>{Ys(e.inputs);let n=e.inputs.length===1?t:Xs(e.inputs,t);e.compute($s(e.inputs,n),{inputs:[0]})},tc=e=>{let t=e.axis,n=e.splitSizes,r=e.numOutputs<0?n.length:e.numOutputs;if(r!==n.length)throw Error(`numOutputs and splitSizes length must be equal`);return V({axis:t,numOutputs:r,splitSizes:n})}}),rc,ic,ac,oc=o(()=>{L(),B(),H(),Y(),rc=(e,t)=>{let[n,r,i,a]=e,{numHeads:o,rotaryEmbeddingDim:s}=t;if(n.dims.length!==3&&n.dims.length!==4)throw Error(`Input 'x' is expected to have 3 or 4 dimensions, got ${n.dims.length}`);if(!z.areEqual(r.dims,[])&&!z.areEqual(r.dims,[1])&&r.dims.length!==2)throw Error(`Input 'position_ids' is expected to have 0, 1, or 2 dimensions, got ${r.dims.length}`);if(i.dims.length!==2)throw Error(`Input 'cos_cache' is expected to have 2 dimensions, got ${i.dims.length}`);if(a.dims.length!==2)throw Error(`Input 'sin_cache' is expected to have 2 dimensions, got ${a.dims.length}`);if(!z.areEqual(i.dims,a.dims))throw Error(`Inputs 'cos_cache' and 'sin_cache' are expected to have the same shape`);if(s>0&&o===0)throw Error(`num_heads must be provided if rotary_embedding_dim is specified`);let c=n.dims[0],l=n.dims[n.dims.length-2],u=i.dims[0],d=z.sizeFromDimension(n.dims,1)/l,f=s===0?i.dims[1]*2:d/o;if(s>f)throw Error(`rotary_embedding_dim must be less than or equal to head_size`);if(r.dims.length===2){if(c!==r.dims[0])throw Error(`Input 'position_ids' dimension 0 should be of size batch_size, got ${r.dims[0]}`);if(l!==r.dims[1])throw Error(`Input 'position_ids' dimension 1 should be of size sequence_length, got ${r.dims[1]}`)}if(f/2!==i.dims[1]&&s/2!==i.dims[1])throw Error(`Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got ${i.dims[1]}`);if(l>u)throw Error(`Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported`)},ic=(e,t)=>{let{interleaved:n,numHeads:r,rotaryEmbeddingDim:i,scale:a}=t,o=e[0].dims[0],s=z.sizeFromDimension(e[0].dims,1),c=e[0].dims[e[0].dims.length-2],l=s/c,u=e[2].dims[1],d=i===0?u*2:l/r,f=[o,c,l/d,d-u],p=z.computeStrides(f),m=[{type:1,data:a},{type:12,data:f},{type:12,data:p},...e[0].dims.length===3?Array({type:12,data:[s,l,d,1]}):[],...e[0].dims.length===4?Array({type:12,data:[s,d,c*d,1]}):[],...W(e[0].dims,e[1].dims,e[2].dims,e[3].dims,e[0].dims)];return{name:`RotaryEmbedding`,shaderCache:{hint:V({interleaved:n}).cacheKey,inputDependencies:[`rank`,`rank`,`rank`,`rank`]},getShaderSource:t=>{let r=q(`input`,e[0].dataType,e[0].dims.length),i=q(`position_ids`,e[1].dataType,e[1].dims.length),a=q(`cos_cache`,e[2].dataType,e[2].dims.length),o=q(`sin_cache`,e[3].dataType,e[3].dims.length),s=J(`output`,e[0].dataType,e[0].dims.length);return t.registerUniforms([{name:`scale`,type:`f32`},{name:`global_shape`,type:`u32`,length:f.length},{name:`global_strides`,type:`u32`,length:p.length},{name:`input_output_strides`,type:`u32`,length:p.length}]),` | |
| ${t.declareVariables(r,i,a,o,s)} | |
| ${t.mainStart(xn)} | |
| let half_rotary_emb_dim = uniforms.${a.name}_shape[1]; | |
| let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape; | |
| let size = uniforms.global_shape[0] * uniforms.global_strides[0]; | |
| ${t.guardAgainstOutOfBoundsWorkgroupSizes(`size`)} | |
| if (bsnh[3] < half_rotary_emb_dim) { | |
| let position_ids_idx = | |
| ${i.broadcastedIndicesToOffset(`bsnh.xy`,J(``,i.type.tensor,2))}; | |
| let position_id = | |
| u32(${i.getByOffset(`position_ids_idx`)}) + select(0, bsnh[1], position_ids_idx == 0); | |
| let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], ${n}); | |
| let j = i + select(half_rotary_emb_dim, 1, ${n}); | |
| let re = ${r.getByOffset(`i`)} * ${a.get(`position_id`,`bsnh[3]`)} - | |
| ${r.getByOffset(`j`)} * ${o.get(`position_id`,`bsnh[3]`)}; | |
| ${s.setByOffset(`i`,`re`)} | |
| let im = ${r.getByOffset(`i`)} * ${o.get(`position_id`,`bsnh[3]`)} + | |
| ${r.getByOffset(`j`)} * ${a.get(`position_id`,`bsnh[3]`)}; | |
| ${s.setByOffset(`j`,`im`)} | |
| } else { | |
| let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim; | |
| ${s.setByOffset(`k`,r.getByOffset(`k`))} | |
| } | |
| }`},getRunData:()=>({outputs:[{dims:e[0].dims,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(z.size(f)/xn)},programUniforms:m})}},ac=(e,t)=>{rc(e.inputs,t),e.compute(ic(e.inputs,t))}}),sc,cc,lc,uc,dc,fc=o(()=>{H(),L(),qr(),Js(),nc(),Vn(),oc(),Y(),sc=(e,t)=>{if(t.doRotary&&e.length<=7)throw Error(`cos_cache and sin_cache inputs are required if do_rotary is specified`);let n=e[0],r=e[1],i=e[2],a=e[3],o=e[4];if(t.doRotary!==0&&e.length<=7)throw Error(`cos_cast and sin_cache are expected if do_rotary attribute is non-zero`);if(t.localWindowSize!==-1)throw Error(`Local attention is not supported`);if(t.softcap!==0)throw Error(`Softcap is not supported`);if(t.rotaryInterleaved!==0)throw Error(`Rotary interleaved is not supported`);if(t.smoothSoftmax)throw Error(`Smooth softmax is not supported`);if(n.dims.length!==3&&n.dims.length!==5)throw Error(`Input query is expected to have 3 or 5 dimensions`);let s=n.dims[0],c=n.dims[1],l=n.dims.length===3?n.dims[2]:t.numHeads*n.dims[4],u=c,d=0,f=!r||r.dims.length===0,p=Math.floor(f?l/(t.numHeads+2*t.kvNumHeads):l/t.numHeads);f&&(l=p*t.numHeads);let m=a&&a.dims.length!==0,h=o&&o.dims.length!==0;if(m&&a.dims.length===4&&a.dims[0]===s&&a.dims[1]!==t.kvNumHeads&&a.dims[2]===t.kvNumHeads&&a.dims[3]===p)throw Error(`BSNH pastKey/pastValue is not supported`);if(m&&h){if(a.dims.length!==4)throw Error(`Input "past_key" is expected to have 4 dimensions`);if(o.dims.length!==4)throw Error(`Input "past_value" is expected to have 4 dimensions`);d=a.dims[2]}else if(m||h)throw Error(`Input "past_key" and "past_value" shall be both present or both absent`);let g=1;if(r&&r.dims.length>0){if(n.dims.length!==3)throw Error(`Input "query" is expected to have 3 dimensions when key is given`);if(r.dims.length<3||r.dims.length>5)throw Error(`Input "key" is expected to have 3, 4, or 5 dimensions`);if(n.dims[0]!==r.dims[0])throw Error(`Input "query" and "key" shall have same dim 0 (batch size)`);if(r.dims.length===3){if(n.dims[2]%r.dims[2]!==0)throw Error(`Dimension 2 of "query" should be a multiple of "key"`);u=r.dims[1]}else if(r.dims.length===5){if(r.dims[2]!==t.numHeads||r.dims[3]!==2||r.dims[4]!==p)throw Error(`Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv`);if(i)throw Error(`Expect "value" be none when "key" has packed kv format.`);u=r.dims[1]}else{if(r.dims[1]!==t.numHeads||r.dims[3]!==p)throw Error(`Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key`);u=r.dims[2]}}else{if(n.dims.length!==3&&n.dims.length!==5)throw Error(`Input "query" is expected to have 3 or 5 dimensions when key is empty`);if(n.dims.length===5&&(n.dims[2]!==t.numHeads||n.dims[3]!==3))throw Error(`Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv`);g=3}let _=!1,v=t.kvNumHeads?p*t.kvNumHeads:l;if(i&&i.dims.length>0){if(i.dims.length!==3&&i.dims.length!==4)throw Error(`Input "value" is expected to have 3 or 4 dimensions`);if(n.dims[0]!==i.dims[0])throw Error(`Input "query" and "value" shall have same dim 0 (batch_size)`);if(i.dims.length===3){if(u!==i.dims[1])throw Error(`Input "key" and "value" shall have the same dim 1 (kv_sequence_length)`);v=i.dims[2]}else{if(u!==i.dims[2])throw Error(`Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)`);v=i.dims[1]*i.dims[3],_=!0}}let y=e.length>4?e[5]:void 0;if(y&&y.dims.length!==1&&y.dims[0]!==s)throw Error(`Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size`);return{batchSize:s,sequenceLength:c,pastSequenceLength:d,kvSequenceLength:u,totalSequenceLength:-1,maxSequenceLength:-1,inputHiddenSize:0,hiddenSize:l,vHiddenSize:v,headSize:p,vHeadSize:Math.floor(v/t.kvNumHeads),numHeads:t.numHeads,kvNumHeads:t.kvNumHeads,nReps:t.numHeads/t.kvNumHeads,pastPresentShareBuffer:!1,maskType:0,scale:t.scale,broadcastResPosBias:!1,passPastInKv:_,qkvFormat:g}},cc=V({perm:[0,2,1,3]}),lc=(e,t,n)=>{let r=t,i=n.kvNumHeads;return t.dims.length===3&&n.kvSequenceLength!==0&&(r=t.reshape([n.batchSize,n.kvSequenceLength,i,n.headSize]),r=e.compute(Rn(r,cc.perm),{inputs:[r],outputs:[-1]})[0]),r},uc=(e,t,n,r)=>{let i=[`type`,`type`],a=[e*t],o=e*t,s=[{type:12,data:o},{type:12,data:t},{type:12,data:e}];return{name:`GeneratePositionIds`,shaderCache:{hint:`${e};${t}`,inputDependencies:i},getRunData:()=>({outputs:[{dims:a,dataType:7}],dispatchGroup:{x:Math.ceil(o/64)},programUniforms:s}),getShaderSource:e=>{let t=q(`seq_lens`,n.dataType,n.dims),i=q(`total_seq_lens`,r.dataType,r.dims),o=J(`pos_ids`,7,a);return` | |
| ${e.registerUniforms([{name:`output_size`,type:`u32`},{name:`sequence_length`,type:`u32`},{name:`batch_size`,type:`u32`}]).declareVariables(t,i,o)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let total_sequence_length = u32(${i.getByOffset(`0`)}); | |
| let is_subsequent_prompt = uniforms.sequence_length > 1 && uniforms.sequence_length != total_sequence_length; | |
| let is_first_prompt = !is_subsequent_prompt && uniforms.sequence_length == total_sequence_length; | |
| let batch_idx = global_idx / uniforms.sequence_length; | |
| let sequence_idx = i32(global_idx % uniforms.sequence_length); | |
| var pos_id: i32 = 0; | |
| let seqlen = ${t.getByOffset(`batch_idx`)}; | |
| let total_seqlen = seqlen + 1; | |
| if (is_first_prompt) { | |
| if (sequence_idx < total_seqlen) { | |
| pos_id = sequence_idx; | |
| } else { | |
| pos_id = 1; | |
| } | |
| ${o.setByOffset(`global_idx`,`pos_id`)} | |
| } else if (is_subsequent_prompt) { | |
| let past_seqlen = total_seqlen - i32(uniforms.sequence_length); | |
| if (past_seqlen + sequence_idx < total_seqlen) { | |
| pos_id = past_seqlen + sequence_idx; | |
| } else { | |
| pos_id = 1; | |
| } | |
| ${o.setByOffset(`global_idx`,`pos_id`)} | |
| } else if (global_idx < uniforms.batch_size) { | |
| ${o.setByOffset(`global_idx`,`seqlen`)} | |
| }; | |
| } | |
| `}}},dc=(e,t)=>{let n=sc(e.inputs,t);if(e.inputs[0].dims.length===5)throw Error(`Packed QKV is not implemented`);if(e.inputs[1]?.dims.length===5)throw Error(`Packed KV is not implemented`);let r=e.inputs[0],i=e.inputs[1]&&e.inputs[1].dims.length>0?e.inputs[1]:void 0,a=e.inputs[2]&&e.inputs[2].dims.length>0?e.inputs[2]:void 0,o=e.inputs[3]&&e.inputs[3].dims.length!==0?e.inputs[3]:void 0,s=e.inputs[4]&&e.inputs[4].dims.length!==0?e.inputs[4]:void 0,c=e.inputs.length>4?e.inputs[5]:void 0,l=e.inputs.length>5?e.inputs[6]:void 0,u=n.kvNumHeads?n.kvNumHeads:n.numHeads,d=V({axis:2,numOutputs:3,splitSizes:[n.numHeads*n.headSize,u*n.headSize,u*n.headSize]}),[f,p,m]=!i&&!a?e.compute($s([r],d),{inputs:[r],outputs:[-1,-1,-1]}):[r,i,a],h,g;if(t.doRotary){let r=e.compute(uc(n.batchSize,n.sequenceLength,c,l),{inputs:[c,l],outputs:[-1]})[0],i=e.inputs[7],a=e.inputs[8],o=V({interleaved:t.rotaryInterleaved!==0,numHeads:n.numHeads,rotaryEmbeddingDim:0,scale:t.scale}),s=[f,r,i,a],u=[-1];h=e.compute(ic(s,o),{inputs:s,outputs:u})[0],s.splice(0,1,p);let d=V({interleaved:t.rotaryInterleaved!==0,numHeads:n.kvNumHeads,rotaryEmbeddingDim:0,scale:t.scale});g=e.compute(ic(s,d),{inputs:s,outputs:u})[0]}let _=Ks(e,n.batchSize,n.numHeads,n.sequenceLength,n.headSize,t.doRotary?h:f,void 0,0),v=lc(e,t.doRotary?g:p,n),y=lc(e,m,n);Wr(e,_,v,y,void 0,void 0,o,s,void 0,n,c,l)}}),pc,mc,hc,gc,_c=o(()=>{L(),B(),Vn(),Y(),pc=(e,t,n,r,i,a,o,s)=>{let c=G(a),l=c===1?`f32`:`vec${c}f`,u=c===1?`vec2f`:`mat2x${c}f`,d=i*o,f=64;d===1&&(f=256);let p=[i,o,a/c],m=[i,o,2],h=[`rank`,`type`,`type`],g=[];return g.push(...W(p,m)),e.compute({name:`InstanceNormComputeChannelScaleShift`,shaderCache:{hint:`${c};${s};${f}`,inputDependencies:h},getRunData:()=>({outputs:[{dims:m,dataType:1}],dispatchGroup:{x:d},programUniforms:g}),getShaderSource:e=>{let i=q(`x`,t.dataType,3,c),a=[i,q(`scale`,n.dataType,n.dims),q(`bias`,r.dataType,r.dims),J(`output`,1,3,2)];return` | |
| var<workgroup> workgroup_shared : array<${u}, ${f}>; | |
| const workgroup_size = ${f}u; | |
| ${e.declareVariables(...a)} | |
| ${e.mainStart(f)} | |
| let batch = workgroup_index / uniforms.x_shape[1]; | |
| let channel = workgroup_index % uniforms.x_shape[1]; | |
| let hight = uniforms.x_shape[2]; | |
| // initialize workgroup memory | |
| var sum = ${l}(0); | |
| var squared_sum = ${l}(0); | |
| for (var h = local_idx; h < hight; h += workgroup_size) { | |
| let value = ${l}(${i.get(`batch`,`channel`,`h`)}); | |
| sum += value; | |
| squared_sum += value * value; | |
| } | |
| workgroup_shared[local_idx] = ${u}(sum, squared_sum); | |
| workgroupBarrier(); | |
| for (var currSize = workgroup_size >> 1; currSize > 0; currSize = currSize >> 1) { | |
| if (local_idx < currSize) { | |
| workgroup_shared[local_idx] = workgroup_shared[local_idx] + workgroup_shared[local_idx + currSize]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (local_idx == 0) { | |
| let sum_final = ${En(`workgroup_shared[0][0]`,c)} / f32(hight * ${c}); | |
| let squared_sum_final = ${En(`workgroup_shared[0][1]`,c)} / f32(hight * ${c}); | |
| let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(${s})); | |
| let channel_scale = inv_std_dev * f32(scale[channel]); | |
| let channel_shift = f32(bias[channel]) - sum_final * channel_scale; | |
| output[workgroup_index] = vec2f(channel_scale, channel_shift); | |
| } | |
| }`}},{inputs:[t,n,r],outputs:[-1]})[0]},mc=(e,t,n)=>{let r=t[0].dims,i=r,a=r[0],o=r[1],s=z.sizeFromDimension(r,2),c=G(s),l=z.size(i)/c,u=pc(e,t[0],t[1],t[2],a,s,o,n.epsilon),d=[a,o,s/c],f=[a,o];e.compute({name:`InstanceNormalization`,shaderCache:{hint:`${c}`,inputDependencies:[`type`,`none`]},getRunData:()=>({outputs:[{dims:i,dataType:t[0].dataType}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:[{type:12,data:l},...W(d,f,d)]}),getShaderSource:e=>{let n=q(`x`,t[0].dataType,d.length,c),r=q(`scale_shift`,1,f.length,2),i=J(`output`,t[0].dataType,d.length,c),a=[n,r,i];return` | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(...a)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let outputIndices = ${i.offsetToIndices(`global_idx`)}; | |
| let batch = outputIndices[0]; | |
| let channel = outputIndices[1]; | |
| let scale_shift = ${r.getByIndices(`vec2<u32>(batch, channel)`)}; | |
| let value = ${n.getByOffset(`global_idx`)} * ${i.type.value}(scale_shift.x) + ${i.type.value}(scale_shift.y); | |
| ${i.setByOffset(`global_idx`,`value`)}; | |
| }`}},{inputs:[t[0],u]})},hc=(e,t,n)=>{let r=t[0].dims,i=r,a=r[0],o=r[r.length-1],s=z.sizeFromDimension(r,1)/o,c=G(o),l=z.size(i)/c,u=[{type:12,data:s},{type:12,data:Math.floor(o/c)}],d=[`type`,`type`],f=!1,p=[0,r.length-1];for(let e=0;e<r.length-2;e++)f||=r[e+1]!==1,p.push(e+1);f&&=r[r.length-1]!==1;let m=f?e.compute(Rn(e.inputs[0],p),{inputs:[e.inputs[0]],outputs:[-1]})[0]:e.inputs[0].reshape(Array.from({length:r.length},(e,t)=>r[p[t]])),h=pc(e,m,t[1],t[2],a,s,o,n.epsilon);e.compute({name:`InstanceNormalizationNHWC`,shaderCache:{hint:`${c}`,inputDependencies:d},getRunData:()=>({outputs:[{dims:i,dataType:t[0].dataType}],dispatchGroup:{x:Math.ceil(l/64)},programUniforms:u}),getShaderSource:e=>{let n=U(t[0].dataType),r=c===1?`vec2f`:`mat${c}x2f`,a=e=>{let t=e===0?`x`:`y`,r=c===1?`f32`:`vec${c}f`;switch(c){case 1:return`${n}(${r}(scale.${t}))`;case 2:return`vec2<${n}>(${r}(scale[0].${t}, scale[1].${t}))`;case 4:return`vec4<${n}>(${r}(scale[0].${t}, scale[1].${t}, scale[2].${t}, scale[3].${t}))`;default:throw Error(`Not supported compoents ${c}`)}},o=q(`input`,t[0].dataType,t[0].dims,c),s=J(`output`,t[0].dataType,i,c);return` | |
| @group(0) @binding(0) var<storage, read> input : array<${o.type.storage}>; | |
| @group(0) @binding(1) var<storage, read> scale_input : array<${r}>; | |
| @group(0) @binding(2) var<storage, read_write> output : array<${s.type.storage}>; | |
| struct Uniforms {H: u32, C : u32}; | |
| @group(0) @binding(3) var<uniform> uniforms: Uniforms; | |
| ${e.mainStart()} | |
| let current_image_number = global_idx / (uniforms.C * uniforms.H); | |
| let current_channel_number = global_idx % uniforms.C; | |
| let scale_offset = current_image_number * uniforms.C + current_channel_number; | |
| let scale = scale_input[scale_offset]; | |
| output[global_idx] = fma(input[global_idx], ${a(0)}, ${a(1)}); | |
| }`}},{inputs:[t[0],h]})},gc=(e,t)=>{t.format===`NHWC`?hc(e,e.inputs,t):mc(e,e.inputs,t)}}),vc,yc,bc,xc=o(()=>{L(),B(),Y(),vc=e=>{if(!e||e.length<2)throw Error(`layerNorm requires at least 2 inputs.`)},yc=(e,t,n)=>{let r=t.simplified,i=e[0].dims,a=e[1],o=!r&&e[2],s=i,c=z.normalizeAxis(t.axis,i.length),l=z.sizeToDimension(i,c),u=z.sizeFromDimension(i,c),d=z.size(a.dims),f=o?z.size(o.dims):0;if(d!==u||o&&f!==u)throw Error(`Size of X.shape()[axis:] == ${u}. | |
| Size of scale and bias (if provided) must match this. | |
| Got scale size of ${d} and bias size of ${f}`);let p=[];for(let e=0;e<i.length;++e)e<c?p.push(i[e]):p.push(1);let m=G(u),h=[`type`,`type`],g=[{type:12,data:l},{type:1,data:u},{type:12,data:Math.floor(u/m)},{type:1,data:t.epsilon}];o&&h.push(`type`);let _=n>1,v=n>2,y=t=>{let n=U(e[0].dataType),i=[q(`x`,e[0].dataType,e[0].dims,m),q(`scale`,a.dataType,a.dims,m)];return o&&i.push(q(`bias`,o.dataType,o.dims,m)),i.push(J(`output`,e[0].dataType,s,m)),_&&i.push(J(`mean_data_output`,1,p)),v&&i.push(J(`inv_std_output`,1,p)),` | |
| ${t.registerUniforms([{name:`norm_count`,type:`u32`},{name:`norm_size`,type:`f32`},{name:`norm_size_vectorized`,type:`u32`},{name:`epsilon`,type:`f32`}]).declareVariables(...i)} | |
| ${t.mainStart()} | |
| ${t.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.norm_count`)} | |
| let offset = global_idx * uniforms.norm_size_vectorized; | |
| var mean_vector = ${wn(`f32`,m)}; | |
| var mean_square_vector = ${wn(`f32`,m)}; | |
| for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { | |
| let value = ${Tn(n,m,`x[h + offset]`)}; | |
| mean_vector += value; | |
| mean_square_vector += value * value; | |
| } | |
| let mean = ${En(`mean_vector`,m)} / uniforms.norm_size; | |
| let inv_std_dev = inverseSqrt(${En(`mean_square_vector`,m)} / uniforms.norm_size ${r?``:`- mean * mean`} + uniforms.epsilon); | |
| for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { | |
| let f32input = ${Tn(n,m,`x[j + offset]`)}; | |
| let f32scale = ${Tn(n,m,`scale[j]`)}; | |
| output[j + offset] = ${i[0].type.value}((f32input ${r?``:`- mean`}) * inv_std_dev * f32scale | |
| ${o?`+ ${Tn(n,m,`bias[j]`)}`:``} | |
| ); | |
| } | |
| ${_?`mean_data_output[global_idx] = mean`:``}; | |
| ${v?`inv_std_output[global_idx] = inv_std_dev`:``}; | |
| }`},b=[{dims:s,dataType:e[0].dataType}];return _&&b.push({dims:p,dataType:1}),v&&b.push({dims:p,dataType:1}),{name:`LayerNormalization`,shaderCache:{hint:`${m};${n};${r}`,inputDependencies:h},getRunData:()=>({outputs:b,dispatchGroup:{x:Math.ceil(l/64)},programUniforms:g}),getShaderSource:y}},bc=(e,t)=>{vc(e.inputs),e.compute(yc(e.inputs,t,e.outputCount))}}),Sc,Cc,wc=o(()=>{B(),ja(),Ba(),Sc=e=>{if(!e||e.length!==2)throw Error(`MatMul requires 2 inputs.`);if(e[0].dims[e[0].dims.length-1]!==e[1].dims[e[1].dims.length-2])throw Error(`shared dimension does not match.`)},Cc=e=>{Sc(e.inputs);let t=zt.calcShape(e.inputs[0].dims,e.inputs[1].dims,!0);if(!t)throw Error(`Can't use matmul on the given tensors`);let n=t[t.length-1],r=e.inputs[0].dims[e.inputs[0].dims.length-1];if(n<8&&r<8)e.compute(Aa(e.inputs,{activation:``},t));else{let i=t[t.length-2],a=z.size(e.inputs[0].dims.slice(0,-2)),o=z.size(e.inputs[1].dims.slice(0,-2));if(a!==1&&i===1&&o===1){let i=e.inputs[0].reshape([1,a,r]),o=e.inputs[1].reshape([1,r,n]),s=[1,a,n],c=[i,o];e.compute(za(c,{activation:``},t,s),{inputs:c})}else e.compute(za(e.inputs,{activation:``},t))}}}),Tc,Ec,Dc,Oc,kc,Ac=o(()=>{L(),B(),H(),Y(),Tc=(e,t)=>{if(e.length<3||e.length>4)throw Error(`MatMulNBits requires 3 or 4 inputs`);let n=e[0],r=n.dims.length;if(n.dims[r-1]!==t.k)throw Error(`The last dim of input shape does not match the k value`);let i=Math.floor((t.k+t.blockSize-1)/t.blockSize),a=t.blockSize/8*t.bits,o=e[1];if(!z.areEqual(o.dims,[t.n,i,a]))throw Error(`The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize`);let s=e[2].dims;if(z.size(s)!==t.n*i)throw Error(`scales input size error.`);if(e.length===4){let n=e[3].dims,r=t.n*(t.bits===8?i:Math.floor((i*t.bits+7)/8));if(z.size(n)!==r)throw Error(`zeroPoints input size error.`)}},Ec=(e,t)=>{let n=e[0].dims,r=n.length,i=n[r-2],a=t.k,o=t.n,s=n.slice(0,r-2),c=z.size(s),l=e[1].dims[2]/4,u=e[0].dataType,d=G(t.k),f=G(l),p=G(o),m=s.concat([i,o]),h=i>1&&o/p%2==0?2:1,g=z.size(m)/p/h,_=[],v=[c,i,a/d],y=z.convertShape(e[1].dims).slice();y.splice(-1,1,l/f),_.push(...W(v)),_.push(...W(y)),_.push(...W(e[2].dims)),e.length===4&&_.push(...W(z.convertShape(e[3].dims)));let b=[c,i,o/p];return _.push(...W(b)),{name:`MatMulNBits`,shaderCache:{hint:`${t.blockSize};${t.bits};${d};${f};${p};${h};64`,inputDependencies:Array(e.length).fill(`rank`)},getRunData:()=>({outputs:[{dims:m,dataType:u}],dispatchGroup:{x:g},programUniforms:_}),getShaderSource:n=>{let r=v.length,i=q(`a`,e[0].dataType,r,d),a=q(`b`,12,y.length,f),o=q(`scales`,e[2].dataType,e[2].dims.length),s=[i,a,o],c=e.length===4?q(`zero_points`,12,e[3].dims.length):void 0;c&&s.push(c);let u=b.length,m=J(`output`,e[0].dataType,u,p),g=U(e[0].dataType),_=(()=>{switch(d){case 1:return`array<${g}, 8>`;case 2:return`mat4x2<${g}>`;case 4:return`mat2x4<${g}>`;default:throw Error(`${d}-component is not supported.`)}})(),x=Math.floor(32/t.bits),S=Math.floor(x/8),C=()=>{let e=``;for(let n=0;n<S;n++){let r=n*t.bits*4,a=r+t.bits;e+=` | |
| // reuse a data (pass ${n}) | |
| var input_offset${n>0?n:``} = ${n===0?i.indicesToOffset(`${i.type.indices}(batch, row, word_offset)`):`input_offset`}; | |
| var a_data${n>0?n:``}: ${_}; | |
| for (var j${n>0?n:``}: u32 = 0; j${n>0?n:``} < ${8/d}; j${n>0?n:``}++) { | |
| a_data${n>0?n:``}[j${n>0?n:``}] = ${i.getByOffset(`input_offset${n>0?n:``}`)}; | |
| input_offset${n>0?n:``}++; | |
| } | |
| `;for(let i=0;i<p*h;i++)e+=` | |
| b_value = ${f===1?`b${i}_data`:`b${i}_data[i]`}; | |
| ${t.bits===2?`{ | |
| let half_word = b_value >> ${n*16}u; | |
| let byte_lo = half_word & 0xFFu; | |
| let byte_hi = (half_word >> 8u) & 0xFFu; | |
| let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u); | |
| b_value_lower = unpack4xU8(spread_word & b_mask); | |
| b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask); | |
| }`:`b_value_lower = unpack4xU8((b_value >> ${r}u) & b_mask); | |
| b_value_upper = unpack4xU8((b_value >> ${a}u) & b_mask);`} | |
| b_quantized_values = ${_}(${Array.from({length:4},(e,t)=>`${g}(b_value_lower[${t}]), ${g}(b_value_upper[${t}])`).join(`, `)}); | |
| b_dequantized_values = ${d===1?`${_}(${Array.from({length:8},(e,t)=>`(b_quantized_values[${t}] - ${c?`zero_point${i}`:`zero_point`}) * scale${i}`).join(`, `)});`:`(b_quantized_values - ${_}(${Array(8).fill(`${c?`zero_point${i}`:`zero_point`}`).join(`,`)})) * scale${i};`}; | |
| workgroup_shared[local_id.x * ${h} + ${Math.floor(i/p)}]${p>1?`[${i%p}]`:``} += ${Array.from({length:8/d},(e,t)=>`${d===1?`a_data${n>0?n:``}[${t}] * b_dequantized_values[${t}]`:`dot(a_data${n>0?n:``}[${t}], b_dequantized_values[${t}])`}`).join(` + `)}; | |
| `}return e},w=()=>{let e=` | |
| var col_index = col * ${p}; | |
| ${c?` | |
| let zero_point_values_per_byte: u32 = ${Math.floor(8/t.bits)}u; | |
| let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte; | |
| var zero_point_byte_count: u32; | |
| var zero_point_word_index: u32; | |
| var zero_point_byte_offset: u32; | |
| let zero_point_sub_offset: u32 = block % zero_point_values_per_byte; | |
| var zero_point_bits_offset: u32; | |
| var zero_point_word: u32;`:` | |
| // The default zero point is ${2**(t.bits-1)} for unsigned ${t.bits}-bit quantization. | |
| let zero_point = ${g}(${(2**(t.bits-1)).toFixed(1)});`} | |
| `;for(let n=0;n<p*h;n++)e+=` | |
| let scale${n} = ${o.getByOffset(`col_index * nBlocksPerCol + block`)}; | |
| ${c?` | |
| zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte); | |
| zero_point_word_index = zero_point_byte_count >> 0x2u; | |
| zero_point_byte_offset = zero_point_byte_count & 0x3u; | |
| zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${t.bits}u); | |
| zero_point_word = ${c.getByOffset(`zero_point_word_index`)} >> zero_point_bits_offset; | |
| let zero_point${n} = ${g}((zero_point_word) & ${t.bits===2?`0x3u`:`0xFu`});`:``} | |
| col_index += 1;`;return e},ee=()=>{let e=`col_index = col * ${p};`;for(let t=0;t<p*h;t++)e+=` | |
| let b${t}_data = ${a.getByIndices(`${a.type.indices}(col_index, block, word)`)}; | |
| col_index += 1;`;return e+=` | |
| var b_value: u32; | |
| let b_mask: u32 = ${t.bits===2?`0x03030303u`:`0x0F0F0F0Fu`}; | |
| var b_value_lower: vec4<u32>; | |
| var b_value_upper: vec4<u32>; | |
| var b_quantized_values: ${_}; | |
| var b_dequantized_values: ${_};`,e};return` | |
| var<workgroup> workgroup_shared: array<${m.type.value}, ${h*64}>; | |
| ${n.declareVariables(...s,m)} | |
| ${n.mainStart([64,1,1])} | |
| let output_indices = ${m.offsetToIndices(`(global_idx / 64) * ${h}`)}; | |
| let col = output_indices[2]; | |
| let row = output_indices[1]; | |
| let batch = output_indices[0]; | |
| let nBlocksPerCol = uniforms.b_shape[1]; | |
| for (var block = local_id.x; block < nBlocksPerCol; block += 64) { | |
| //process one block | |
| var word_offset: u32 = block * ${t.blockSize/d}; | |
| ${w()} | |
| for (var word: u32 = 0; word < ${l}; word += ${f}) { | |
| ${ee()} | |
| for (var i: u32 = 0; i < ${f}; i++) { | |
| ${C()} | |
| word_offset += ${x/d}; | |
| } | |
| } | |
| } | |
| workgroupBarrier(); | |
| if (local_id.x < ${h}) { | |
| var output_value: ${m.type.value} = ${m.type.value}(0); | |
| var workgroup_shared_offset: u32 = local_id.x; | |
| for (var b: u32 = 0u; b < 64u; b++) { | |
| output_value += workgroup_shared[workgroup_shared_offset]; | |
| workgroup_shared_offset += ${h}; | |
| } | |
| ${m.setByIndices(`${m.type.indices}(batch, row, col + local_id.x)`,`output_value`)}; | |
| } | |
| }`}}},Dc=(e,t)=>{let n=e[0].dims,r=n.length,i=n[r-2],a=t.k,o=t.n,s=n.slice(0,r-2),c=z.size(s),l=e[1].dims[2]/4,u=e[0].dataType,d=G(t.k),f=G(l),p=s.concat([i,o]),m=o%8==0?8:o%4==0?4:1,h=128/m,g=Math.floor(32/t.bits),_=h*f*g,v=_/d,y=_/t.blockSize,b=z.size(p)/m,x=[],S=[c,i,a/d],C=z.convertShape(e[1].dims).slice();C.splice(-1,1,l/f),x.push(...W(S)),x.push(...W(C)),x.push(...W(e[2].dims)),e.length===4&&x.push(...W(z.convertShape(e[3].dims)));let w=[c,i,o];return x.push(...W(w)),{name:`BlockwiseMatMulNBits32`,shaderCache:{hint:`${t.blockSize};${d};${f};${h};${m}`,inputDependencies:Array(e.length).fill(`rank`)},getRunData:()=>({outputs:[{dims:p,dataType:u}],dispatchGroup:{x:b},programUniforms:x}),getShaderSource:n=>{let r=S.length,i=q(`a`,e[0].dataType,r,d),a=q(`b`,12,C.length,f),o=q(`scales`,e[2].dataType,e[2].dims.length),s=[i,a,o],c=e.length===4?q(`zero_points`,12,e[3].dims.length):void 0;c&&s.push(c);let l=w.length,u=J(`output`,e[0].dataType,l),p=U(e[0].dataType),_=()=>{switch(d){case 1:return` | |
| let a_data0 = vec4<${p}>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]); | |
| let a_data1 = vec4<${p}>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);`;case 2:return` | |
| let a_data0 = vec4<${p}>(sub_a[word_offset], sub_a[word_offset + 1]); | |
| let a_data1 = vec4<${p}>(sub_a[word_offset + 2], sub_a[word_offset + 3]);`;case 4:return` | |
| let a_data0 = sub_a[word_offset]; | |
| let a_data1 = sub_a[word_offset + 1];`;default:throw Error(`${d}-component is not supported.`)}};return` | |
| var<workgroup> sub_a: array<${i.type.value}, ${v}>; | |
| var<workgroup> inter_results: array<array<${u.type.value}, ${h}>, ${m}>; | |
| ${n.declareVariables(...s,u)} | |
| ${n.mainStart([h,m,1])} | |
| let output_indices = ${u.offsetToIndices(`workgroup_index * ${m}`)}; | |
| let col = output_indices[2]; | |
| let row = output_indices[1]; | |
| let batch = output_indices[0]; | |
| let n_blocks_per_col = uniforms.b_shape[1]; | |
| let num_tiles = (n_blocks_per_col - 1) / ${y} + 1; | |
| // Loop over shared dimension. | |
| for (var tile: u32 = 0; tile < num_tiles; tile += 1) { | |
| let a_col_start = tile * ${v}; | |
| // load one tile A data into shared memory. | |
| for (var a_offset = local_idx; a_offset < ${v}; a_offset += 128) | |
| { | |
| let a_col = a_col_start + a_offset; | |
| if (a_col < uniforms.a_shape[2]) | |
| { | |
| sub_a[a_offset] = ${i.getByIndices(`${i.type.indices}(batch, row, a_col)`)}; | |
| } else { | |
| sub_a[a_offset] = ${i.type.value}(0); | |
| } | |
| } | |
| workgroupBarrier(); | |
| // each thread process one block | |
| let b_row = col + local_id.y; | |
| let block = tile * ${y} + local_id.x; | |
| ${c?` | |
| let zero_point_values_per_byte: u32 = ${Math.floor(8/t.bits)}u; | |
| let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte; | |
| let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte); | |
| let zero_point_word_index = zero_point_byte_count >> 0x2u; | |
| let zero_point_byte_offset = zero_point_byte_count & 0x3u; | |
| let zero_point_sub_offset: u32 = block % zero_point_values_per_byte; | |
| let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${t.bits}u); | |
| let zero_point_word = ${c.getByOffset(`zero_point_word_index`)} >> zero_point_bits_offset; | |
| let zero_point = ${p}((zero_point_word) & ${t.bits===2?`0x3u`:`0xFu`});`:` | |
| // The default zero point is ${2**(t.bits-1)} for unsigned ${t.bits}-bit quantization. | |
| let zero_point = ${p}(${(2**(t.bits-1)).toFixed(1)});`} | |
| let scale = ${o.getByOffset(`b_row * n_blocks_per_col + block`)}; | |
| let b_data = ${a.getByIndices(`${a.type.indices}(b_row, block, 0)`)}; | |
| var word_offset = local_id.x * ${t.blockSize/d}; | |
| for (var i: u32 = 0; i < ${f}; i++) { | |
| let b_value = ${f===1?`b_data`:`b_data[i]`}; | |
| ${(()=>{let e=Math.floor(g/8),n=``;for(let r=0;r<e;r++){let e=r*t.bits*4,i=e+t.bits;n+=` | |
| ${_()} | |
| {${t.bits===2?` | |
| let half_word = b_value >> ${r*16}u; | |
| let byte_lo = half_word & 0xFFu; | |
| let byte_hi = (half_word >> 8u) & 0xFFu; | |
| let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u); | |
| let b_value_lower = unpack4xU8(spread_word & 0x03030303u); | |
| let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);`:` | |
| let b_value_lower = unpack4xU8((b_value >> ${e}u) & 0x0F0F0F0Fu); | |
| let b_value_upper = unpack4xU8((b_value >> ${i}u) & 0x0F0F0F0Fu);`} | |
| let b_quantized_values = mat2x4<${p}>(${Array.from({length:4},(e,t)=>`${p}(b_value_lower[${t}]), ${p}(b_value_upper[${t}])`).join(`, `)}); | |
| let b_dequantized_values = (b_quantized_values - mat2x4<${p}>(${Array(8).fill(`zero_point`).join(`,`)})) * scale; | |
| inter_results[local_id.y][local_id.x] += ${Array.from({length:2},(e,t)=>`${`dot(a_data${t}, b_dequantized_values[${t}])`}`).join(` + `)}; | |
| } | |
| word_offset += ${8/d};`}return n})()} | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (local_idx < ${m}) { | |
| var output_value: ${u.type.value} = ${u.type.value}(0); | |
| for (var b = 0u; b < ${h}; b++) { | |
| output_value += inter_results[local_idx][b]; | |
| } | |
| if (col + local_idx < uniforms.output_shape[2]) | |
| { | |
| ${u.setByIndices(`${u.type.indices}(batch, row, col + local_idx)`,`output_value`)} | |
| } | |
| } | |
| }`}}},Oc=(e,t)=>{Tc(e.inputs,t),t.blockSize===32&&e.adapterInfo.isVendor(`intel`)&&e.adapterInfo.isArchitecture(`gen-12lp`)?e.compute(Dc(e.inputs,t)):e.compute(Ec(e.inputs,t))},kc=e=>V(e)}),jc,Mc,Nc,Pc,Fc,Ic,Lc,Rc,zc,Bc=o(()=>{L(),B(),Y(),jc=e=>{if(!e||e.length<1)throw Error(`Too few inputs`);if(e[0].dataType!==1&&e[0].dataType!==10)throw Error(`Input type must be float or float16.`);if(e.length>=2){let t=e[0].dims.length*2===e[1].dims[0];if(e.length===4&&(t=e[3].dims[0]*2===e[1].dims[0]),!t)throw Error(`The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].`)}},Mc=(e,t,n)=>{let r=``;for(let i=t-1;i>=0;--i)r+=` | |
| k = i32(${e.indicesGet(`indices`,i)}) - ${K(`uniforms.pads`,i,n)}; | |
| if (k < 0) { | |
| break; | |
| } | |
| if (k >= i32(${K(`uniforms.x_shape`,i,t)})) { | |
| break; | |
| } | |
| offset += k * i32(${K(`uniforms.x_strides`,i,t)}); | |
| `;return` | |
| value = ${e.type.value}(uniforms.constant_value); | |
| for (var i = 0; i < 1; i++) { | |
| var offset = 0; | |
| var k = 0; | |
| ${r} | |
| value = x[offset]; | |
| } | |
| `},Nc=(e,t,n)=>{let r=``;for(let i=t-1;i>=0;--i)r+=` | |
| k = i32(${e.indicesGet(`indices`,i)}) - ${K(`uniforms.pads`,i,n)}; | |
| if (k < 0) { | |
| k = -k; | |
| } | |
| { | |
| let _2n_1 = 2 * (i32(${K(`uniforms.x_shape`,i,t)}) - 1); | |
| k = k % _2n_1; | |
| if(k >= i32(${K(`uniforms.x_shape`,i,t)})) { | |
| k = _2n_1 - k; | |
| } | |
| } | |
| offset += k * i32(${K(`uniforms.x_strides`,i,t)}); | |
| `;return` | |
| var offset = 0; | |
| var k = 0; | |
| ${r} | |
| value = x[offset]; | |
| `},Pc=(e,t,n)=>{let r=``;for(let i=t-1;i>=0;--i)r+=` | |
| k = i32(${e.indicesGet(`indices`,i)}) - ${K(`uniforms.pads`,i,n)}; | |
| if (k < 0) { | |
| k = 0; | |
| } | |
| if (k >= i32(${K(`uniforms.x_shape`,i,t)})) { | |
| k = i32(${K(`uniforms.x_shape`,i,t)}) - 1; | |
| } | |
| offset += k * i32(${K(`uniforms.x_strides`,i,t)}); | |
| `;return` | |
| var offset = 0; | |
| var k = 0; | |
| ${r} | |
| value = x[offset]; | |
| `},Fc=(e,t,n)=>{let r=``;for(let i=t-1;i>=0;--i)r+=` | |
| k = i32(${e.indicesGet(`indices`,i)}) - ${K(`uniforms.pads`,i,n)}; | |
| if (k < 0) { | |
| k += i32(${K(`uniforms.x_shape`,i,t)}]); | |
| } | |
| if (k >= i32(${K(`uniforms.x_shape`,i,t)})) { | |
| k -= i32(${K(`uniforms.x_shape`,i,t)}); | |
| } | |
| offset += k * i32(${K(`uniforms.x_strides`,i,t)}); | |
| `;return` | |
| var offset = 0; | |
| var k = 0; | |
| ${r} | |
| value = x[offset]; | |
| `},Ic=(e,t,n)=>{switch(n.mode){case 0:return Mc(e,t,n.pads.length);case 1:return Nc(e,t,n.pads.length);case 2:return Pc(e,t,n.pads.length);case 3:return Fc(e,t,n.pads.length);default:throw Error(`Invalid mode`)}},Lc=(e,t)=>{let n=z.padShape(e[0].dims.slice(),t.pads),r=e[0].dims,i=[{type:12,data:z.size(n)},{type:6,data:t.pads}],a=e.length>=3&&e[2].data;return t.mode===0&&i.push({type:a?e[2].dataType:1,data:t.value}),i.push(...W(e[0].dims,n)),{name:`Pad`,shaderCache:{hint:`${t.mode}${a}`,inputDependencies:[`rank`]},getRunData:()=>({outputs:[{dims:n,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(z.size(n)/64)},programUniforms:i}),getShaderSource:i=>{let o=J(`output`,e[0].dataType,n.length),s=q(`x`,e[0].dataType,r.length),c=s.type.value,l=Ic(o,r.length,t),u=[{name:`output_size`,type:`u32`},{name:`pads`,type:`i32`,length:t.pads.length}];return t.mode===0&&u.push({name:`constant_value`,type:a?c:`f32`}),` | |
| ${i.registerUniforms(u).declareVariables(s,o)} | |
| ${i.mainStart()} | |
| ${i.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let indices = ${o.offsetToIndices(`global_idx`)}; | |
| var value = ${c}(0); | |
| ${l} | |
| output[global_idx] = value; | |
| }`}}},Rc=(e,t)=>{if(e.length>1){let n=e[1].getBigInt64Array(),r=e.length>=3&&e[2].data?e[2].dataType===10?e[2].getUint16Array()[0]:e[2].getFloat32Array()[0]:0,i=e[0].dims.length,a=new Int32Array(2*i).fill(0);if(e.length>=4){let t=e[3].getBigInt64Array();for(let e=0;e<t.length;e++)a[Number(t[e])]=Number(n[e]),a[Number(t[e])+i]=Number(n[e+t.length])}else n.forEach((e,t)=>a[Number(t)]=Number(e));let o=[];return a.forEach(e=>o.push(e)),{mode:t.mode,value:r,pads:o}}else return t},zc=(e,t)=>{jc(e.inputs);let n=Rc(e.inputs,t);e.compute(Lc(e.inputs,n),{inputs:[0]})}}),Vc,Hc,Uc,Wc,Gc,Kc,qc,Jc,Yc,Xc,Zc,Qc,$c,el,tl,nl,rl,il,al,ol=o(()=>{N(),L(),B(),Y(),Vc=e=>{if(S.webgpu.validateInputContent&&(!e||e.length!==1))throw Error(`Pool ops requires 1 input.`)},Hc=(e,t,n)=>{let r=t.format===`NHWC`,i=e.dims.slice();r&&i.splice(1,0,i.pop());let a=Object.hasOwnProperty.call(t,`dilations`),o=t.kernelShape.slice(),s=t.strides.slice(),c=a?t.dilations.slice():[],l=t.pads.slice();Bt.adjustPoolAttributes(n,i,o,s,c,l);let u=Bt.computePoolOutputShape(n,i,s,c,o,l,t.autoPad),d=Object.assign({},t);a?Object.assign(d,{kernelShape:o,strides:s,pads:l,dilations:c,cacheKey:t.cacheKey}):Object.assign(d,{kernelShape:o,strides:s,pads:l,cacheKey:t.cacheKey});let f=u.slice();return f.push(f.splice(1,1)[0]),[d,r?f:u]},Uc=(e,t)=>{let n=t.format===`NHWC`,r=z.size(e),i=z.size(t.kernelShape),a=[{type:12,data:r},{type:12,data:i}],o=[{name:`outputSize`,type:`u32`},{name:`kernelSize`,type:`u32`}];if(t.kernelShape.length<=2){let e=t.kernelShape[t.kernelShape.length-1],n=t.strides[t.strides.length-1],r=t.pads[t.pads.length/2-1],i=t.pads[t.pads.length-1],s=!!(r+i);a.push({type:12,data:e},{type:12,data:n},{type:12,data:r},{type:12,data:i}),o.push({name:`kw`,type:`u32`},{name:`sw`,type:`u32`},{name:`pwStart`,type:`u32`},{name:`pwEnd`,type:`u32`});let c=!1;if(t.kernelShape.length===2){let e=t.kernelShape[t.kernelShape.length-2],n=t.strides[t.strides.length-2],r=t.pads[t.pads.length/2-2],i=t.pads[t.pads.length-2];c=!!(r+i),a.push({type:12,data:e},{type:12,data:n},{type:12,data:r},{type:12,data:i}),o.push({name:`kh`,type:`u32`},{name:`sh`,type:`u32`},{name:`phStart`,type:`u32`},{name:`phEnd`,type:`u32`})}return[a,o,!0,s,c]}else{if(n)throw Error(`Pooling with kernelShape.length > 2 is not supported for NHWC format.`);let e=z.computeStrides(t.kernelShape);return a.push({type:12,data:e},{type:12,data:t.pads},{type:12,data:t.strides}),o.push({name:`kernelStrides`,type:`u32`,length:e.length},{name:`pads`,type:`u32`,length:t.pads.length},{name:`strides`,type:`u32`,length:t.strides.length}),[a,o,!!t.pads.reduce((e,t)=>e+t),!1,!1]}},Wc=(e,t,n,r,i,a,o,s,c,l,u,d)=>{let f=i.format===`NHWC`,p=t.type.value,m=J(`output`,t.type.tensor,r);if(i.kernelShape.length<=2){let r=``,l=``,h=``,g=n-(f?2:1);if(r=u?` | |
| for (var i: u32 = 0u; i < uniforms.kw; i++) { | |
| xIndices[${g}] = indices[${g}] * uniforms.sw - uniforms.pwStart + i; | |
| if (xIndices[${g}] < 0 || xIndices[${g}] | |
| >= uniforms.x_shape[${g}]) { | |
| pad++; | |
| continue; | |
| } | |
| let x_val = x[${t.indicesToOffset(`xIndices`)}]; | |
| ${a} | |
| }`:` | |
| for (var i: u32 = 0u; i < uniforms.kw; i++) { | |
| xIndices[${g}] = indices[${g}] * uniforms.sw - uniforms.pwStart + i; | |
| let x_val = x[${t.indicesToOffset(`xIndices`)}]; | |
| ${a} | |
| }`,i.kernelShape.length===2){let e=n-(f?3:2);l=d?` | |
| for (var j: u32 = 0u; j < uniforms.kh; j++) { | |
| xIndices[${e}] = indices[${e}] * uniforms.sh - uniforms.phStart + j; | |
| if (xIndices[${e}] < 0 || xIndices[${e}] >= uniforms.x_shape[${e}]) { | |
| pad += i32(uniforms.kw); | |
| continue; | |
| } | |
| `:` | |
| for (var j: u32 = 0u; j < uniforms.kh; j++) { | |
| xIndices[${e}] = indices[${e}] * uniforms.sh - uniforms.phStart + j; | |
| `,h=` | |
| } | |
| `}return` | |
| ${e.registerUniforms(c).declareVariables(t,m)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| let indices = ${m.offsetToIndices(`global_idx`)}; | |
| var xIndices = ${m.offsetToIndices(`global_idx`)}; | |
| var value = ${p}(${s}); | |
| var pad = 0; | |
| ${l} | |
| ${r} | |
| ${h} | |
| ${o} | |
| output[global_idx] = value; | |
| }`}else{if(f)throw Error(`Pooling with kernelShape.length > 2 is not supported for NHWC format.`);let r=i.kernelShape.length,u=i.pads.length,d=``;return d=l?` | |
| if (xIndices[j] >= uniforms.x_shape[j]) { | |
| pad++; | |
| isPad = true; | |
| break; | |
| } | |
| } | |
| if (!isPad) { | |
| let x_val = x[${t.indicesToOffset(`xIndices`)}]; | |
| ${a} | |
| }`:` | |
| } | |
| let x_val = x[${t.indicesToOffset(`xIndices`)}]; | |
| ${a} | |
| `,` | |
| ${e.registerUniforms(c).declareVariables(t,m)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| let indices = ${m.offsetToIndices(`global_idx`)}; | |
| var xIndices = ${m.offsetToIndices(`global_idx`)}; | |
| var offsets: array<u32, ${r}>; | |
| var value = ${p}(${s}); | |
| var pad = 0; | |
| var isPad = false; | |
| for (var i: u32 = 0u; i < uniforms.kernelSize; i++) { | |
| var offset = i; | |
| for (var j = 0u; j < ${r-1}u; j++) { | |
| offsets[j] = offset / ${K(`uniforms.kernelStrides`,`j`,r)}; | |
| offset -= offsets[j] * ${K(`uniforms.kernelStrides`,`j`,r)}; | |
| } | |
| offsets[${r-1}] = offset; | |
| isPad = false; | |
| for (var j = ${n-r}u; j < ${n}u; j++) { | |
| xIndices[j] = indices[j] * ${K(`uniforms.strides`,`j - ${n-r}u`,r)} | |
| + offsets[j - ${n-r}u] - ${K(`uniforms.pads`,`j - 2u`,u)}; | |
| ${d} | |
| } | |
| ${o} | |
| output[global_idx] = value; | |
| }`}},Gc=e=>`${e.format};${e.ceilMode};${e.autoPad};${e.kernelShape.length}`,Kc=e=>`${Gc(e)};${e.countIncludePad}`,qc=e=>`${Gc(e)};${e.storageOrder};${e.dilations}`,Jc=e=>({format:e.format,autoPad:[`NOTSET`,`VALID`,`SAME_UPPER`,`SAME_LOWER`][e.auto_pad],ceilMode:e.ceil_mode,kernelShape:e.kernel_shape,strides:e.strides,pads:e.pads}),Yc=(e,t,n,r)=>{let[i,a]=Hc(t,r,n),o=q(`x`,t.dataType,t.dims.length),s=o.type.value,c=``;i.countIncludePad?c+=`value /= ${s}(uniforms.kernelSize);`:c+=`value /= ${s}(i32(uniforms.kernelSize) - pad);`;let[l,u,d,f,p]=Uc(a,i);return l.push(...W(t.dims,a)),{name:e,shaderCache:{hint:`${r.cacheKey};${d};${f};${p}`,inputDependencies:[`rank`]},getRunData:()=>({outputs:[{dims:a,dataType:t.dataType}],dispatchGroup:{x:Math.ceil(z.size(a)/64)},programUniforms:l}),getShaderSource:e=>Wc(e,o,t.dims.length,a.length,i,`value += x_val;`,c,0,u,d,f,p)}},Xc=e=>{let t=e.count_include_pad!==0,n=Jc(e);if(n.ceilMode!==0)throw Error(`using ceil() in shape computation is not yet supported for AveragePool`);let r={countIncludePad:t,...n,cacheKey:``};return{...r,cacheKey:Kc(r)}},Zc=(e,t)=>{Vc(e.inputs),e.compute(Yc(`AveragePool`,e.inputs[0],!1,t))},Qc={autoPad:``,ceilMode:0,countIncludePad:!1,kernelShape:[],strides:[],pads:[],storageOrder:0,dilations:[]},$c=e=>{let t=e.format;return{format:t,...Qc,cacheKey:t}},el=(e,t)=>{Vc(e.inputs),e.compute(Yc(`GlobalAveragePool`,e.inputs[0],!0,t))},tl=(e,t,n,r)=>{let[i,a]=Hc(t,r,n),o=q(`x`,t.dataType,t.dims.length),s=[`rank`],[c,l,u,d,f]=Uc(a,i);return c.push(...W(t.dims,a)),{name:e,shaderCache:{hint:`${r.cacheKey};${u};${d};${f}`,inputDependencies:s},getRunData:()=>({outputs:[{dims:a,dataType:t.dataType}],dispatchGroup:{x:Math.ceil(z.size(a)/64)},programUniforms:c}),getShaderSource:e=>Wc(e,o,t.dims.length,a.length,i,` | |
| value = max(x_val, value); | |
| `,``,t.dataType===10?-65504:-1e5,l,u,d,f)}},nl=(e,t)=>{Vc(e.inputs),e.compute(tl(`MaxPool`,e.inputs[0],!1,t))},rl=e=>{let t=e.storage_order,n=e.dilations,r=Jc(e);if(t!==0)throw Error(`column major storage order is not yet supported for MaxPool`);if(r.ceilMode!==0)throw Error(`using ceil() in shape computation is not yet supported for MaxPool`);let i={storageOrder:t,dilations:n,...r,cacheKey:``};return{...i,cacheKey:qc(i)}},il=e=>{let t=e.format;return{format:t,...Qc,cacheKey:t}},al=(e,t)=>{Vc(e.inputs),e.compute(tl(`GlobalMaxPool`,e.inputs[0],!0,t))}}),sl,cl,ll,ul,dl=o(()=>{L(),B(),H(),Y(),sl=(e,t)=>{if(e.length<2||e.length>3)throw Error(`DequantizeLinear requires 2 or 3 inputs.`);if(e.length===3&&e[1].dims===e[2].dims)throw Error(`x-scale and x-zero-point must have the same shape.`);if(e.length===3&&e[0].dataType!==e[2].dataType)throw Error(`x and x-zero-point must have the same data type.`);if(e[0].dataType===6&&e.length>2)throw Error(`In the case of dequantizing int32 there is no zero point.`);if(e[1].dims.length!==0&&e[1].dims.length!==1&&e[1].dims.length!==e[0].dims.length)throw Error(`scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.`);if(e.length>2){if(e[0].dataType!==e[2].dataType)throw Error(`x and x-zero-point must have the same data type.`);if(e[1].dims.length!==e[2].dims.length)throw Error(`scale and zero-point inputs must have the same rank.`);if(!e[1].dims.map((t,n)=>t===e[2].dims[n]).reduce((e,t)=>e&&t,!0))throw Error(`scale and zero-point inputs must have the same shape.`)}if(t.blockSize>0){if(e[1].dims.length===0||e[1].dims.length===1&&e[1].dims[0]===1)throw Error(`blockSize must be set only for block quantization.`);if(!e[1].dims.map((n,r)=>r===t.axis||n===e[0].dims[r]).reduce((e,t)=>e&&t,!0))throw Error(`For block qunatization, scale input shape to match the input shape except for the axis`);if(e[1].dims.length!==e[0].dims.length)throw Error(`For block qunatization the scale input rank must be the same as the x rank.`);let n=e[0].dims[t.axis],r=e[1].dims[t.axis];if(t.blockSize<Math.ceil(n/r)||t.blockSize>Math.ceil(n/(r-1)-1))throw Error(`blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].`)}},cl=(e,t)=>{let n=z.normalizeAxis(t.axis,e[0].dims.length),r=e[0].dataType,i=r===3,a=e[0].dims,o=e[1].dataType,s=z.size(a),c=r===3||r===2,l=c?[Math.ceil(z.size(e[0].dims)/4)]:e[0].dims,u=e[1].dims,d=e.length>2?e[2]:void 0,f=d?c?[Math.ceil(z.size(d.dims)/4)]:d.dims:void 0,p=u.length===0||u.length===1&&u[0]===1,m=p===!1&&u.length===1,h=G(s),g=p&&(!c||h===4),_=g?h:1,v=g&&!c?h:1,y=q(`input`,c?12:r,l.length,v),b=q(`scale`,o,u.length),x=d?q(`zero_point`,c?12:r,f.length):void 0,S=J(`output`,o,a.length,_),C=[y,b];x&&C.push(x);let w=[l,u];d&&w.push(f);let ee=[{type:12,data:s/_},{type:12,data:n},{type:12,data:t.blockSize},...W(...w,a)];return{name:`DequantizeLinear`,shaderCache:{hint:t.cacheKey,inputDependencies:x?[`rank`,`rank`,`rank`]:[`rank`,`rank`]},getShaderSource:e=>` | |
| ${e.registerUniforms([{name:`output_size`,type:`u32`},{name:`axis`,type:`u32`},{name:`block_size`,type:`u32`}]).declareVariables(...C,S)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let output_indices = ${S.offsetToIndices(`global_idx`)}; | |
| // Set input x | |
| ${c?` | |
| let input = ${y.getByOffset(`global_idx / 4`)}; | |
| let x_vec = ${i?`unpack4xI8(input)`:`unpack4xU8(input)`}; | |
| let x_value = ${_===1?`x_vec[global_idx % 4]`:`x_vec`};`:`let x_value = ${y.getByOffset(`global_idx`)};`}; | |
| // Set scale input | |
| ${p?`let scale_value= ${b.getByOffset(`0`)}`:m?` | |
| let scale_index = ${S.indicesGet(`output_indices`,`uniforms.axis`)}; | |
| let scale_value= ${b.getByOffset(`scale_index`)};`:` | |
| var scale_indices: ${b.type.indices} = output_indices; | |
| let index = ${b.indicesGet(`scale_indices`,`uniforms.axis`)} / uniforms.block_size; | |
| ${b.indicesSet(`scale_indices`,`uniforms.axis`,`index`)}; | |
| let scale_value= ${b.getByIndices(`scale_indices`)};`}; | |
| // Set zero-point input | |
| ${x?p?c?` | |
| let zero_point_input = ${x.getByOffset(`0`)}; | |
| let zero_point_vec = ${i?`unpack4xI8(zero_point_input)`:`unpack4xU8(zero_point_input)`}; | |
| let zero_point_value= zero_point_vec[0]`:`let zero_point_value = ${x.getByOffset(`0`)}`:m?c?` | |
| let zero_point_index = ${S.indicesGet(`output_indices`,`uniforms.axis`)}; | |
| let zero_point_input = ${x.getByOffset(`zero_point_index / 4`)}; | |
| let zero_point_vec = ${i?`unpack4xI8(zero_point_input)`:`unpack4xU8(zero_point_input)`}; | |
| let zero_point_value = zero_point_vec[zero_point_index % 4]`:` | |
| let zero_point_index = ${S.indicesGet(`output_indices`,`uniforms.axis`)}; | |
| let zero_point_value = ${x.getByOffset(`zero_point_index`)};`:c?` | |
| let zero_point_offset = ${b.indicesToOffset(`scale_indices`)}; | |
| let zero_point_input = ${x.getByOffset(`zero_point_offset / 4`)}; | |
| let zero_point_vec = ${i?`unpack4xI8(zero_point_input)`:`unpack4xU8(zero_point_input)`}; | |
| let zero_point_value = zero_point_vec[zero_point_offset % 4];`:`let zero_point_value = ${x.getByIndices(`scale_indices`)};`:`let zero_point_value = ${c?i?`i32`:`u32`:y.type.value}(0);`}; | |
| // Compute and write output | |
| ${S.setByOffset(`global_idx`,`${S.type.value}(x_value - zero_point_value) * scale_value`)}; | |
| }`,getRunData:()=>({outputs:[{dims:a,dataType:o}],dispatchGroup:{x:Math.ceil(s/_/64),y:1,z:1},programUniforms:ee})}},ll=(e,t)=>{sl(e.inputs,t),e.compute(cl(e.inputs,t))},ul=e=>V({axis:e.axis,blockSize:e.blockSize})}),fl,pl,ml,hl=o(()=>{N(),L(),Y(),fl=(e,t,n)=>{if(e===t||e<t&&n<0||e>t&&n>0)throw Error(`Range these inputs' contents are invalid.`)},pl=(e,t,n,r)=>{let i=Math.abs(Math.ceil((t-e)/n)),a=[i],o=i,s=[{type:12,data:o},{type:r,data:e},{type:r,data:n},...W(a)];return{name:`Range`,shaderCache:{hint:`${r}`},getShaderSource:e=>{let t=J(`output`,r,a.length),n=t.type.value,i=[{name:`outputSize`,type:`u32`},{name:`start`,type:n},{name:`delta`,type:n}];return` | |
| ${e.registerUniforms(i).declareVariables(t)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| output[global_idx] = uniforms.start + ${n}(global_idx) * uniforms.delta; | |
| }`},getRunData:()=>({outputs:[{dims:a,dataType:r}],dispatchGroup:{x:Math.ceil(o/64)},programUniforms:s})}},ml=e=>{let t=0,n=0,r=0;e.inputs[0].dataType===6?(t=e.inputs[0].getInt32Array()[0],n=e.inputs[1].getInt32Array()[0],r=e.inputs[2].getInt32Array()[0]):e.inputs[0].dataType===1&&(t=e.inputs[0].getFloat32Array()[0],n=e.inputs[1].getFloat32Array()[0],r=e.inputs[2].getFloat32Array()[0]),S.webgpu.validateInputContent&&fl(t,n,r),e.compute(pl(t,n,r,e.inputs[0].dataType),{inputs:[]})}}),gl,_l,vl,yl,bl=o(()=>{L(),B(),H(),Y(),gl=(e,t,n,r)=>{if(e!==`none`&&r!==`i32`&&r!==`u32`&&r!==`f32`)throw Error(`Input ${r} is not supported with reduction ${e}.`);let i=`{ | |
| var oldValue = 0; | |
| loop { | |
| let newValueF32 =`,a=`; | |
| let newValue = bitcast<i32>(newValueF32); | |
| let res = atomicCompareExchangeWeak(&${t}, oldValue, newValue); | |
| if res.exchanged { | |
| break; | |
| } | |
| oldValue = res.old_value; | |
| } | |
| }`;switch(e){case`none`:return`${t}=${n};`;case`add`:return r===`i32`||r===`u32`?`atomicAdd(&${t}, bitcast<${r}>(${n}));`:` | |
| ${i}bitcast<${r}>(oldValue) + (${n})${a}`;case`max`:return r===`i32`||r===`u32`?`atomicMax(&${t}, bitcast<${r}>(${n}));`:` | |
| ${i}max(bitcast<f32>(oldValue), (${n}))${a}`;case`min`:return r===`i32`||r===`u32`?`atomicMin(&${t}, bitcast<${r}>(${n}));`:`${i}min(bitcast<${r}>(oldValue), (${n}))${a}`;case`mul`:return`${i}(bitcast<${r}>(oldValue) * (${n}))${a}`;default:throw Error(`Reduction ${e} is not supported.`)}},_l=(e,t)=>{let n=e[0].dims,r=e[1].dims,i=n,a=Math.ceil(z.sizeToDimension(r,r.length-1)/1),o=r[r.length-1],s=z.sizeFromDimension(n,o),c=[{type:12,data:a},{type:12,data:o},{type:12,data:s},...W(e[1].dims,e[2].dims,i)];return{name:`ScatterND`,shaderCache:{hint:`${t.cacheKey}_${t.reduction}`,inputDependencies:[`rank`,`rank`]},getRunData:()=>({outputs:[{dims:i,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(a/64)},programUniforms:c}),getShaderSource:n=>{let r=q(`indices`,e[1].dataType,e[1].dims.length),a=q(`updates`,e[2].dataType,e[2].dims.length,1),o=t.reduction!==`none`&&t.reduction!==``?On(`output`,e[0].dataType,i.length):J(`output`,e[0].dataType,i.length,1);return` | |
| ${n.registerUniform(`output_size`,`u32`).registerUniform(`last_index_dimension`,`u32`).registerUniform(`num_updates_elements`,`u32`).declareVariables(r,a,o)} | |
| ${n.mainStart()} | |
| ${n.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| var data_offset = 0u; | |
| let indices_start = uniforms.last_index_dimension * global_idx; | |
| let indices_end = indices_start + uniforms.last_index_dimension; | |
| for (var i = indices_start; i < indices_end; i++) { | |
| var index = i32(indices[i].x); | |
| ${e[0].dims.length===1?` | |
| let element_count_dim = uniforms.output_strides; | |
| let dim_value = uniforms.output_shape;`:` | |
| let element_count_dim = uniforms.output_strides[i - indices_start]; | |
| let dim_value = uniforms.output_shape[i - indices_start];`} | |
| if (index >= 0) { | |
| if (index >= i32(dim_value)) { | |
| index = i32(dim_value - 1); | |
| } | |
| } else { | |
| if (index < -i32(dim_value)) { | |
| index = 0; | |
| } else { | |
| index += i32(dim_value); | |
| } | |
| } | |
| data_offset += u32((u32(index) * element_count_dim)); | |
| } | |
| for (var i = 0u; i < uniforms.num_updates_elements; i++) { | |
| let value = updates[uniforms.num_updates_elements * global_idx + i]; | |
| ${gl(t.reduction,`output[data_offset + i]`,`value`,o.type.value)} | |
| } | |
| }`}}},vl=e=>V({reduction:e.reduction}),yl=(e,t)=>{e.compute(_l(e.inputs,t),{inputs:[e.inputs[1],e.inputs[2]],outputs:[]})}}),xl,Sl,Cl,wl,Tl,El,Dl,Ol,kl,Al,jl,Ml,Nl,Pl,Fl,Il,Ll,Rl,zl,Bl,Vl=o(()=>{L(),B(),H(),Y(),xl=(e,t)=>{if(e.every(e=>e>0||(()=>{throw Error(`Resize requires scales input values to be positive`)})),e.length>0){if(t.mode===`linear`){if(!(e.length===2||e.length===3||e.length===4&&e[0]===1&&e[1]===1||e.length===4&&e[0]===1&&e[3]===1||e.length===5&&e[0]===1&&e[1]===1))throw Error(`For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and | |
| one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`)}else if(t.mode===`cubic`&&!(e.length===2||e.length===4&&e[0]===1&&e[1]===1||e.length===4&&e[0]===1&&e[3]===1))throw Error(`Resize requires scales input size to be 2 or 4 for cubic mode`)}},Sl=(e,t,n)=>{t.every(e=>e>=0&&e<n||(()=>{throw Error(`Resize requires axes input values to be positive and less than rank`)}));let r=Array(n).fill(1);return t.forEach((t,n)=>r[t]=e[n]),r},Cl=(e,t,n,r,i,a)=>{let[o,s,c]=n>10?[1,2,3]:[-1,e.length>1?1:-1,-1],l=e[0].dims.length;if(o>0&&e.length>o&&e[o].dims.length>0)e[o].getFloat32Array().forEach(e=>a.push(e));else if(t.coordinateTransformMode===`tf_crop_and_resize`)throw Error(`Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize`);if(s>0&&e.length>s&&e[s].dims.length===1&&e[s].dims[0]>0){if(e[s].getFloat32Array().forEach(e=>r.push(e)),r.length!==0&&r.length!==l&&n>=18&&r.length!==t.axes.length)throw Error(`Resize requires scales input size to be same as input rank or axes size for opset 18 and up`);xl(r,t),t.axes.length>0&&Sl(r,t.axes,l).forEach((e,t)=>r[t]=e)}if(c>0&&e.length>c&&e[c].dims.length===1&&e[c].dims[0]>0&&(e[c].getBigInt64Array().forEach(e=>i.push(Number(e))),i.length!==0&&i.length!==l&&n>=18&&i.length!==t.axes.length))throw Error(`Resize requires sizes input size to be same as input rank or axes size for opset 18 and up`);if(t.axes.length>0){if(r.length!==0&&r.length!==t.axes.length)throw Error(`Resize requires "scales" input size to be of axes rank when axes attributes is specified`);if(i.length!==0&&i.length!==t.axes.length)throw Error(`Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified`)}if(typeof r<`u`&&typeof i<`u`&&r.length>0&&i.length>l)throw Error(`Resize requires only of scales or sizes to be specified`)},wl=(e,t,n,r)=>` | |
| // The whole part and the fractional part are calculated separately due to inaccuracy of floating | |
| // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an | |
| // offset-by-one error later in floor(). | |
| let big = (${e}) * (${t}); | |
| let whole = ${r}(big / (${n})); | |
| let fract = ${r}(big % (${n})) / ${r}(${n}); | |
| return whole + fract; | |
| `,Tl=(e,t)=>`fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32, | |
| lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${t} { `+(()=>{switch(e){case`asymmetric`:return` | |
| if (xScale < 1.0 || floor(xScale) != xScale) { | |
| return ${t}(xResized) / ${t}(xScale); | |
| } else { | |
| ${wl(`xResized`,`lengthOriginal`,`lengthResized`,t)} | |
| } | |
| `;case`pytorch_half_pixel`:return`if (lengthResized > 1) { | |
| return (${t}(xResized) + 0.5) / ${t}(xScale) - 0.5; | |
| } else { | |
| return 0.0; | |
| }`;case`tf_half_pixel_for_nn`:return`return (${t}(xResized) + 0.5) / ${t}(xScale);`;case`align_corners`:return`if (lengthResized == 1) { | |
| return 0.0; | |
| } else { | |
| ${wl(`xResized`,`lengthOriginal - 1`,`lengthResized - 1`,t)} | |
| }`;case`tf_crop_and_resize`:return`if (lengthResized > 1) { | |
| return ${t}(roiStart) * ${t}(lengthOriginal - 1) + | |
| (${t}(xResized) * ${t}(roiEnd - roiStart) * ${t}(lengthOriginal - 1)) / | |
| ${t}(lengthResized - 1); | |
| } else { | |
| return 0.5 * ${t}(roiStart + roiEnd) * ${t}(lengthOriginal - 1); | |
| }`;case`half_pixel_symmetric`:return`const outputWidth = ${t}xScale * ${t}(lengthResized); | |
| const adjustment = ${t}(lengthResized) / outputWidth; | |
| const center = ${t}(lengthOriginal) / 2; | |
| const offset = center * (1 - adjustment); | |
| return offset + ((${t}(xResized) + 0.5) / ${t}(xScale)) - 0.5;`;case`half_pixel`:return`return ((${t}(xResized) + 0.5) / ${t}(xScale)) - 0.5;`;default:throw Error(`Coordinate transform mode ${e} is not supported`)}})()+`}`,El=(e,t,n)=>`fn getNearestPixelFromOriginal(xOriginal: ${n}, isDownSample: bool) -> ${n} {`+(()=>{switch(e){case`round_prefer_ceil`:return`if (fract(xOriginal) == 0.5) { return ceil(xOriginal); } else { return round(xOriginal); }`;case`floor`:return`return floor(xOriginal);`;case`ceil`:return`return ceil(xOriginal);`;case`round_prefer_floor`:return`if (fract(xOriginal) == 0.5) { return floor(xOriginal); } else { return round(xOriginal); }`;default:if(t<11)return`if (isDownSample) { return ceil(xOriginal); } else { return xOriginal; }`;throw Error(`Nearest mode ${e} is not supported`)}})()+`}`,Dl=(e,t,n)=>{let r=Array(n).fill(0).concat(Array(n).fill(1)),i=e.length===0?r:e.slice();return t.length>0?(t.forEach((e,a)=>{r[e]=i[a],r[a+n]=i[t.length+a]}),r):i},Ol=(e,t,n,r)=>{let i=[];if(n.length>0)if(r.length>0){if(e.forEach(e=>i.push(e)),Math.max(...r)>e.length)throw Error(`axes is out of bound`);r.forEach((e,t)=>i[e]=n[t])}else n.forEach(e=>i.push(e));else{if(t.length===0)throw Error(`Resize requires either scales or sizes.`);i=e.map((e,n)=>Math.round(e*t[n]))}return i},kl=(e,t,n)=>{let r=(()=>{switch(n.keepAspectRatioPolicy){case`not_larger`:return n.axes.length>0?Math.min(...n.axes.map(e=>t[e]),Number.MAX_VALUE):Math.min(...t,Number.MAX_VALUE);case`not_smaller`:return n.axes.length>0?Math.max(...n.axes.map(e=>t[e]),Number.MIN_VALUE):Math.max(...t,Number.MIN_VALUE);default:throw Error(`Keep aspect ratio policy ${n.keepAspectRatioPolicy} is not supported`)}})();t.fill(1,0,t.length);let i=e.slice();return n.axes.length>0?(n.axes.forEach(e=>t[e]=r),n.axes.forEach(n=>i[n]=Math.round(e[n]*t[n]))):(t.fill(r,0,t.length),i.forEach((e,n)=>i[n]=Math.round(e*t[n]))),i},Al=(e,t,n,r,i)=>` | |
| fn calculateOriginalIndicesFromOutputIndices(output_indices: ${e.type.indices}) -> array<${e.type.value}, ${n.length}> { | |
| var original_indices: array<${e.type.value}, ${n.length}>; | |
| for (var i:u32 = 0; i < ${n.length}; i++) { | |
| var output_index = ${e.indicesGet(`output_indices`,`i`)}; | |
| var scale = ${K(`uniforms.scales`,`i`,r)}; | |
| var roi_low = ${K(`uniforms.roi`,`i`,i)}; | |
| var roi_hi = ${K(`uniforms.roi`,`i + ${t.length}`,i)}; | |
| if (scale == 1.0) { | |
| original_indices[i] = ${e.type.value}(output_index); | |
| } else { | |
| var input_shape_i = ${K(`uniforms.input_shape`,`i`,t.length)}; | |
| var output_shape_i = ${K(`uniforms.output_shape`,`i`,n.length)}; | |
| original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, | |
| input_shape_i, roi_low, roi_hi); | |
| } | |
| } | |
| return original_indices; | |
| }`,jl=(e,t,n,r,i,a,o)=>` | |
| fn calculateInputIndicesFromOutputIndices(output_indices: ${t.type.indices}) -> ${e.type.indices} { | |
| var input_indices: ${e.type.indices}; | |
| for (var i:u32 = 0; i < ${r.length}; i++) { | |
| var output_index = ${t.indicesGet(`output_indices`,`i`)}; | |
| var input_index: u32; | |
| var scale = ${K(`uniforms.scales`,`i`,i)}; | |
| if (scale == 1.0) { | |
| input_index = output_index; | |
| } else { | |
| var roi_low = ${K(`uniforms.roi`,`i`,a)}; | |
| var roi_hi = ${K(`uniforms.roi`,`i + ${n.length}`,a)}; | |
| var input_shape_i = ${K(`uniforms.input_shape`,`i`,n.length)}; | |
| var output_shape_i = ${K(`uniforms.output_shape`,`i`,r.length)}; | |
| var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, | |
| input_shape_i, roi_low, roi_hi); | |
| if (!${o} || (original_idx >= 0 && original_idx < ${t.type.value}(input_shape_i))) { | |
| if (original_idx < 0) { | |
| input_index = 0; | |
| } else if (original_idx > ${t.type.value}(input_shape_i - 1)) { | |
| input_index = input_shape_i - 1; | |
| } else { | |
| input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1)); | |
| } | |
| } else { | |
| input_index = u32(original_idx); | |
| } | |
| } | |
| ${e.indicesSet(`input_indices`,`i`,`input_index`)} | |
| } | |
| return input_indices; | |
| }`,Ml=(e,t)=>` | |
| fn checkInputIndices(input_indices: ${e.type.indices}) -> bool { | |
| for (var i:u32 = 0; i < ${t.length}; i++) { | |
| var input_index = ${e.indicesGet(`input_indices`,`i`)}; | |
| if (input_index < 0 || input_index >= ${K(`uniforms.input_shape`,`i`,t.length)}) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| }`,Nl=(e,t,n,r)=>e.rank>r?` | |
| ${e.indicesSet(`input_indices`,t,`channel`)}; | |
| ${e.indicesSet(`input_indices`,n,`batch`)}; | |
| `:``,Pl=(e,t,n,r,i)=>{let[a,o,s,c]=n.length===2?[-1,0,1,-1]:[0,2,3,1],l=e.type.value;return` | |
| fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${l} { | |
| var input_indices: ${e.type.indices}; | |
| ${e.indicesSet(`input_indices`,o,`max(0, min(row, ${n[o]} - 1))`)}; | |
| ${e.indicesSet(`input_indices`,s,`max(0, min(col, ${n[s]} - 1))`)}; | |
| ${Nl(e,c,a,2)} | |
| return ${e.getByIndices(`input_indices`)}; | |
| } | |
| fn bilinearInterpolation(output_indices: ${t.type.indices}) -> ${l} { | |
| var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); | |
| var row:${l} = originalIndices[${o}]; | |
| var col:${l} = originalIndices[${s}]; | |
| ${r?`if (row < 0 || row > (${n[o]} - 1) || col < 0 || col > (${n[s]} - 1)) { | |
| return ${i}; | |
| }`:``}; | |
| row = max(0, min(row, ${n[o]} - 1)); | |
| col = max(0, min(col, ${n[s]} - 1)); | |
| var row1: u32 = u32(row); | |
| var col1: u32 = u32(col); | |
| var row2: u32 = u32(row + 1); | |
| var col2: u32 = u32(col + 1); | |
| var channel: u32 = ${n.length>2?`u32(originalIndices[${c}])`:`0`}; | |
| var batch: u32 = ${n.length>2?`u32(originalIndices[${a}])`:`0`}; | |
| var x11: ${l} = getInputValue(batch, channel, row1, col1); | |
| var x12: ${l} = getInputValue(batch, channel, row1, col2); | |
| var x21: ${l} = getInputValue(batch, channel, row2, col1); | |
| var x22: ${l} = getInputValue(batch, channel, row2, col2); | |
| var dx1: ${l} = abs(row - ${l}(row1)); | |
| var dx2: ${l} = abs(${l}(row2) - row); | |
| var dy1: ${l} = abs(col - ${l}(col1)); | |
| var dy2: ${l} = abs(${l}(col2) - col); | |
| if (row1 == row2) { | |
| dx1 = 0.5; | |
| dx2 = 0.5; | |
| } | |
| if (col1 == col2) { | |
| dy1 = 0.5; | |
| dy2 = 0.5; | |
| } | |
| return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); | |
| }`},Fl=(e,t,n,r,i,a,o,s,c,l)=>{let[u,d]=n.length===2?[0,1]:[2,3],f=e.type.value,p=o=>{let d=o===u?`row`:`col`;return` | |
| fn ${d}CubicInterpolation(input_indices: ${e.type.indices}, output_indices: ${t.type.indices}) -> ${f} { | |
| var output_index = ${t.indicesGet(`output_indices`,o)}; | |
| var originalIdx: ${f} = getOriginalCoordinateFromResizedCoordinate(output_index, ${i[o]}, | |
| ${r[o]}, ${n[o]}, ${a[o]}, ${a[o]} + ${n.length}); | |
| var fractOriginalIdx: ${f} = originalIdx - floor(originalIdx); | |
| var coefs = getCubicInterpolationCoefs(fractOriginalIdx); | |
| if (${s} && (originalIdx < 0 || originalIdx > (${n[o]} - 1))) { | |
| return ${c}; | |
| } | |
| var data: array<${f}, 4> = array<${f}, 4>(0.0, 0.0, 0.0, 0.0); | |
| for (var i: i32 = -1; i < 3; i++) { | |
| var ${d}: ${f} = originalIdx + ${f}(i); | |
| if (${d} < 0 || ${d} >= ${n[o]}) { | |
| ${l?`coefs[i + 1] = 0.0; | |
| continue;`:s?`return ${c};`:`${d} = max(0, min(${d}, ${n[o]} - 1));`}; | |
| } | |
| var input_indices_copy: ${e.type.indices} = input_indices; | |
| ${e.indicesSet(`input_indices_copy`,o,`u32(${d})`)}; | |
| data[i + 1] = ${o===u?e.getByIndices(`input_indices_copy`):`rowCubicInterpolation(input_indices_copy, output_indices)`}; | |
| } | |
| return cubicInterpolation1D(data, coefs); | |
| }`};return` | |
| ${p(u)}; | |
| ${p(d)}; | |
| fn getCubicInterpolationCoefs(s: ${f}) -> array<${f}, 4> { | |
| var absS = abs(s); | |
| var coeffs: array<${f}, 4> = array<${f}, 4>(0.0, 0.0, 0.0, 0.0); | |
| var oneMinusAbsS: ${f} = 1.0 - absS; | |
| var twoMinusAbsS: ${f} = 2.0 - absS; | |
| var onePlusAbsS: ${f} = 1.0 + absS; | |
| coeffs[0] = ((${o} * onePlusAbsS - 5 * ${o}) * onePlusAbsS + 8 * ${o}) * onePlusAbsS - 4 * ${o}; | |
| coeffs[1] = ((${o} + 2) * absS - (${o} + 3)) * absS * absS + 1; | |
| coeffs[2] = ((${o} + 2) * oneMinusAbsS - (${o} + 3)) * oneMinusAbsS * oneMinusAbsS + 1; | |
| coeffs[3] = ((${o} * twoMinusAbsS - 5 * ${o}) * twoMinusAbsS + 8 * ${o}) * twoMinusAbsS - 4 * ${o}; | |
| return coeffs; | |
| } | |
| fn cubicInterpolation1D(x: array<${f}, 4>, coefs: array<${f}, 4>) -> ${f} { | |
| var coefsSum: ${f} = coefs[0] + coefs[1] + coefs[2] + coefs[3]; | |
| return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; | |
| } | |
| fn bicubicInterpolation(output_indices: ${t.type.indices}) -> ${f} { | |
| var input_indices: ${e.type.indices} = output_indices; | |
| return colCubicInterpolation(input_indices, output_indices); | |
| } | |
| `},Il=(e,t,n,r,i)=>{let[a,o,s,c,l]=n.length===3?[-1,0,1,2,-1]:[0,2,3,4,1],u=e.type.value;return` | |
| fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${u} { | |
| var input_indices: ${e.type.indices}; | |
| ${e.indicesSet(`input_indices`,o,`max(0, min(depth, ${n[o]} - 1))`)}; | |
| ${e.indicesSet(`input_indices`,s,`max(0, min(height, ${n[s]} - 1))`)}; | |
| ${e.indicesSet(`input_indices`,c,`max(0, min(width, ${n[c]} - 1))`)}; | |
| ${Nl(e,l,a,3)} | |
| return ${e.getByIndices(`input_indices`)}; | |
| } | |
| fn trilinearInterpolation(output_indices: ${t.type.indices}) -> ${u} { | |
| var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); | |
| var depth:${u} = originalIndices[${o}]; | |
| var height:${u} = originalIndices[${s}]; | |
| var width:${u} = originalIndices[${c}]; | |
| ${r?`if (depth < 0 || depth > (${n[o]} - 1) || height < 0 || height > (${n[s]} - 1) || width < 0 || (width > ${n[c]} - 1)) { | |
| return ${i}; | |
| }`:``}; | |
| depth = max(0, min(depth, ${n[o]} - 1)); | |
| height = max(0, min(height, ${n[s]} - 1)); | |
| width = max(0, min(width, ${n[c]} - 1)); | |
| var depth1: u32 = u32(depth); | |
| var height1: u32 = u32(height); | |
| var width1: u32 = u32(width); | |
| var depth2: u32 = u32(depth + 1); | |
| var height2: u32 = u32(height + 1); | |
| var width2: u32 = u32(width + 1); | |
| var channel: u32 = ${n.length>3?`u32(originalIndices[${l}])`:`0`}; | |
| var batch: u32 = ${n.length>3?`u32(originalIndices[${a}])`:`0`}; | |
| var x111: ${u} = getInputValue(batch, channel, depth1, height1, width1); | |
| var x112: ${u} = getInputValue(batch, channel, depth1, height1, width2); | |
| var x121: ${u} = getInputValue(batch, channel, depth1, height2, width1); | |
| var x122: ${u} = getInputValue(batch, channel, depth1, height2, width2); | |
| var x211: ${u} = getInputValue(batch, channel, depth2, height1, width1); | |
| var x212: ${u} = getInputValue(batch, channel, depth2, height1, width2); | |
| var x221: ${u} = getInputValue(batch, channel, depth2, height2, width1); | |
| var x222: ${u} = getInputValue(batch, channel, depth2, height2, width2); | |
| var dx1: ${u} = abs(depth - ${u}(depth1)); | |
| var dx2: ${u} = abs(${u}(depth2) - depth); | |
| var dy1: ${u} = abs(height - ${u}(height1)); | |
| var dy2: ${u} = abs(${u}(height2) - height); | |
| var dz1: ${u} = abs(width - ${u}(width1)); | |
| var dz2: ${u} = abs(${u}(width2) - width); | |
| if (depth1 == depth2) { | |
| dx1 = 0.5; | |
| dx2 = 0.5; | |
| } | |
| if (height1 == height2) { | |
| dy1 = 0.5; | |
| dy2 = 0.5; | |
| } | |
| if (width1 == width2) { | |
| dz1 = 0.5; | |
| dz2 = 0.5; | |
| } | |
| return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 + | |
| x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1); | |
| }`},Ll=(e,t,n,r,i,a)=>{let o=e.dims,s=Dl(a,t.axes,o.length),c=Ol(o,r,i,t.axes),l=r.slice();r.length===0&&(l=o.map((e,t)=>e===0?1:c[t]/e),t.keepAspectRatioPolicy!==`stretch`&&(c=kl(o,l,t)));let u=J(`output`,e.dataType,c.length),d=q(`input`,e.dataType,o.length),f=z.size(c),p=o.length===c.length&&o.every((e,t)=>e===c[t]),m=t.coordinateTransformMode===`tf_crop_and_resize`,h=t.extrapolationValue,g=d.type.value;return{name:`Resize`,shaderCache:{hint:`${t.cacheKey}|${n}|${l.length>0?t.mode===`cubic`?l:l.length:``}|${i.length>0?i:``}|${s.length>0?s:``}|${p}|${t.mode===`nearest`?o.length:o}`,inputDependencies:[`rank`]},getShaderSource:e=>` | |
| ${p?``:` | |
| ${Tl(t.coordinateTransformMode,g)}; | |
| ${(()=>{switch(t.mode){case`nearest`:return` | |
| ${Ml(d,o)}; | |
| ${El(t.nearestMode,n,g)}; | |
| ${jl(d,u,o,c,l.length,s.length,m)}; | |
| `;case`linear`:return` | |
| ${Al(u,o,c,l.length,s.length)}; | |
| ${(()=>{if(o.length===2||o.length===4)return`${Pl(d,u,o,m,h)}`;if(o.length===3||o.length===5)return`${Il(d,u,o,m,h)}`;throw Error(`Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.`)})()}; | |
| `;case`cubic`:return` | |
| ${(()=>{if(o.length===2||o.length===4)return`${Fl(d,u,o,c,l,s,t.cubicCoeffA,m,t.extrapolationValue,t.excludeOutside)}`;throw Error(`Cubic mode only supports input dims 2 and 4 are supported in linear mode.`)})()}; | |
| `;default:throw Error(`Invalid resize mode`)}})()}; | |
| `} | |
| ${e.registerUniform(`output_size`,`u32`).registerUniform(`scales`,`f32`,l.length).registerUniform(`roi`,`f32`,s.length).declareVariables(d,u)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| ${p?`output[global_idx] = input[global_idx];`:` | |
| let output_indices = ${u.offsetToIndices(`global_idx`)}; | |
| var input_indices: ${d.type.indices}; | |
| ${(()=>{switch(t.mode){case`nearest`:return`input_indices = calculateInputIndicesFromOutputIndices(output_indices); | |
| if (checkInputIndices(input_indices)) { | |
| output[global_idx] = ${d.getByIndices(`input_indices`)}; | |
| } else { | |
| output[global_idx] = ${t.extrapolationValue}; | |
| }`;case`linear`:return`output[global_idx] = ${o.length===2||o.length===4?`bilinearInterpolation`:`trilinearInterpolation`}(output_indices);`;case`cubic`:return`output[global_idx] = bicubicInterpolation(output_indices);`;default:throw Error(`Unsupported resize mode: ${t.mode}`)}})()}; | |
| `} | |
| }`,getRunData:()=>({outputs:[{dims:c,dataType:e.dataType}],dispatchGroup:{x:Math.ceil(f/64)},programUniforms:[{type:12,data:f},{type:1,data:l},{type:1,data:s},...W(o,c)]})}},Rl=e=>{let t=e.customDataBuffer;return new Uint32Array(t,t.byteOffset,1)[0]},zl=(e,t)=>{let n=[],r=[],i=[],a=Rl(e);if(t.antialias!==0)throw Error(`Only default value (0) for Antialias attribute is supported`);Cl(e.inputs,t,a,n,r,i),e.compute(Ll(e.inputs[0],t,a,n,r,i),{inputs:[0]})},Bl=e=>{let t=e.antialias,n=e.axes,r=e.coordinateTransformMode,i=e.cubicCoeffA,a=e.excludeOutside!==0,o=e.extrapolationValue,s=e.keepAspectRatioPolicy,c=e.mode,l=e.nearestMode===``?`simple`:e.nearestMode;return V({antialias:t,axes:n,coordinateTransformMode:r,cubicCoeffA:i,excludeOutside:a,extrapolationValue:o,keepAspectRatioPolicy:s,mode:c,nearestMode:l})}}),Hl,Ul,Wl,Gl=o(()=>{L(),B(),Y(),Hl=e=>{if(!e||e.length<3)throw Error(`layerNorm requires at least 3 inputs.`);let t=e[0],n=e[1],r=e[2];if(t.dataType!==n.dataType||t.dataType!==r.dataType)throw Error(`All inputs must have the same data type`);if(t.dims.length!==3&&t.dims.length!==2)throw Error(`Input must be 2D or 3D`);if(n.dims.length!==3&&n.dims.length!==2)throw Error(`Skip must be 2D or 3D`);let i=t.dims[t.dims.length-1],a=t.dims[t.dims.length-2];if(n.dims[n.dims.length-1]!==i)throw Error(`Skip must have the same hidden size as input`);if(n.dims[n.dims.length-2]!==a)throw Error(`Skip must have the same sequence length as input`);if(r.dims.length!==1)throw Error(`Gamma must be 1D`);if(r.dims[r.dims.length-1]!==i)throw Error(`Gamma must have the same hidden size as input`);if(e.length>3){let t=e[3];if(t.dims.length!==1)throw Error(`Beta must be 1D`);if(t.dims[t.dims.length-1]!==i)throw Error(`Beta must have the same hidden size as input`)}if(e.length>4){let t=e[4];if(t.dims.length!==1)throw Error(`Bias must be 1D`);if(t.dims[t.dims.length-1]!==i)throw Error(`Bias must have the same hidden size as input`)}},Ul=(e,t,n,r)=>{let i=t.simplified,a=e[0].dims,o=z.size(a),s=a,c=o,l=a.slice(-1)[0],u=r?a.slice(0,-1).concat(1):[],d=!i&&e.length>3,f=e.length>4,p=r&&n>1,m=r&&n>2,h=n>3,g=G(l),_=[{type:12,data:c},{type:12,data:g},{type:12,data:l},{type:1,data:t.epsilon}],v=t=>{let n=[{name:`output_size`,type:`u32`},{name:`components`,type:`u32`},{name:`hidden_size`,type:`u32`},{name:`epsilon`,type:`f32`}],r=[q(`x`,e[0].dataType,e[0].dims,g),q(`skip`,e[1].dataType,e[1].dims,g),q(`gamma`,e[2].dataType,e[2].dims,g)];d&&r.push(q(`beta`,e[3].dataType,e[3].dims,g)),f&&r.push(q(`bias`,e[4].dataType,e[4].dims,g)),r.push(J(`output`,e[0].dataType,s,g)),p&&r.push(J(`mean_output`,1,u)),m&&r.push(J(`inv_std_output`,1,u)),h&&r.push(J(`input_skip_bias_sum`,e[0].dataType,s,g));let a=U(e[0].dataType),o=U(1,g);return` | |
| ${t.registerUniforms(n).declareVariables(...r)} | |
| var<workgroup> sum_shared : array<${o}, 64>; | |
| var<workgroup> sum_squared_shared : array<${o}, 64>; | |
| ${t.mainStart([64,1,1])} | |
| let ix = local_id.x; | |
| let iy = global_id.x / 64; | |
| let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; | |
| var stride = hidden_size_vectorized / 64; | |
| let offset = ix * stride + iy * hidden_size_vectorized; | |
| let offset1d = stride * ix; | |
| if (ix == 63) { | |
| stride = hidden_size_vectorized - stride * ix; | |
| } | |
| for (var i: u32 = 0; i < stride; i++) { | |
| let skip_value = skip[offset + i]; | |
| let bias_value = ${f?`bias[offset1d + i]`:a+`(0.0)`}; | |
| let input_value = x[offset + i]; | |
| let value = input_value + skip_value + bias_value; | |
| ${h?`input_skip_bias_sum[offset + i] = value;`:``} | |
| output[offset + i] = value; | |
| let f32_value = ${Tn(a,g,`value`)}; | |
| sum_shared[ix] += f32_value; | |
| sum_squared_shared[ix] += f32_value * f32_value; | |
| } | |
| workgroupBarrier(); | |
| var reduce_size : u32 = 64; | |
| for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) { | |
| reduce_size = curr_size + (reduce_size & 1); | |
| if (ix < curr_size) { | |
| sum_shared[ix] += sum_shared[ix + reduce_size]; | |
| sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| let sum = sum_shared[0]; | |
| let square_sum = sum_squared_shared[0]; | |
| let mean = ${En(`sum`,g)} / f32(uniforms.hidden_size); | |
| let inv_std_dev = inverseSqrt(${En(`square_sum`,g)} / f32(uniforms.hidden_size) ${i?``:`- mean * mean`} + uniforms.epsilon); | |
| ${p?`mean_output[global_idx] = mean;`:``} | |
| ${m?`inv_std_output[global_idx] = inv_std_dev;`:``} | |
| for (var i: u32 = 0; i < stride; i++) { | |
| output[offset + i] = (output[offset + i] ${i?``:`- ${a}(mean)`}) * | |
| ${a}(inv_std_dev) * gamma[offset1d + i] | |
| ${d?`+ beta[offset1d + i]`:``}; | |
| } | |
| }`},y=[{dims:s,dataType:e[0].dataType}];return n>1&&y.push({dims:u,dataType:1}),n>2&&y.push({dims:u,dataType:1}),n>3&&y.push({dims:a,dataType:e[0].dataType}),{name:`SkipLayerNormalization`,shaderCache:{hint:`${g};${p};${m};${h}`,inputDependencies:e.map((e,t)=>`type`)},getShaderSource:v,getRunData:()=>({outputs:y,dispatchGroup:{x:Math.ceil(c/l)},programUniforms:_})}},Wl=(e,t)=>{Hl(e.inputs);let n=[0];e.outputCount>1&&n.push(-3),e.outputCount>2&&n.push(-3),e.outputCount>3&&n.push(3),e.compute(Ul(e.inputs,t,e.outputCount,!1),{outputs:n})}}),Kl,ql,Jl,Yl,Xl,Zl,Ql,$l,eu=o(()=>{L(),B(),H(),Y(),Kl=(e,t)=>{if(!e||e.length<1)throw Error(`too few inputs`);if(t.axes.length!==0){if(t.axes.length!==t.starts.length||t.axes.length!==t.ends.length)throw Error(`axes, starts and ends must have the same length`)}else if(t.starts.length!==t.ends.length)throw Error(`starts and ends must have the same length`);e.slice(1).forEach((t,n)=>{if(e[n+1].dataType!==6&&e[n+1].dataType!==7)throw Error(`Input ${n} must be an array of int32 or int64`)})},ql=(e,t)=>{let n=[];if(e.length>t)if(e[t].dataType===7)e[t].getBigInt64Array().forEach(e=>n.push(Number(e)));else if(e[t].dataType===6)e[t].getInt32Array().forEach(e=>n.push(Number(e)));else throw Error(`Input ${t} must be an array of int32 or int64`);return n},Jl=(e,t)=>{if(e.length>1){let t=ql(e,1),n=ql(e,2),r=ql(e,3);return r.length===0&&(r=[...Array(e[0].dims.length).keys()]),V({starts:t,ends:n,axes:r})}else return t},Yl=(e,t,n,r,i)=>{let a=e;return e<0&&(a+=n[r[t]]),i[t]<0?Math.max(0,Math.min(a,n[r[t]]-1)):Math.max(0,Math.min(a,n[r[t]]))},Xl=(e,t,n)=>`fn calculateInputIndices(output_indices: ${t.type.indices}) -> ${e.type.indices} { | |
| var input_indices: ${e.type.indices}; | |
| var carry = 0u; | |
| for (var i = ${n.length-1}; i >= 0; i--) { | |
| let input_shape_i = ${K(`uniforms.input_shape`,`i`,n.length)}; | |
| let steps_i = ${K(`uniforms.steps`,`i`,n.length)}; | |
| let signs_i = ${K(`uniforms.signs`,`i`,n.length)}; | |
| let starts_i = ${K(`uniforms.starts`,`i`,n.length)}; | |
| var output_index = ${t.indicesGet(`output_indices`,`i`)}; | |
| var input_index = output_index * steps_i + starts_i + carry; | |
| carry = input_index / input_shape_i; | |
| input_index = input_index % input_shape_i; | |
| if (signs_i < 0) { | |
| input_index = input_shape_i - input_index - 1u + starts_i; | |
| } | |
| ${e.indicesSet(`input_indices`,`i`,`input_index`)}; | |
| } | |
| return input_indices; | |
| }`,Zl=(e,t)=>{let n=e[0].dims,r=z.size(n),i=t.axes.length>0?z.normalizeAxes(t.axes,n.length):[...Array(n.length).keys()],a=ql(e,4);a.forEach(e=>e!==0||(()=>{throw Error(`step cannot be 0`)})),a.length===0&&(a=Array(i.length).fill(1));let o=t.starts.map((e,t)=>Yl(e,t,n,i,a)),s=t.ends.map((e,t)=>Yl(e,t,n,i,a));if(i.length!==o.length||i.length!==s.length)throw Error(`start, ends and axes should have the same number of elements`);if(i.length!==n.length)for(let e=0;e<n.length;++e)i.includes(e)||(o.splice(e,0,0),s.splice(e,0,n[e]),a.splice(e,0,1));let c=a.map(e=>Math.sign(e));a.forEach((e,t,n)=>{if(e<0){let r=(s[t]-o[t])/e,i=o[t];o[t]=i+r*a[t],s[t]=i,n[t]=-e}});let l=n.slice(0);i.forEach((e,t)=>{l[e]=Math.ceil((s[e]-o[e])/a[e])});let u={dims:l,dataType:e[0].dataType},d=J(`output`,e[0].dataType,l.length),f=q(`input`,e[0].dataType,e[0].dims.length),p=z.size(l),m=[{name:`outputSize`,type:`u32`},{name:`starts`,type:`u32`,length:o.length},{name:`signs`,type:`i32`,length:c.length},{name:`steps`,type:`u32`,length:a.length}],h=[{type:12,data:p},{type:12,data:o},{type:6,data:c},{type:12,data:a},...W(e[0].dims,l)];return{name:`Slice`,shaderCache:{hint:`${c.length}_${o.length}_${a.length}`,inputDependencies:[`rank`]},getShaderSource:e=>` | |
| ${e.registerUniforms(m).declareVariables(f,d)} | |
| ${Xl(f,d,n)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.outputSize`)} | |
| let output_indices = ${d.offsetToIndices(`global_idx`)}; | |
| let input_indices = calculateInputIndices(output_indices); | |
| ${d.setByOffset(`global_idx`,f.getByIndices(`input_indices`))} | |
| }`,getRunData:()=>({outputs:[u],dispatchGroup:{x:Math.ceil(r/64)},programUniforms:h})}},Ql=(e,t)=>{Kl(e.inputs,t);let n=Jl(e.inputs,t);e.compute(Zl(e.inputs,n),{inputs:[0]})},$l=e=>{let t=e.starts,n=e.ends,r=e.axes;return V({starts:t,ends:n,axes:r})}}),tu,nu,ru,iu,au=o(()=>{L(),B(),H(),Vn(),Y(),tu=e=>{if(!e||e.length!==1)throw Error(`Softmax op requires 1 input.`)},nu=(e,t)=>{let n=e.inputs[0],r=n.dims,i=z.size(r),a=r.length,o=z.normalizeAxis(t.axis,a),s=o<r.length-1,c,l=[];s?(l=Array.from({length:a},(e,t)=>t),l[o]=a-1,l[a-1]=o,c=e.compute(Rn(n,l),{inputs:[n],outputs:[-1]})[0]):c=n;let u=c.dims,d=u[a-1],f=i/d,p=G(d),m=d/p,h=64;f===1&&(h=256);let g=(e,t)=>t===4?`max(max(${e}.x, ${e}.y), max(${e}.z, ${e}.w))`:t===2?`max(${e}.x, ${e}.y)`:t===3?`max(max(${e}.x, ${e}.y), ${e}.z)`:e,_=q(`x`,c.dataType,c.dims,p),v=J(`result`,c.dataType,c.dims,p),y=_.type.value,b=U(c.dataType)===`f32`?`var threadMax = ${y}(-3.4028234663852886e+38f);`:`var threadMax = ${y}(-65504.0h);`,x=e.compute({name:`Softmax`,shaderCache:{hint:`${p};${h}`,inputDependencies:[`type`]},getRunData:()=>({outputs:[{dims:u,dataType:c.dataType}],dispatchGroup:{x:f},programUniforms:[{type:6,data:m}]}),getShaderSource:e=>` | |
| var<workgroup> rowMaxShared : ${y}; | |
| var<workgroup> rowSumShared : ${y}; | |
| var<workgroup> threadShared : array<${y}, ${h}>; | |
| fn getValue(row: i32, col: i32, row_stride: i32) -> ${y} { | |
| let index = row * row_stride + col; | |
| return x[index]; | |
| } | |
| fn setValue(row: i32, col: i32, row_stride: i32, value: ${y}) { | |
| let index = row * row_stride + col; | |
| result[index] = value; | |
| } | |
| ${e.registerUniform(`packedCols`,`i32`).declareVariables(_,v)} | |
| ${e.mainStart(h)} | |
| let gindex = i32(global_idx); | |
| let lindex = i32(local_idx); | |
| const wg = ${h}; | |
| let row = gindex / wg; | |
| let cols = uniforms.packedCols; | |
| let row_stride : i32 = uniforms.packedCols; | |
| // find the rows max | |
| ${b} | |
| for (var col = lindex; col < cols; col += wg) { | |
| let value = getValue(row, col, row_stride); | |
| threadMax = max(threadMax, value); | |
| } | |
| if (lindex < cols) { | |
| threadShared[lindex] = threadMax; | |
| } | |
| workgroupBarrier(); | |
| var reduceSize = min(cols, wg); | |
| for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) { | |
| reduceSize = currSize + (reduceSize & 1); | |
| if (lindex < currSize) { | |
| threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]); | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (lindex == 0) { | |
| rowMaxShared = ${y}(${g(`threadShared[0]`,p)}); | |
| } | |
| workgroupBarrier(); | |
| // find the rows sum | |
| var threadSum = ${y}(0.0); | |
| for (var col = lindex; col < cols; col += wg) { | |
| let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); | |
| threadSum += subExp; | |
| } | |
| threadShared[lindex] = threadSum; | |
| workgroupBarrier(); | |
| for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) { | |
| if (lindex < currSize) { | |
| threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize]; | |
| } | |
| workgroupBarrier(); | |
| } | |
| if (lindex == 0) { | |
| rowSumShared = ${y}(${En(`threadShared[0]`,p)}); | |
| } | |
| workgroupBarrier(); | |
| // calculate final value for each element in the row | |
| for (var col = lindex; col < cols; col += wg) { | |
| var value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; | |
| // max operation protects against NaN since all values should be >=0 | |
| value = max(value, ${y}(0.0)); | |
| setValue(row, col, row_stride, value); | |
| } | |
| }`},{inputs:[c],outputs:[s?-1:0]})[0];s&&e.compute(Rn(x,l),{inputs:[x]})},ru=(e,t)=>{tu(e.inputs),nu(e,t)},iu=e=>V({axis:e.axis})}),ou,su,cu,lu,uu,du=o(()=>{L(),B(),Y(),ou=e=>Array.from(e.getBigInt64Array(),Number),su=e=>{if(!e||e.length!==2)throw Error(`Tile requires 2 inputs.`);if(e[0].dataType!==1&&e[0].dataType!==10&&e[0].dataType!==6&&e[0].dataType!==12)throw Error(`Tile only support float, float16, int32, and uint32 data types`);if(e[1].dataType!==7)throw Error("Tile `repeats` input should be of int64 data type");if(e[1].dims.length!==1)throw Error("Tile `repeats` input should be 1-D");if(ou(e[1]).length!==e[0].dims.length)throw Error("Tile `repeats` input should have same number of elements as rank of input data tensor")},cu=(e,t)=>{let n=[];for(let r=0;r<e.length;++r)n.push(e[r]*t[r]);return n},lu=(e,t)=>{let n=e[0].dims,r=t??ou(e[1]),i=cu(n,r),a=z.size(i),o=e[0].dataType,s=q(`input`,o,n.length),c=J(`output`,o,i.length);return{name:`Tile`,shaderCache:{hint:`${r}`,inputDependencies:[`rank`]},getRunData:()=>({outputs:[{dims:i,dataType:e[0].dataType}],dispatchGroup:{x:Math.ceil(a/64)},programUniforms:[{type:12,data:a},...W(e[0].dims,i)]}),getShaderSource:e=>` | |
| const inputShape = ${s.indices(...n)}; | |
| ${e.registerUniform(`output_size`,`u32`).declareVariables(s,c)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.output_size`)} | |
| let output_indices = ${c.offsetToIndices(`global_idx`)}; | |
| var input_indices: ${s.type.indices}; | |
| for (var i = 0; i < ${n.length}; i++) { | |
| let input_dim_i = ${s.indicesGet(`uniforms.input_shape`,`i`)}; | |
| let input_dim_value = ${c.indicesGet(`output_indices`,`i`)} % input_dim_i; | |
| ${s.indicesSet(`input_indices`,`i`,`input_dim_value`)} | |
| } | |
| ${c.setByOffset(`global_idx`,s.getByIndices(`input_indices`))} | |
| }`}},uu=e=>{su(e.inputs),e.compute(lu(e.inputs),{inputs:[0]})}}),fu,pu,mu,hu=o(()=>{L(),B(),Y(),fu=(e,t,n,r,i)=>{let a=J(`output_data`,i,n.length,4),o=q(`a_data`,t[1].dataType,t[1].dims.length,4),s=q(`b_data`,t[2].dataType,t[2].dims.length,4),c=q(`c_data`,t[0].dataType,t[0].dims.length,4),l,u=(e,t,n)=>`select(${t}, ${e}, ${n})`;if(!r)l=a.setByOffset(`global_idx`,u(o.getByOffset(`global_idx`),s.getByOffset(`global_idx`),c.getByOffset(`global_idx`)));else{let e=(e,t,n=``)=>{let r=`a_data[index_a${t}][component_a${t}]`,i=`b_data[index_b${t}][component_b${t}]`,l=`bool(c_data[index_c${t}] & (0xffu << (component_c${t} * 8)))`;return` | |
| let output_indices${t} = ${a.offsetToIndices(`global_idx * 4u + ${t}u`)}; | |
| let offset_a${t} = ${o.broadcastedIndicesToOffset(`output_indices${t}`,a)}; | |
| let offset_b${t} = ${s.broadcastedIndicesToOffset(`output_indices${t}`,a)}; | |
| let offset_c${t} = ${c.broadcastedIndicesToOffset(`output_indices${t}`,a)}; | |
| let index_a${t} = offset_a${t} / 4u; | |
| let index_b${t} = offset_b${t} / 4u; | |
| let index_c${t} = offset_c${t} / 4u; | |
| let component_a${t} = offset_a${t} % 4u; | |
| let component_b${t} = offset_b${t} % 4u; | |
| let component_c${t} = offset_c${t} % 4u; | |
| ${e}[${t}] = ${n}(${u(r,i,l)}); | |
| `};l=i===9?` | |
| var data = vec4<u32>(0); | |
| ${e(`data`,0,`u32`)} | |
| ${e(`data`,1,`u32`)} | |
| ${e(`data`,2,`u32`)} | |
| ${e(`data`,3,`u32`)} | |
| output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`:` | |
| ${e(`output_data[global_idx]`,0)} | |
| ${e(`output_data[global_idx]`,1)} | |
| ${e(`output_data[global_idx]`,2)} | |
| ${e(`output_data[global_idx]`,3)} | |
| `}return` | |
| ${e.registerUniform(`vec_size`,`u32`).declareVariables(c,o,s,a)} | |
| ${e.mainStart()} | |
| ${e.guardAgainstOutOfBoundsWorkgroupSizes(`uniforms.vec_size`)} | |
| ${l} | |
| }`},pu=e=>{let t=e[1].dims,n=e[2].dims,r=e[0].dims,i=e[1].dataType,a=!(z.areEqual(t,n)&&z.areEqual(n,r)),o=t,s=z.size(t);if(a){let e=zt.calcShape(zt.calcShape(t,n,!1),r,!1);if(!e)throw Error(`Can't perform where op on the given tensors`);o=e,s=z.size(o)}let c=Math.ceil(s/4);return{name:`Where`,shaderCache:{inputDependencies:[`rank`,`rank`,`rank`]},getShaderSource:t=>fu(t,e,o,a,i),getRunData:()=>({outputs:[{dims:o,dataType:i}],dispatchGroup:{x:Math.ceil(s/64/4)},programUniforms:[{type:12,data:c},...W(r,t,n,o)]})}},mu=e=>{e.compute(pu(e.inputs))}}),gu,_u=o(()=>{Rr(),qr(),Qr(),ni(),Zi(),da(),va(),fo(),wo(),Oo(),Po(),Ko(),Qo(),ts(),os(),us(),hs(),bs(),Ts(),Bs(),fc(),_c(),xc(),wc(),Ac(),Js(),Bc(),ol(),dl(),hl(),bl(),Nr(),Vl(),oc(),Gl(),eu(),au(),nc(),du(),Vn(),qi(),hu(),gu=new Map([[`Abs`,[ii]],[`Acos`,[ai]],[`Acosh`,[oi]],[`Add`,[ta]],[`ArgMax`,[Ir,Lr]],[`ArgMin`,[Fr,Lr]],[`Asin`,[si]],[`Asinh`,[ci]],[`Atan`,[li]],[`Atanh`,[ui]],[`Attention`,[Kr]],[`AveragePool`,[Zc,Xc]],[`BatchNormalization`,[Zr]],[`BiasAdd`,[ti]],[`BiasSplitGelu`,[Xi]],[`Cast`,[fi,di]],[`Ceil`,[hi]],[`Clip`,[mi]],[`Concat`,[ga,_a]],[`Conv`,[uo,oo]],[`ConvTranspose`,[Co,yo]],[`Cos`,[gi]],[`Cosh`,[_i]],[`CumSum`,[Eo,Do]],[`DepthToSpace`,[Mo,No]],[`DequantizeLinear`,[ll,ul]],[`Div`,[na]],[`Einsum`,[Wo,Go]],[`Elu`,[yi,vi]],[`Equal`,[ra]],[`Erf`,[xi]],[`Exp`,[Si]],[`Expand`,[Zo]],[`FastGelu`,[es]],[`Floor`,[Ci]],[`FusedConv`,[uo,oo]],[`Gather`,[as,is]],[`GatherElements`,[ys,vs]],[`GatherBlockQuantized`,[ps,ms]],[`GatherND`,[cs,ls]],[`Gelu`,[wi]],[`Gemm`,[ws,Cs]],[`GlobalAveragePool`,[el,$c]],[`GlobalMaxPool`,[al,il]],[`Greater`,[sa]],[`GreaterOrEqual`,[la]],[`GridSample`,[Rs,zs]],[`GroupQueryAttention`,[dc]],[`HardSigmoid`,[Mi,ji]],[`InstanceNormalization`,[gc]],[`LayerNormalization`,[bc]],[`LeakyRelu`,[Ti,vi]],[`Less`,[ca]],[`LessOrEqual`,[ua]],[`Log`,[Ui]],[`MatMul`,[Cc]],[`MatMulNBits`,[Oc,kc]],[`MaxPool`,[nl,rl]],[`Mul`,[ia]],[`MultiHeadAttention`,[qs,Us]],[`Neg`,[Di]],[`Not`,[Ei]],[`Pad`,[zc]],[`Pow`,[aa]],[`QuickGelu`,[Ki,vi]],[`Range`,[ml]],[`Reciprocal`,[Oi]],[`ReduceMin`,[Or]],[`ReduceMean`,[Q]],[`ReduceMax`,[Dr]],[`ReduceSum`,[Ar]],[`ReduceProd`,[kr]],[`ReduceL1`,[wr]],[`ReduceL2`,[Tr]],[`ReduceLogSum`,[Mr]],[`ReduceLogSumExp`,[Er]],[`ReduceSumSquare`,[jr]],[`Relu`,[ki]],[`Resize`,[zl,Bl]],[`RotaryEmbedding`,[ac]],[`ScatterND`,[yl,vl]],[`Sigmoid`,[Ai]],[`Sin`,[Ni]],[`Sinh`,[Pi]],[`Slice`,[Ql,$l]],[`SkipLayerNormalization`,[Wl]],[`Split`,[ec,tc]],[`Sqrt`,[Fi]],[`Softmax`,[ru,iu]],[`Sub`,[oa]],[`Tan`,[Ii]],[`Tanh`,[Ri]],[`ThresholdedRelu`,[Hi,vi]],[`Tile`,[uu]],[`Transpose`,[zn,Bn]],[`Where`,[mu]]])}),vu,yu=o(()=>{N(),Lt(),Y(),vu=class{constructor(e){this.backend=e,this.repo=new Map,this.attributesBound=!1}getArtifact(e){return this.repo.get(e)}setArtifact(e,t){this.repo.set(e,t)}run(e,t,n,r,i){_e(e.programInfo.name);let a=this.backend.device,o=this.backend.getComputePassEncoder();this.backend.writeTimestamp(this.backend.pendingDispatchNumber*2);let s=[];for(let e of t)s.push({binding:s.length,resource:{buffer:e.buffer}});for(let e of n)s.push({binding:s.length,resource:{buffer:e.buffer}});i&&s.push({binding:s.length,resource:i});let c=a.createBindGroup({layout:e.computePipeline.getBindGroupLayout(0),entries:s,label:e.programInfo.name});if(this.backend.sessionStatus===`capturing`){let t={kernelId:this.backend.currentKernelId,computePipeline:e.computePipeline,bindGroup:c,dispatchGroup:r};this.backend.capturedCommandList.get(this.backend.currentSessionId).push(t)}o.setPipeline(e.computePipeline),o.setBindGroup(0,c),o.dispatchWorkgroups(...r),this.backend.writeTimestamp(this.backend.pendingDispatchNumber*2+1),this.backend.pendingDispatchNumber++,(this.backend.pendingDispatchNumber>=this.backend.maxDispatchNumber||this.backend.queryType===`at-passes`)&&this.backend.endComputePass(),this.backend.pendingDispatchNumber>=this.backend.maxDispatchNumber&&this.backend.flush(),j(e.programInfo.name)}dispose(){}build(e,t){_e(e.name);let n=this.backend.device,r=[];[{feature:`shader-f16`,extension:`f16`},{feature:`subgroups`,extension:`subgroups`}].forEach(e=>{n.features.has(e.feature)&&r.push(`enable ${e.extension};`)});let i=jn(t,this.backend.device.limits),a=e.getShaderSource(i),o=`${r.join(` | |
| `)} | |
| ${i.additionalImplementations} | |
| ${a}`,s=n.createShaderModule({code:o,label:e.name});R(`verbose`,()=>`[WebGPU] ${e.name} shader code: ${o}`);let c=n.createComputePipeline({compute:{module:s,entryPoint:`main`},layout:`auto`,label:e.name});return j(e.name),{programInfo:e,computePipeline:c,uniformVariablesInfo:i.variablesInfo}}normalizeDispatchGroupSize(e){let t=typeof e==`number`?e:e.x,n=typeof e==`number`?1:e.y||1,r=typeof e==`number`?1:e.z||1,i=this.backend.device.limits.maxComputeWorkgroupsPerDimension;if(t<=i&&n<=i&&r<=i)return[t,n,r];let a=t*n*r,o=Math.ceil(Math.sqrt(a));if(o>i){if(o=Math.ceil(Math.cbrt(a)),o>i)throw Error(`Total dispatch size exceeds WebGPU maximum.`);return[o,o,o]}else return[o,o,1]}}}),bu={};s(bu,{WebGpuBackend:()=>wu});var xu,Su,Cu,wu,Tu=o(()=>{N(),L(),Lt(),Gt(),yn(),_u(),yu(),xu=(e,t)=>{if(t.length!==e.length)throw Error(`inputDependencies length ${t.length} is not equal to inputTensors length ${e.length}.`);let n=[];for(let r=0;r<e.length;++r){let i=e[r].dataType;switch(t[r]){case`none`:n.push(``);break;case`type`:n.push(`${i}`);break;case`rank`:{let t=e[r].dims.length;n.push(`${i};${t}`);break}case`dims`:{let t=e[r].dims.join(`,`);n.push(`${i};${t}`);break}default:throw Error(`unsupported input dependency: ${t[r]}`)}}return n.join(`|`)},Su=(e,t,n)=>{let r=e.name;return e.shaderCache?.hint&&(r+=`[`+e.shaderCache.hint+`]`),r+=`:`+n+`:${xu(t,e.shaderCache?.inputDependencies??Array(t.length).fill(`dims`))}`,r},Cu=class{constructor(e){e&&(this.architecture=e.architecture,this.vendor=e.vendor)}isArchitecture(e){return this.architecture===e}isVendor(e){return this.vendor===e}},wu=class{constructor(){this.currentSessionId=null,this.currentKernelId=null,this.commandEncoder=null,this.computePassEncoder=null,this.maxDispatchNumber=16,this.pendingDispatchNumber=0,this.pendingKernels=[],this.pendingQueries=new Map,this.sessionStatus=`default`,this.capturedCommandList=new Map,this.capturedPendingKernels=new Map,this.sessionExternalDataMapping=new Map}get currentKernelCustomData(){if(this.currentKernelId===null)throw Error(`currentKernelCustomData(): currentKernelId is null. (should not happen)`);let e=this.kernelCustomData.get(this.currentKernelId);return e||(e={},this.kernelCustomData.set(this.currentKernelId,e)),e}async initialize(e,t){this.env=e;let n=[],r={requiredLimits:{maxComputeWorkgroupStorageSize:t.limits.maxComputeWorkgroupStorageSize,maxComputeWorkgroupsPerDimension:t.limits.maxComputeWorkgroupsPerDimension,maxStorageBufferBindingSize:t.limits.maxStorageBufferBindingSize,maxBufferSize:t.limits.maxBufferSize,maxComputeInvocationsPerWorkgroup:t.limits.maxComputeInvocationsPerWorkgroup,maxComputeWorkgroupSizeX:t.limits.maxComputeWorkgroupSizeX,maxComputeWorkgroupSizeY:t.limits.maxComputeWorkgroupSizeY,maxComputeWorkgroupSizeZ:t.limits.maxComputeWorkgroupSizeZ},requiredFeatures:n},i=e=>t.features.has(e)&&n.push(e)&&!0;i(`chromium-experimental-timestamp-query-inside-passes`)||i(`timestamp-query`),i(`shader-f16`),i(`subgroups`),this.device=await t.requestDevice(r),this.adapterInfo=new Cu(t.info||await t.requestAdapterInfo()),this.gpuDataManager=vn(this),this.programManager=new vu(this),this.kernels=new Map,this.kernelPersistentData=new Map,this.kernelCustomData=new Map,Ft(e.logLevel,!!e.debug),this.device.onuncapturederror=e=>{e.error instanceof GPUValidationError&&console.error(`An uncaught WebGPU validation error was raised: ${e.error.message}`)},Object.defineProperty(this.env.webgpu,`device`,{value:this.device,writable:!1,enumerable:!0,configurable:!1}),Object.defineProperty(this.env.webgpu,`adapter`,{value:t,writable:!1,enumerable:!0,configurable:!1}),this.setQueryType()}dispose(){typeof this.querySet<`u`&&this.querySet.destroy(),this.gpuDataManager.dispose()}getCommandEncoder(){return this.commandEncoder||=this.device.createCommandEncoder(),this.commandEncoder}getComputePassEncoder(){if(!this.computePassEncoder){let e=this.getCommandEncoder(),t={};this.queryType===`at-passes`&&(t.timestampWrites={querySet:this.querySet,beginningOfPassWriteIndex:this.pendingDispatchNumber*2,endOfPassWriteIndex:this.pendingDispatchNumber*2+1}),this.computePassEncoder=e.beginComputePass(t)}return this.computePassEncoder}endComputePass(){this.computePassEncoder&&=(this.computePassEncoder.end(),null)}flush(){if(!this.commandEncoder)return;_e(),this.endComputePass();let e;this.queryType!==`none`&&(this.commandEncoder.resolveQuerySet(this.querySet,0,this.pendingDispatchNumber*2,this.queryResolveBuffer,0),e=this.device.createBuffer({size:this.pendingDispatchNumber*2*8,usage:GPUBufferUsage.MAP_READ|GPUBufferUsage.COPY_DST}),this.pendingQueries.set(e,this.pendingKernels),this.pendingKernels=[],this.commandEncoder.copyBufferToBuffer(this.queryResolveBuffer,0,e,0,this.pendingDispatchNumber*2*8)),this.device.queue.submit([this.commandEncoder.finish()]),this.gpuDataManager.refreshPendingBuffers(),this.commandEncoder=null,this.pendingDispatchNumber=0,this.queryType!==`none`&&e.mapAsync(GPUMapMode.READ).then(()=>{let t=new BigUint64Array(e.getMappedRange()),n=this.pendingQueries.get(e);for(let e=0;e<t.length/2;e++){let r=n[e],i=r.kernelId,a=this.kernels.get(i),o=a.kernelType,s=a.kernelName,c=r.programName,l=r.inputTensorViews,u=r.outputTensorViews,d=t[e*2],f=t[e*2+1];typeof this.queryTimeBase>`u`&&(this.queryTimeBase=d);let p=Number(d-this.queryTimeBase),m=Number(f-this.queryTimeBase);if(!Number.isSafeInteger(p)||!Number.isSafeInteger(m))throw RangeError(`incorrect timestamp range`);if(this.env.webgpu.profiling?.ondata)this.env.webgpu.profiling.ondata({version:1,inputsMetadata:l.map(e=>({dims:e.dims,dataType:St(e.dataType)})),outputsMetadata:u.map(e=>({dims:e.dims,dataType:St(e.dataType)})),kernelId:i,kernelType:o,kernelName:s,programName:c,startTime:p,endTime:m});else{let e=``;l.forEach((t,n)=>{e+=`input[${n}]: [${t.dims}] | ${St(t.dataType)}, `});let t=``;u.forEach((e,n)=>{t+=`output[${n}]: [${e.dims}] | ${St(e.dataType)}, `}),console.log(`[profiling] kernel "${i}|${o}|${s}|${c}" ${e}${t}start time: ${p} ns, execution time: ${m-p} ns`)}he(`GPU`,`${c}::${d}::${f}`)}e.unmap(),this.pendingQueries.delete(e)}),j()}run(e,t,n,r,i,a){_e(e.name);let o=[];for(let e=0;e<t.length;++e){let n=t[e].data;if(n===0)continue;let r=this.gpuDataManager.get(n);if(!r)throw Error(`no GPU data for input: ${n}`);o.push(r)}let{outputs:s,dispatchGroup:c,programUniforms:l}=e.getRunData(t),u=n.length===0?s.map((e,t)=>t):n;if(u.length!==s.length)throw Error(`Output size ${u.length} must be equal to ${s.length}.`);let d=[],f=[];for(let e=0;e<s.length;++e){if(!Number.isInteger(u[e])||u[e]<-3||u[e]>=a)throw Error(`Invalid output index: ${u[e]}`);if(u[e]===-3)continue;let t=u[e]===-1,n=u[e]===-2,o=t||n?i(s[e].dataType,s[e].dims):r(u[e],s[e].dataType,s[e].dims);if(d.push(o),o.data===0)continue;let c=this.gpuDataManager.get(o.data);if(!c)throw Error(`no GPU data for output: ${o.data}`);if(t&&this.temporaryData.push(c),n){let e=this.kernelPersistentData.get(this.currentKernelId);e||(e=[],this.kernelPersistentData.set(this.currentKernelId,e)),e.push(c)}f.push(c)}if(o.length!==t.length||f.length!==d.length){if(f.length===0)return j(e.name),d;throw Error(`Program ${e.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`)}let p;if(l){let e=0,t=[];l.forEach(n=>{let r=typeof n.data==`number`?[n.data]:n.data;if(r.length===0)return;let i=n.type===10?2:4,a,o;n.type===10?(o=r.length>4?16:r.length>2?8:r.length*i,a=r.length>4?16:i*r.length):(o=r.length<=2?r.length*i:16,a=16),e=Math.ceil(e/o)*o,t.push(e);let s=n.type===10?8:4;e+=r.length>4?Math.ceil(r.length/s)*a:r.length*i}),e=Math.ceil(e/16)*16;let n=new ArrayBuffer(e);l.forEach((e,r)=>{let i=t[r],a=typeof e.data==`number`?[e.data]:e.data;if(e.type===6)new Int32Array(n,i,a.length).set(a);else if(e.type===12)new Uint32Array(n,i,a.length).set(a);else if(e.type===10)new Uint16Array(n,i,a.length).set(a);else if(e.type===1)new Float32Array(n,i,a.length).set(a);else throw Error(`Unsupported uniform type: ${St(e.type)}`)});let r=this.gpuDataManager.create(e,GPUBufferUsage.COPY_DST|GPUBufferUsage.UNIFORM);this.device.queue.writeBuffer(r.buffer,0,n,0,e),this.gpuDataManager.release(r.id),p={offset:0,size:e,buffer:r.buffer}}let m=this.programManager.normalizeDispatchGroupSize(c),h=m[1]===1&&m[2]===1,g=Su(e,t,h),_=this.programManager.getArtifact(g);if(_||(_=this.programManager.build(e,m),this.programManager.setArtifact(g,_),R(`info`,()=>`[artifact] key: ${g}, programName: ${e.name}`)),l&&_.uniformVariablesInfo){if(l.length!==_.uniformVariablesInfo.length)throw Error(`Uniform variables count mismatch: expect ${_.uniformVariablesInfo.length}, got ${l.length} in program "${_.programInfo.name}".`);for(let e=0;e<l.length;e++){let t=l[e],n=t.type,r=typeof t.data==`number`?1:t.data.length,[i,a]=_.uniformVariablesInfo[e];if(n!==i||r!==a)throw Error(`Uniform variable ${e} mismatch: expect type ${i} with size ${a}, got type ${n} with size ${r} in program "${_.programInfo.name}".`)}}if(R(`info`,()=>`[ProgramManager] run "${e.name}" (key=${g}) with ${m[0]}x${m[1]}x${m[2]}`),this.queryType!==`none`||this.sessionStatus===`capturing`){let e={kernelId:this.currentKernelId,programName:_.programInfo.name,inputTensorViews:t,outputTensorViews:d};this.pendingKernels.push(e),this.sessionStatus===`capturing`&&this.capturedPendingKernels.get(this.currentSessionId).push(e)}return this.programManager.run(_,o,f,m,p),j(e.name),d}upload(e,t){this.gpuDataManager.upload(e,t)}memcpy(e,t){this.gpuDataManager.memcpy(e,t)}async download(e,t){await this.gpuDataManager.download(e,t)}alloc(e){return this.gpuDataManager.create(e).id}free(e){return this.gpuDataManager.release(e)}createKernel(e,t,n,r){let i=gu.get(e);if(!i)throw Error(`kernel not implemented: ${e}`);let a={kernelType:e,kernelName:r,kernelEntry:i[0],attributes:[i[1],n]};this.kernels.set(t,a)}releaseKernel(e){let t=this.kernelPersistentData.get(e);if(t){for(let e of t)this.gpuDataManager.release(e.id);this.kernelPersistentData.delete(e)}this.kernelCustomData.delete(e),this.kernels.delete(e)}computeKernel(e,t,n){let r=this.kernels.get(e);if(!r)throw Error(`kernel not created: ${e}`);let i=r.kernelType,a=r.kernelName,o=r.kernelEntry,s=r.attributes;if(this.currentKernelId!==null)throw Error(`kernel "[${i}] ${a}" is not allowed to be called recursively`);this.currentKernelId=e,s[0]&&=(s[1]=s[0](s[1]),void 0),R(`info`,()=>`[WebGPU] Start to run kernel "[${i}] ${a}"...`);let c=this.env.debug;this.temporaryData=[];try{return c&&this.device.pushErrorScope(`validation`),o(t,s[1]),0}catch(e){return n.push(Promise.resolve(`[WebGPU] Kernel "[${i}] ${a}" failed. ${e}`)),1}finally{c&&n.push(this.device.popErrorScope().then(e=>e?`GPU validation error for kernel "[${i}] ${a}": ${e.message}`:null));for(let e of this.temporaryData)this.gpuDataManager.release(e.id);this.temporaryData=[],this.currentKernelId=null}}registerBuffer(e,t,n,r){let i=this.sessionExternalDataMapping.get(e);i||(i=new Map,this.sessionExternalDataMapping.set(e,i));let a=i.get(t),o=this.gpuDataManager.registerExternalBuffer(n,r,a);return i.set(t,[o,n]),o}unregisterBuffers(e){let t=this.sessionExternalDataMapping.get(e);t&&(t.forEach(e=>this.gpuDataManager.unregisterExternalBuffer(e[0])),this.sessionExternalDataMapping.delete(e))}getBuffer(e){let t=this.gpuDataManager.get(e);if(!t)throw Error(`no GPU data for buffer: ${e}`);return t.buffer}createDownloader(e,t,n){return async()=>{let r=await gn(this,e,t);return Wt(r.buffer,n)}}writeTimestamp(e){this.queryType===`inside-passes`&&this.computePassEncoder.writeTimestamp(this.querySet,e)}setQueryType(){this.queryType=`none`,(this.env.webgpu.profiling?.mode===`default`||(typeof this.env.trace>`u`?this.env.wasm.trace:this.env.trace))&&(this.device.features.has(`chromium-experimental-timestamp-query-inside-passes`)?this.queryType=`inside-passes`:this.device.features.has(`timestamp-query`)&&(this.queryType=`at-passes`),this.queryType!==`none`&&typeof this.querySet>`u`&&(this.querySet=this.device.createQuerySet({type:`timestamp`,count:this.maxDispatchNumber*2}),this.queryResolveBuffer=this.device.createBuffer({size:this.maxDispatchNumber*2*8,usage:GPUBufferUsage.COPY_SRC|GPUBufferUsage.QUERY_RESOLVE})))}captureBegin(){R(`info`,`captureBegin`),this.capturedCommandList.get(this.currentSessionId)||this.capturedCommandList.set(this.currentSessionId,[]),this.capturedPendingKernels.get(this.currentSessionId)||this.capturedPendingKernels.set(this.currentSessionId,[]),this.flush(),this.sessionStatus=`capturing`}captureEnd(){R(`info`,`captureEnd`),this.flush(),this.sessionStatus=`default`}replay(){R(`info`,`replay`),this.sessionStatus=`replaying`;let e=this.capturedCommandList.get(this.currentSessionId),t=this.capturedPendingKernels.get(this.currentSessionId),n=e.length;this.pendingKernels=[];for(let r=0;r<n;r++){let n=this.getComputePassEncoder(),i=e[r];this.writeTimestamp(this.pendingDispatchNumber*2),n.setPipeline(i.computePipeline),n.setBindGroup(0,i.bindGroup),n.dispatchWorkgroups(...i.dispatchGroup),this.writeTimestamp(this.pendingDispatchNumber*2+1),this.pendingDispatchNumber++,this.queryType!==`none`&&this.pendingKernels.push(t[r]),(this.pendingDispatchNumber>=this.maxDispatchNumber||this.queryType===`at-passes`)&&this.endComputePass(),this.pendingDispatchNumber>=this.maxDispatchNumber&&this.flush()}this.flush(),this.sessionStatus=`default`}onCreateSession(){this.gpuDataManager.onCreateSession()}onReleaseSession(e){this.unregisterBuffers(e),this.capturedCommandList.has(e)&&this.capturedCommandList.delete(e),this.capturedPendingKernels.has(e)&&this.capturedPendingKernels.delete(e),this.gpuDataManager.onReleaseSession(e)}onRunStart(e){this.currentSessionId=e,this.setQueryType()}}}),Eu={};s(Eu,{init:()=>ku});var Du,Ou,ku,Au=o(()=>{L(),Lt(),B(),cn(),Du=class e{constructor(e,t,n,r){this.module=e,this.dataType=t,this.data=n,this.dims=r}getFloat32Array(){if(this.dataType!==1)throw Error(`Invalid data type`);let e=z.size(this.dims);return e===0?new Float32Array:new Float32Array(this.module.HEAP8.buffer,this.data,e)}getBigInt64Array(){if(this.dataType!==7)throw Error(`Invalid data type`);let e=z.size(this.dims);return e===0?new BigInt64Array:new BigInt64Array(this.module.HEAP8.buffer,this.data,e)}getInt32Array(){if(this.dataType!==6)throw Error(`Invalid data type`);let e=z.size(this.dims);return e===0?new Int32Array:new Int32Array(this.module.HEAP8.buffer,this.data,e)}getUint16Array(){if(this.dataType!==10&&this.dataType!==4)throw Error(`Invalid data type`);let e=z.size(this.dims);return e===0?new Uint16Array:new Uint16Array(this.module.HEAP8.buffer,this.data,e)}reshape(t){if(z.size(t)!==z.size(this.dims))throw Error(`Invalid new shape`);return new e(this.module,this.dataType,this.data,t)}},Ou=class{constructor(e,t,n){this.module=e,this.backend=t,this.customDataOffset=0,this.customDataSize=0,this.adapterInfo=t.adapterInfo;let r=e.PTR_SIZE,i=n/e.PTR_SIZE,a=r===4?`i32`:`i64`;this.opKernelContext=Number(e.getValue(r*i++,a));let o=Number(e.getValue(r*i++,a));this.outputCount=Number(e.getValue(r*i++,a)),this.customDataOffset=Number(e.getValue(r*i++,`*`)),this.customDataSize=Number(e.getValue(r*i++,a));let s=[];for(let t=0;t<o;t++){let t=Number(e.getValue(r*i++,a)),n=Number(e.getValue(r*i++,`*`)),o=Number(e.getValue(r*i++,a)),c=[];for(let t=0;t<o;t++)c.push(Number(e.getValue(r*i++,a)));s.push(new Du(e,t,n,c))}this.inputs=s}get kernelCustomData(){return this.backend.currentKernelCustomData}get customDataBuffer(){return this.module.HEAPU8.subarray(this.customDataOffset,this.customDataOffset+this.customDataSize)}compute(e,t){let n=t?.inputs?.map(e=>typeof e==`number`?this.inputs[e]:e)??this.inputs,r=t?.outputs??[];return this.backend.run(e,n,r,(e,t,n)=>new Du(this.module,t,this.output(e,n),n),(e,t)=>{let n=Ct(e,t);if(!n)throw Error(`Unsupported data type: ${e}`);let r=n>0?this.backend.gpuDataManager.create(n).id:0;return new Du(this.module,e,r,t)},this.outputCount)}output(e,t){let n=this.module.stackSave();try{let n=this.module.PTR_SIZE,r=n===4?`i32`:`i64`,i=this.module.stackAlloc((1+t.length)*n);this.module.setValue(i,t.length,r);for(let e=0;e<t.length;e++)this.module.setValue(i+n*(e+1),t[e],r);return this.module._JsepOutput(this.opKernelContext,e,i)}catch(n){throw Error(`Failed to generate kernel's output[${e}] with dims [${t}]. If you are running with pre-allocated output, please make sure the output type/dims are correct. Error: ${n}`)}finally{this.module.stackRestore(n)}}},ku=async(e,t,n,r)=>{let i=t.jsepInit;if(!i)throw Error(`Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.`);if(e===`webgpu`){let e=(Tu(),l(bu)).WebGpuBackend,a=new e;await a.initialize(n,r),i(`webgpu`,[a,e=>a.alloc(Number(e)),e=>a.free(e),(e,n,r,i=!1)=>{if(i)R(`verbose`,()=>`[WebGPU] jsepCopyGpuToGpu: src=${Number(e)}, dst=${Number(n)}, size=${Number(r)}`),a.memcpy(Number(e),Number(n));else{R(`verbose`,()=>`[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(e)}, gpuDataId=${Number(n)}, size=${Number(r)}`);let i=t.HEAPU8.subarray(Number(e>>>0),Number(e>>>0)+Number(r));a.upload(Number(n),i)}},async(e,n,r)=>{R(`verbose`,()=>`[WebGPU] jsepCopyGpuToCpu: gpuDataId=${e}, dataOffset=${n}, size=${r}`),await a.download(Number(e),()=>t.HEAPU8.subarray(Number(n)>>>0,Number(n+r)>>>0))},(e,n,r)=>a.createKernel(e,Number(n),r,t.UTF8ToString(t._JsepGetNodeName(Number(n)))),e=>a.releaseKernel(e),(e,n,r,i)=>{R(`verbose`,()=>`[WebGPU] jsepRun: sessionHandle=${r}, kernel=${e}, contextDataOffset=${n}`);let o=new Ou(t,a,Number(n));return a.computeKernel(Number(e),o,i)},()=>a.captureBegin(),()=>a.captureEnd(),()=>a.replay()])}else{let e=new sn(n);i(`webnn`,[e,()=>e.reserveTensorId(),t=>e.releaseTensorId(t),async(t,n,r,i,a)=>e.ensureTensor(t,n,r,i,a),(t,n)=>{e.uploadTensor(t,n)},async(t,n)=>e.downloadTensor(t,n),(t,n)=>e.registerMLContext(t,n),!!n.trace])}}}),ju,Mu,Nu,Pu,Fu,Iu,Lu,Ru,zu,Bu,Vu,Hu,Uu,Wu=o(()=>{N(),pt(),bt(),L(),ct(),dt(),At(),ju=(e,t)=>{F()._OrtInit(e,t)!==0&&I(`Can't initialize onnxruntime.`)},Mu=async e=>{ju(e.wasm.numThreads,Tt(e.logLevel))},Nu=async(e,t)=>{F().asyncInit?.();let n=e.webgpu.adapter;if(t===`webgpu`){if(typeof navigator>`u`||!navigator.gpu)throw Error(`WebGPU is not supported in current environment`);if(n){if(typeof n.limits!=`object`||typeof n.features!=`object`||typeof n.requestDevice!=`function`)throw Error("Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.")}else{let t=e.webgpu.powerPreference;if(t!==void 0&&t!==`low-power`&&t!==`high-performance`)throw Error(`Invalid powerPreference setting: "${t}"`);let r=e.webgpu.forceFallbackAdapter;if(r!==void 0&&typeof r!=`boolean`)throw Error(`Invalid forceFallbackAdapter setting: "${r}"`);if(n=await navigator.gpu.requestAdapter({powerPreference:t,forceFallbackAdapter:r}),!n)throw Error(`Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.`)}}if(t===`webnn`&&(typeof navigator>`u`||!navigator.ml))throw Error(`WebNN is not supported in current environment`);{let r=(Au(),l(Eu)).init;t===`webgpu`&&await r(`webgpu`,F(),e,n),t===`webnn`&&await r(`webnn`,F(),e)}},Pu=new Map,Fu=e=>{let t=F(),n=t.stackSave();try{let n=t.PTR_SIZE,r=t.stackAlloc(2*n);t._OrtGetInputOutputCount(e,r,r+n)!==0&&I(`Can't get session input/output count.`);let i=n===4?`i32`:`i64`;return[Number(t.getValue(r,i)),Number(t.getValue(r+n,i))]}finally{t.stackRestore(n)}},Iu=(e,t)=>{let n=F(),r=n.stackSave(),i=0;try{let r=n.PTR_SIZE,a=n.stackAlloc(2*r);n._OrtGetInputOutputMetadata(e,t,a,a+r)!==0&&I(`Can't get session input/output metadata.`);let o=Number(n.getValue(a,`*`));i=Number(n.getValue(a+r,`*`));let s=n.HEAP32[i/4];if(s===0)return[o,0];let c=n.HEAPU32[i/4+1],l=[];for(let e=0;e<c;e++){let t=Number(n.getValue(i+8+e*r,`*`));l.push(t===0?Number(n.getValue(i+8+(e+c)*r,`*`)):n.UTF8ToString(t))}return[o,s,l]}finally{n.stackRestore(r),i!==0&&n._OrtFree(i)}},Lu=e=>{let t=F(),n=t._malloc(e.byteLength);if(n===0)throw Error(`Can't create a session. failed to allocate a buffer of size ${e.byteLength}.`);return t.HEAPU8.set(e,n),[n,e.byteLength]},Ru=async(e,t)=>{let n,r,i=F();Array.isArray(e)?[n,r]=e:e.buffer===i.HEAPU8.buffer?[n,r]=[e.byteOffset,e.byteLength]:[n,r]=Lu(e);let a=0,o=0,s=0,c=[],l=[],u=[];try{if([o,c]=await yt(t),t?.externalData&&i.mountExternalData){let e=[];for(let n of t.externalData){let t=typeof n==`string`?n:n.path;e.push(kt(typeof n==`string`?n:n.data).then(e=>{i.mountExternalData(t,e)}))}await Promise.all(e)}for(let e of t?.executionProviders??[])if((typeof e==`string`?e:e.name)===`webnn`){if(i.shouldTransferToMLTensor=!1,typeof e!=`string`){let t=e,n=t?.context,r=t?.gpuDevice,a=t?.deviceType,o=t?.powerPreference;n?i.currentContext=n:r?i.currentContext=await i.webnnCreateMLContext(r):i.currentContext=await i.webnnCreateMLContext({deviceType:a,powerPreference:o})}else i.currentContext=await i.webnnCreateMLContext();break}a=await i._OrtCreateSession(n,r,o),i.webgpuOnCreateSession?.(a),a===0&&I(`Can't create a session.`),i.jsepOnCreateSession?.(),i.currentContext&&(i.webnnRegisterMLContext(a,i.currentContext),i.currentContext=void 0,i.shouldTransferToMLTensor=!0);let[e,d]=Fu(a),f=!!t?.enableGraphCapture,p=[],m=[],h=[],g=[],_=[];for(let t=0;t<e;t++){let[e,n,r]=Iu(a,t);e===0&&I(`Can't get an input name.`),l.push(e);let o=i.UTF8ToString(e);p.push(o),h.push(n===0?{name:o,isTensor:!1}:{name:o,isTensor:!0,type:St(n),shape:r})}for(let n=0;n<d;n++){let[r,o,s]=Iu(a,n+e);r===0&&I(`Can't get an output name.`),u.push(r);let c=i.UTF8ToString(r);m.push(c),g.push(o===0?{name:c,isTensor:!1}:{name:c,isTensor:!0,type:St(o),shape:s});{if(f&&t?.preferredOutputLocation===void 0){_.push(`gpu-buffer`);continue}let e=typeof t?.preferredOutputLocation==`string`?t.preferredOutputLocation:t?.preferredOutputLocation?.[c]??`cpu`,n=i.webnnIsGraphOutput;if(e===`cpu`&&n&&n(a,c)){_.push(`ml-tensor-cpu-output`);continue}if(e!==`cpu`&&e!==`cpu-pinned`&&e!==`gpu-buffer`&&e!==`ml-tensor`)throw Error(`Not supported preferred output location: ${e}.`);if(f&&e!==`gpu-buffer`)throw Error(`Not supported preferred output location: ${e}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`);_.push(e)}}let v=null;return _.some(e=>e===`gpu-buffer`||e===`ml-tensor`||e===`ml-tensor-cpu-output`)&&(s=i._OrtCreateBinding(a),s===0&&I(`Can't create IO binding.`),v={handle:s,outputPreferredLocations:_,outputPreferredLocationsEncoded:_.map(e=>e===`ml-tensor-cpu-output`?`ml-tensor`:e).map(e=>Ot(e))}),Pu.set(a,[a,l,u,v,f,!1]),[a,p,m,h,g]}catch(e){throw l.forEach(e=>i._OrtFree(e)),u.forEach(e=>i._OrtFree(e)),s!==0&&i._OrtReleaseBinding(s)!==0&&I(`Can't release IO binding.`),a!==0&&i._OrtReleaseSession(a)!==0&&I(`Can't release session.`),e}finally{i._free(n),o!==0&&i._OrtReleaseSessionOptions(o)!==0&&I(`Can't release session options.`),c.forEach(e=>i._free(e)),i.unmountExternalData?.()}},zu=e=>{let t=F(),n=Pu.get(e);if(!n)throw Error(`cannot release session. invalid session id: ${e}`);let[r,i,a,o,s]=n;o&&(s&&t._OrtClearBoundOutputs(o.handle)!==0&&I(`Can't clear bound outputs.`),t._OrtReleaseBinding(o.handle)!==0&&I(`Can't release IO binding.`)),t.jsepOnReleaseSession?.(e),t.webnnOnReleaseSession?.(e),t.webgpuOnReleaseSession?.(e),i.forEach(e=>t._OrtFree(e)),a.forEach(e=>t._OrtFree(e)),t._OrtReleaseSession(r)!==0&&I(`Can't release session.`),Pu.delete(e)},Bu=async(e,t,n,r,i,a,o=!1)=>{if(!e){t.push(0);return}let s=F(),c=s.PTR_SIZE,l=e[0],u=e[1],d=e[3],f=d,p,m;if(l===`string`&&(d===`gpu-buffer`||d===`ml-tensor`))throw Error(`String tensor is not supported on GPU.`);if(o&&d!==`gpu-buffer`)throw Error(`External buffer must be provided for input/output index ${a} when enableGraphCapture is true.`);if(d===`gpu-buffer`){let t=e[2].gpuBuffer;m=Ct(xt(l),u);{let e=s.jsepRegisterBuffer;if(!e)throw Error(`Tensor location "gpu-buffer" is not supported without using WebGPU.`);p=e(r,a,t,m)}}else if(d===`ml-tensor`){let t=e[2].mlTensor;m=Ct(xt(l),u);let n=s.webnnRegisterMLTensor;if(!n)throw Error(`Tensor location "ml-tensor" is not supported without using WebNN.`);p=n(r,t,xt(l),u)}else{let t=e[2];if(Array.isArray(t)){m=c*t.length,p=s._malloc(m),n.push(p);for(let e=0;e<t.length;e++){if(typeof t[e]!=`string`)throw TypeError(`tensor data at index ${e} is not a string`);s.setValue(p+e*c,lt(t[e],n),`*`)}}else{let e=s.webnnIsGraphInput,a=s.webnnIsGraphOutput;if(l!==`string`&&e&&a){let o=s.UTF8ToString(i);if(e(r,o)||a(r,o)){let e=xt(l);m=Ct(e,u),f=`ml-tensor`;let n=s.webnnCreateTemporaryTensor,i=s.webnnUploadTensor;if(!n||!i)throw Error(`Tensor location "ml-tensor" is not supported without using WebNN.`);let a=await n(r,e,u);i(a,new Uint8Array(t.buffer,t.byteOffset,t.byteLength)),p=a}else m=t.byteLength,p=s._malloc(m),n.push(p),s.HEAPU8.set(new Uint8Array(t.buffer,t.byteOffset,m),p)}else m=t.byteLength,p=s._malloc(m),n.push(p),s.HEAPU8.set(new Uint8Array(t.buffer,t.byteOffset,m),p)}}let h=s.stackSave(),g=s.stackAlloc(4*u.length);try{u.forEach((e,t)=>s.setValue(g+t*c,e,c===4?`i32`:`i64`));let e=s._OrtCreateTensor(xt(l),p,m,g,u.length,Ot(f));e===0&&I(`Can't create tensor for input/output. session=${r}, index=${a}.`),t.push(e)}finally{s.stackRestore(h)}},Vu=async(e,t,n,r,i,a)=>{let o=F(),s=o.PTR_SIZE,c=Pu.get(e);if(!c)throw Error(`cannot run inference. invalid session id: ${e}`);let l=c[0],u=c[1],d=c[2],f=c[3],p=c[4],m=c[5],h=t.length,g=r.length,_=0,v=[],y=[],b=[],x=[],S=[],C=o.stackSave(),w=o.stackAlloc(h*s),ee=o.stackAlloc(h*s),T=o.stackAlloc(g*s),E=o.stackAlloc(g*s);try{[_,v]=ft(a),M(`wasm prepareInputOutputTensor`);for(let r=0;r<h;r++)await Bu(n[r],y,x,e,u[t[r]],t[r],p);for(let t=0;t<g;t++)await Bu(i[t],b,x,e,d[r[t]],h+r[t],p);ve(`wasm prepareInputOutputTensor`);for(let e=0;e<h;e++)o.setValue(w+e*s,y[e],`*`),o.setValue(ee+e*s,u[t[e]],`*`);for(let e=0;e<g;e++)o.setValue(T+e*s,b[e],`*`),o.setValue(E+e*s,d[r[e]],`*`);if(f&&!m){let{handle:n,outputPreferredLocations:a,outputPreferredLocationsEncoded:s}=f;if(u.length!==h)throw Error(`input count from feeds (${h}) is expected to be always equal to model's input count (${u.length}).`);M(`wasm bindInputsOutputs`);for(let r=0;r<h;r++){let i=t[r];await o._OrtBindInput(n,u[i],y[r])!==0&&I(`Can't bind input[${r}] for session=${e}.`)}for(let t=0;t<g;t++){let c=r[t];i[t]?.[3]?(S.push(b[t]),o._OrtBindOutput(n,d[c],b[t],0)!==0&&I(`Can't bind pre-allocated output[${t}] for session=${e}.`)):o._OrtBindOutput(n,d[c],0,s[c])!==0&&I(`Can't bind output[${t}] to ${a[t]} for session=${e}.`)}ve(`wasm bindInputsOutputs`),Pu.set(e,[l,u,d,f,p,!0])}o.jsepOnRunStart?.(l),o.webnnOnRunStart?.(l);let c;c=f?await o._OrtRunWithBinding(l,f.handle,g,T,_):await o._OrtRun(l,ee,w,h,E,g,T,_),c!==0&&I(`failed to call OrtRun().`);let C=[],D=[];M(`wasm ProcessOutputTensor`);for(let t=0;t<g;t++){let n=Number(o.getValue(T+t*s,`*`));if(n===b[t]||S.includes(b[t])){C.push(i[t]),n!==b[t]&&o._OrtReleaseTensor(n)!==0&&I(`Can't release tensor.`);continue}let a=o.stackSave(),c=o.stackAlloc(4*s),l=!1,u,d=0;try{o._OrtGetTensorData(n,c,c+s,c+2*s,c+3*s)!==0&&I(`Can't access output tensor data on index ${t}.`);let i=s===4?`i32`:`i64`,a=Number(o.getValue(c,i));d=o.getValue(c+s,`*`);let p=o.getValue(c+s*2,`*`),m=Number(o.getValue(c+s*3,i)),h=[];for(let e=0;e<m;e++)h.push(Number(o.getValue(p+e*s,i)));o._OrtFree(p)!==0&&I(`Can't free memory for tensor dims.`);let g=h.reduce((e,t)=>e*t,1);u=St(a);let _=f?.outputPreferredLocations[r[t]];if(u===`string`){if(_===`gpu-buffer`||_===`ml-tensor`)throw Error(`String tensor is not supported on GPU.`);let e=[];for(let t=0;t<g;t++){let n=o.getValue(d+t*s,`*`),r=o.getValue(d+(t+1)*s,`*`),i=t===g-1?void 0:r-n;e.push(o.UTF8ToString(n,i))}C.push([u,h,e,`cpu`])}else if(_===`gpu-buffer`&&g>0){let e=o.jsepGetBuffer;if(!e)throw Error(`preferredLocation "gpu-buffer" is not supported without using WebGPU.`);let t=e(d),r=Ct(a,g);if(r===void 0||!Et(u))throw Error(`Unsupported data type: ${u}`);l=!0,C.push([u,h,{gpuBuffer:t,download:o.jsepCreateDownloader(t,r,u),dispose:()=>{o._OrtReleaseTensor(n)!==0&&I(`Can't release tensor.`)}},`gpu-buffer`])}else if(_===`ml-tensor`&&g>0){let t=o.webnnEnsureTensor,r=o.webnnIsGraphInputOutputTypeSupported;if(!t||!r)throw Error(`preferredLocation "ml-tensor" is not supported without using WebNN.`);if(Ct(a,g)===void 0||!Dt(u))throw Error(`Unsupported data type: ${u}`);if(!r(e,u,!1))throw Error(`preferredLocation "ml-tensor" for ${u} output is not supported by current WebNN Context.`);let i=await t(e,d,a,h,!1);l=!0,C.push([u,h,{mlTensor:i,download:o.webnnCreateMLTensorDownloader(d,u),dispose:()=>{o.webnnReleaseTensorId(d),o._OrtReleaseTensor(n)}},`ml-tensor`])}else if(_===`ml-tensor-cpu-output`&&g>0){let e=o.webnnCreateMLTensorDownloader(d,u)(),t=C.length;l=!0,D.push((async()=>{let r=[t,await e];return o.webnnReleaseTensorId(d),o._OrtReleaseTensor(n),r})()),C.push([u,h,[],`cpu`])}else{let e=new(wt(u))(g);new Uint8Array(e.buffer,e.byteOffset,e.byteLength).set(o.HEAPU8.subarray(d,d+e.byteLength)),C.push([u,h,e,`cpu`])}}finally{o.stackRestore(a),u===`string`&&d&&o._free(d),l||o._OrtReleaseTensor(n)}}f&&!p&&(o._OrtClearBoundOutputs(f.handle)!==0&&I(`Can't clear bound outputs.`),Pu.set(e,[l,u,d,f,p,!1]));for(let[e,t]of await Promise.all(D))C[e][2]=t;return ve(`wasm ProcessOutputTensor`),C}finally{o.webnnOnRunEnd?.(l),o.stackRestore(C),y.forEach(e=>o._OrtReleaseTensor(e)),b.forEach(e=>o._OrtReleaseTensor(e)),x.forEach(e=>o._free(e)),_!==0&&o._OrtReleaseRunOptions(_),v.forEach(e=>o._free(e))}},Hu=e=>{let t=F(),n=Pu.get(e);if(!n)throw Error(`invalid session id`);let r=n[0],i=t._OrtEndProfiling(r);i===0&&I(`Can't get an profile file name.`),t._OrtFree(i)},Uu=e=>{let t=[];for(let n of e){let e=n[2];!Array.isArray(e)&&`buffer`in e&&t.push(e.buffer)}return t}}),Gu,Ku,qu,Ju,Yu,Xu,Zu,Qu,$u,ed,td,nd,rd,id,ad,od,sd,cd,ld=o(()=>{N(),Wu(),ct(),$e(),Gu=()=>!!S.wasm.proxy&&typeof document<`u`,qu=!1,Ju=!1,Yu=!1,Qu=new Map,$u=(e,t)=>{let n=Qu.get(e);n?n.push(t):Qu.set(e,[t])},ed=()=>{if(qu||!Ju||Yu||!Ku)throw Error(`worker not ready`)},td=e=>{switch(e.data.type){case`init-wasm`:qu=!1,e.data.err?(Yu=!0,Zu[1](e.data.err)):(Ju=!0,Zu[0]()),Xu&&=(URL.revokeObjectURL(Xu),void 0);break;case`init-ep`:case`copy-from`:case`create`:case`release`:case`run`:case`end-profiling`:{let t=Qu.get(e.data.type);e.data.err?t.shift()[1](e.data.err):t.shift()[0](e.data.out);break}default:}},nd=async()=>{if(!Ju){if(qu)throw Error(`multiple calls to 'initWasm()' detected.`);if(Yu)throw Error(`previous call to 'initWasm()' failed.`);if(qu=!0,Gu())return new Promise((e,t)=>{Ku?.terminate(),P().then(([n,r])=>{try{Ku=r,Ku.onerror=e=>t(e),Ku.onmessage=td,Zu=[e,t];let i={type:`init-wasm`,in:S};!i.in.wasm.wasmPaths&&(n||Ve)&&(i.in.wasm.wasmPaths={wasm:new URL(`/assets/ort-wasm-simd-threaded.jsep-Bzonanhp.wasm`,``+import.meta.url).href}),Ku.postMessage(i),Xu=n}catch(e){t(e)}},t)});try{await st(S.wasm),await Mu(S),Ju=!0}catch(e){throw Yu=!0,e}finally{qu=!1}}},rd=async e=>{if(Gu())return ed(),new Promise((t,n)=>{$u(`init-ep`,[t,n]);let r={type:`init-ep`,in:{epName:e,env:S}};Ku.postMessage(r)});await Nu(S,e)},id=async e=>Gu()?(ed(),new Promise((t,n)=>{$u(`copy-from`,[t,n]);let r={type:`copy-from`,in:{buffer:e}};Ku.postMessage(r,[e.buffer])})):Lu(e),ad=async(e,t)=>{if(Gu()){if(t?.preferredOutputLocation)throw Error(`session option "preferredOutputLocation" is not supported for proxy.`);return ed(),new Promise((n,r)=>{$u(`create`,[n,r]);let i={type:`create`,in:{model:e,options:{...t}}},a=[];e instanceof Uint8Array&&a.push(e.buffer),Ku.postMessage(i,a)})}else return Ru(e,t)},od=async e=>{if(Gu())return ed(),new Promise((t,n)=>{$u(`release`,[t,n]);let r={type:`release`,in:e};Ku.postMessage(r)});zu(e)},sd=async(e,t,n,r,i,a)=>{if(Gu()){if(n.some(e=>e[3]!==`cpu`))throw Error(`input tensor on GPU is not supported for proxy.`);if(i.some(e=>e))throw Error(`pre-allocated output tensor is not supported for proxy.`);return ed(),new Promise((i,o)=>{$u(`run`,[i,o]);let s=n,c={type:`run`,in:{sessionId:e,inputIndices:t,inputs:s,outputIndices:r,options:a}};Ku.postMessage(c,Uu(s))})}else return Vu(e,t,n,r,i,a)},cd=async e=>{if(Gu())return ed(),new Promise((t,n)=>{$u(`end-profiling`,[t,n]);let r={type:`end-profiling`,in:e};Ku.postMessage(r)});Hu(e)}}),ud,dd,fd,pd=o(()=>{N(),ld(),L(),ke(),At(),ud=(e,t)=>{switch(e.location){case`cpu`:return[e.type,e.dims,e.data,`cpu`];case`gpu-buffer`:return[e.type,e.dims,{gpuBuffer:e.gpuBuffer},`gpu-buffer`];case`ml-tensor`:return[e.type,e.dims,{mlTensor:e.mlTensor},`ml-tensor`];default:throw Error(`invalid data location: ${e.location} for ${t()}`)}},dd=e=>{switch(e[3]){case`cpu`:return new pe(e[0],e[2],e[1]);case`gpu-buffer`:{let t=e[0];if(!Et(t))throw Error(`not supported data type: ${t} for deserializing GPU tensor`);let{gpuBuffer:n,download:r,dispose:i}=e[2];return pe.fromGpuBuffer(n,{dataType:t,dims:e[1],download:r,dispose:i})}case`ml-tensor`:{let t=e[0];if(!Dt(t))throw Error(`not supported data type: ${t} for deserializing MLTensor tensor`);let{mlTensor:n,download:r,dispose:i}=e[2];return pe.fromMLTensor(n,{dataType:t,dims:e[1],download:r,dispose:i})}default:throw Error(`invalid data location: ${e[3]}`)}},fd=class{async fetchModelAndCopyToWasmMemory(e){return id(await kt(e))}async loadModel(e,t){_e();let n;n=typeof e==`string`?await this.fetchModelAndCopyToWasmMemory(e):e,[this.sessionId,this.inputNames,this.outputNames,this.inputMetadata,this.outputMetadata]=await ad(n,t),j()}async dispose(){return od(this.sessionId)}async run(e,t,n){_e();let r=[],i=[];Object.entries(e).forEach(e=>{let t=e[0],n=e[1],a=this.inputNames.indexOf(t);if(a===-1)throw Error(`invalid input '${t}'`);r.push(n),i.push(a)});let a=[],o=[];Object.entries(t).forEach(e=>{let t=e[0],n=e[1],r=this.outputNames.indexOf(t);if(r===-1)throw Error(`invalid output '${t}'`);a.push(n),o.push(r)});let s=r.map((e,t)=>ud(e,()=>`input "${this.inputNames[i[t]]}"`)),c=a.map((e,t)=>e?ud(e,()=>`output "${this.outputNames[o[t]]}"`):null),l=await sd(this.sessionId,i,s,o,c,n),u={};for(let e=0;e<l.length;e++)u[this.outputNames[o[e]]]=a[e]??dd(l[e]);return j(),u}startProfiling(){}endProfiling(){cd(this.sessionId)}}}),md={};s(md,{OnnxruntimeWebAssemblyBackend:()=>gd,initializeFlags:()=>hd,wasmBackend:()=>_d});var hd,gd,_d,vd=o(()=>{N(),ld(),pd(),hd=()=>{(typeof S.wasm.initTimeout!=`number`||S.wasm.initTimeout<0)&&(S.wasm.initTimeout=0);let e=S.wasm.simd;if(typeof e!=`boolean`&&e!==void 0&&e!==`fixed`&&e!==`relaxed`&&(console.warn(`Property "env.wasm.simd" is set to unknown value "${e}". Reset it to \`false\` and ignore SIMD feature checking.`),S.wasm.simd=!1),typeof S.wasm.proxy!=`boolean`&&(S.wasm.proxy=!1),typeof S.wasm.trace!=`boolean`&&(S.wasm.trace=!1),typeof S.wasm.numThreads!=`number`||!Number.isInteger(S.wasm.numThreads)||S.wasm.numThreads<=0)if(typeof self<`u`&&!self.crossOriginIsolated)S.wasm.numThreads=1;else{let e=typeof navigator>`u`?a(`node:os`).cpus().length:navigator.hardwareConcurrency;S.wasm.numThreads=Math.min(4,Math.ceil((e||1)/2))}},gd=class{async init(e){hd(),await nd(),await rd(e)}async createInferenceSessionHandler(e,t){let n=new fd;return await n.loadModel(e,t),n}},_d=new gd});N(),N(),N();var yd=`1.25.0-dev.20260307-d626b568e0`,bd=Oe;{let e=(vd(),l(md)).wasmBackend;f(`webgpu`,e,5),f(`webnn`,e,5),f(`cpu`,e,10),f(`wasm`,e,10)}Object.defineProperty(S.versions,`web`,{value:yd,enumerable:!0}); | |
| /** | |
| * @license | |
| * Copyright 2021 Google LLC. All Rights Reserved. | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| * ============================================================================= | |
| */ | |
| /** | |
| * @license | |
| * Copyright 2020 Google LLC. All Rights Reserved. | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| * ============================================================================= | |
| */ | |
| /** | |
| * @license | |
| * Copyright 2019 Google LLC. All Rights Reserved. | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| * ============================================================================= | |
| */ | |
| export{Se as InferenceSession,he as TRACE,M as TRACE_EVENT_BEGIN,ve as TRACE_EVENT_END,_e as TRACE_FUNC_BEGIN,j as TRACE_FUNC_END,pe as Tensor,bd as default,S as env,f as registerBackend}; |