diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index f7b1c827..53fcc661 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -15,13 +15,34 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif -static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, - const int ncols, int nrows, int64_t row_size) { - +static __device__ __forceinline__ uint32_t trellis_next(uint32_t& val) { constexpr uint32_t ka = 89226354; constexpr uint32_t kb = 64248484; constexpr uint32_t kmask = 0x8fff8fff; constexpr uint32_t km32 = 0x3b603b60; + val = ka*val + kb; + return (val & kmask) ^ km32; +} + +static __device__ __forceinline__ void trellis_accum(uint32_t& val1, uint32_t& val2, uint32_t* s, const dfloat2* y, dfloat2& bdot1, dfloat2& bdot2) { + const half * h = (const half *)s; + s[0] = trellis_next(val1); + s[1] = trellis_next(val1); + s[2] = trellis_next(val2); + s[3] = trellis_next(val2); +#ifdef GGML_CUDA_F16 + bdot1 = __hfma2(y[ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); + bdot2 = __hfma2(y[64], {h[4]+h[5], h[6]+h[7]}, bdot2); +#else + bdot1.x += y[ 0].x * (float)(h[0] + h[1]); + bdot1.y += y[ 0].y * (float)(h[2] + h[3]); + bdot2.x += y[64].x * (float)(h[4] + h[5]); + bdot2.y += y[64].y * (float)(h[6] + h[7]); +#endif +} + +static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, + const int ncols, int nrows, int64_t row_size) { const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -38,7 +59,6 @@ static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ v const int ix = threadIdx.x%2; uint32_t s[4]; - const half * h = (const half *)&s; for (int i = ix; i < num_blocks_per_row; i += 2) { const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); @@ -52,19 +72,7 @@ static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ v uint32_t val1 = ql[it+ 0] + 4096; uint32_t val2 = ql[it+16] + 4096; for (int k = 0; k < 4; ++k) { - val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32; - val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32; - val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32; - val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32; -#ifdef GGML_CUDA_F16 - bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); - bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2); -#else - bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]); - bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]); - bdot2.x += y[k+64].x * (float)(h[4] + h[5]); - bdot2.y += y[k+64].y * (float)(h[6] + h[7]); -#endif + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); } #ifdef GGML_CUDA_F16 tmp = __hfma2(dl1, bdot1, tmp); @@ -86,11 +94,6 @@ static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ v static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, int64_t row_size) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; - const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -106,7 +109,6 @@ static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ v const int ix = threadIdx.x%2; uint32_t s[4]; - const half * h = (const half *)s; for (int i = ix; i < num_blocks_per_row; i += 2) { const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); @@ -121,36 +123,12 @@ static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ v uint32_t val1 = ql[2*it+ 0] + ((qh[2*it+0] << 8) & 0xf00) + 4096; uint32_t val2 = ql[2*it+32] + ((qh[2*it+0] << 4) & 0xf00) + 4096; for (int k = 0; k < 2; ++k) { - val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32; - val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32; - val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32; - val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32; -#ifdef GGML_CUDA_F16 - bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); - bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2); -#else - bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]); - bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]); - bdot2.x += y[k+64].x * (float)(h[4] + h[5]); - bdot2.y += y[k+64].y * (float)(h[6] + h[7]); -#endif + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); } val1 = ql[2*it+ 1] + ((qh[2*it+1] << 8) & 0xf00) + 4096; val2 = ql[2*it+33] + ((qh[2*it+1] << 4) & 0xf00) + 4096; for (int k = 2; k < 4; ++k) { - val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32; - val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32; - val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32; - val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32; -#ifdef GGML_CUDA_F16 - bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); - bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2); -#else - bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]); - bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]); - bdot2.x += y[k+64].x * (float)(h[4] + h[5]); - bdot2.y += y[k+64].y * (float)(h[6] + h[7]); -#endif + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); } #ifdef GGML_CUDA_F16 tmp = __hfma2(dl1, bdot1, tmp); @@ -172,10 +150,6 @@ static __global__ void dequantize_mul_mat_vec_iq3_kt(const void * __restrict__ v static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, int64_t row_size) { - constexpr uint32_t ka = 89226354; - constexpr uint32_t kb = 64248484; - constexpr uint32_t kmask = 0x8fff8fff; - constexpr uint32_t km32 = 0x3b603b60; constexpr int kNumGroups = 64; const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -198,7 +172,6 @@ static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ v const int jj = ib32*8 + 2*ig; // 0...30 in steps of 2 uint32_t s[4]; - const half * h = (const half *)s; for (int i = ix; i < num_blocks_per_row; i += 2) { const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); @@ -216,42 +189,12 @@ static __global__ void dequantize_mul_mat_vec_iq4_kt(const void * __restrict__ v uint32_t val1 = ql[jj+ 0] + ((qh[jj] << 8) & 0xf00) + (((shb[ib32+0] >> (8 + 6*ig+0)) & 7) << 12) + offset1; uint32_t val2 = ql[jj+32] + ((qh[jj] << 4) & 0xf00) + (((shb[ib32+4] >> (8 + 6*ig+0)) & 7) << 12) + offset2; for (int k = 0; k < 2; ++k) { - val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32; - val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32; - val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32; - val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32; -#ifdef GGML_CUDA_F16 - bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); - bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2); - tmp2 += y[k] + y[k+64]; -#else - bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]); - bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]); - bdot2.x += y[k+64].x * (float)(h[4] + h[5]); - bdot2.y += y[k+64].y * (float)(h[6] + h[7]); - tmp2.x += y[k].x + y[k+64].x; - tmp2.y += y[k].y + y[k+64].y; -#endif + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); } val1 = ql[jj+ 1] + ((qh[jj+1] << 8) & 0xf00) + (((shb[ib32+0] >> (8 + 6*ig+3)) & 7) << 12) + offset1; val2 = ql[jj+33] + ((qh[jj+1] << 4) & 0xf00) + (((shb[ib32+4] >> (8 + 6*ig+3)) & 7) << 12) + offset2; for (int k = 2; k < 4; ++k) { - val1 = ka*val1 + kb; s[0] = (val1 & kmask) ^ km32; - val1 = ka*val1 + kb; s[1] = (val1 & kmask) ^ km32; - val2 = ka*val2 + kb; s[2] = (val2 & kmask) ^ km32; - val2 = ka*val2 + kb; s[3] = (val2 & kmask) ^ km32; -#ifdef GGML_CUDA_F16 - bdot1 = __hfma2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}, bdot1); - bdot2 = __hfma2(y[k+64], {h[4]+h[5], h[6]+h[7]}, bdot2); - tmp2 += y[k] + y[k+64]; -#else - bdot1.x += y[k+ 0].x * (float)(h[0] + h[1]); - bdot1.y += y[k+ 0].y * (float)(h[2] + h[3]); - bdot2.x += y[k+64].x * (float)(h[4] + h[5]); - bdot2.y += y[k+64].y * (float)(h[6] + h[7]); - tmp2.x += y[k].x + y[k+64].x; - tmp2.y += y[k].y + y[k+64].y; -#endif + trellis_accum(val1, val2, s, y+k, bdot1, bdot2); } #ifdef GGML_CUDA_F16 tmp1 = __hfma2(dl1, bdot1, tmp1);