| |
| |
|
|
| |
|
|
| template <typename F> |
| __global__ void kernel_forward(const int B, const int T, const int C, |
| const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
| F *__restrict__ const _y) { |
| const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| const int _b = idx / C; |
| const int _c = idx % C; |
| const int _offset = _b * T * C + _c; |
|
|
| F u = _u[_c]; |
| F w = _w[_c]; |
| const F *__restrict__ const k = _k + _offset; |
| const F *__restrict__ const v = _v + _offset; |
| F *__restrict__ const y = _y + _offset; |
|
|
| // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| F aa = 0, bb = 0, pp = MIN_VALUE; |
| for (int i = 0; i < T; i++) { |
| const int ii = i * C; |
| const F kk = k[ii]; |
| const F vv = v[ii]; |
|
|
| F ww = u + kk; |
| F p = max(pp, ww); |
| F e1 = exp(pp - p); |
| F e2 = exp(ww - p); |
| y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); |
| |
| ww = w + pp; |
| p = max(ww, kk); |
| e1 = exp(ww - p); |
| e2 = exp(kk - p); |
| aa = e1 * aa + e2 * vv; |
| bb = e1 * bb + e2; |
| pp = p; |
| } |
| } |
|
|
| template <typename F> |
| __global__ void kernel_backward(const int B, const int T, const int C, |
| const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
| const F *__restrict__ const _y, const F *__restrict__ const _gy, |
| F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { |
| const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| const int _b = idx / C; |
| const int _c = idx % C; |
| const int _offset = _b * T * C + _c; |
|
|
| F u = _u[_c]; |
| F w = _w[_c]; |
| const F *__restrict__ const k = _k + _offset; |
| const F *__restrict__ const v = _v + _offset; |
| const F *__restrict__ const y = _y + _offset; |
| const F *__restrict__ const gy = _gy + _offset; |
| F *__restrict__ const gk = _gk + _offset; |
| F *__restrict__ const gv = _gv + _offset; |
|
|
| F q[Tmax], r[Tmax]; |
|
|
| F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; |
| for (int i = 0; i < T; i++) { |
| const int ii = i * C; |
| const F kk = k[ii]; |
| const F vv = v[ii]; |
| const F yy = y[ii]; |
|
|
| F ww = u + kk; |
| F p = max(pp, ww); |
| F e1 = exp(pp - p); |
| F e2 = exp(ww - p); |
| const F qq = gy[ii] / (e1 * bb + e2); |
| gw += (ga - gb * yy) * e1 * qq; |
| gu += (vv - yy) * e2 * qq; |
| q[i] = qq; |
| r[i] = ww - p; |
|
|
| ww = w + pp; |
| p = max(ww, kk); |
| e1 = exp(ww - p); |
| e2 = exp(kk - p); |
| ga = e1 * (aa + ga); |
| gb = e1 * (bb + gb); |
| aa = e1 * aa + e2 * vv; |
| bb = e1 * bb + e2; |
| pp = p; |
| } |
| const int _offsetBC = _b * C + _c; |
| _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() |
| _gu[_offsetBC] = gu; |
|
|
| aa = 0, bb = 0, pp = MIN_VALUE; |
| for (int i = T - 1; i >= 0; i--) { |
| const int ii = i * C; |
| const F kk = k[ii]; |
| const F vv = v[ii]; |
| const F yy = y[ii]; |
| const F qq = q[i]; |
| const F rr = r[i]; |
|
|
| F e1 = qq * exp(rr); |
| F e2 = exp(kk + pp); |
| gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); |
| gv[ii] = e1 + e2 * aa; |
|
|
| const F ww = w + pp; |
| const F www = rr - u - kk; |
| const F p = max(ww, www); |
| e1 = exp(ww - p); |
| e2 = qq * exp(www - p); |
| aa = e1 * aa + e2; |
| bb = e1 * bb - e2 * yy; |
| pp = p; |
| } |
| } |
|
|
| void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { |
| dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| assert(B * C % threadsPerBlock.x == 0); |
| dim3 numBlocks(B * C / threadsPerBlock.x); |
| kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); |
| } |
|
|
| void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { |
| dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| assert(B * C % threadsPerBlock.x == 0); |
| dim3 numBlocks(B * C / threadsPerBlock.x); |
| kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); |
| } |
|
|