From 7cafafc69e08fc3ed5b90a4722f4bd6eeabd0b6c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 7 Nov 2024 14:20:34 +0200 Subject: [PATCH] iq2_kt: faster f16 CUDA dot product We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance. --- ggml/src/ggml-cuda/dmmv.cu | 43 +++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 0edfece5..088b2e8c 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -34,32 +34,45 @@ static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ v dfloat2 tmp = {0, 0}; - const int it = threadIdx.x; + const int it = threadIdx.x/2; + const int ix = threadIdx.x%2; - uint32_t s[2]; + uint32_t s[4]; const half * h = (const half *)&s; - for (int i = 0; i < num_blocks_per_row; ++i) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it); const uint16_t * ql = (const uint16_t *)x[i].ql; - const dfloat scale = iq4k_values[(x[i].scales[(it/4)%4] >> 4*(it/16)) & 0xf]; - const dfloat2 dl = {scale, scale}; - dfloat2 bdot = {0, 0}; - uint32_t val = ql[it] + 4096; + const dfloat scale1 = iq4k_values[x[i].scales[it/4] & 0xf]; + const dfloat scale2 = iq4k_values[x[i].scales[it/4] >> 4]; + const dfloat2 dl1 = {scale1, scale1}; + const dfloat2 dl2 = {scale2, scale2}; + dfloat2 bdot1 = {0, 0}; + dfloat2 bdot2 = {0, 0}; + uint32_t val1 = ql[it+ 0] + 4096; + uint32_t val2 = ql[it+16] + 4096; for (int k = 0; k < 4; ++k) { - val = ka*val + kb; - s[0] = (val & kmask) ^ km32; - val = ka*val + kb; - s[1] = (val & kmask) ^ km32; + val1 = ka*val1 + kb; + s[0] = (val1 & kmask) ^ km32; + val1 = ka*val1 + kb; + s[1] = (val1 & kmask) ^ km32; + val2 = ka*val2 + kb; + s[2] = (val2 & kmask) ^ km32; + val2 = ka*val2 + kb; + s[3] = (val2 & kmask) ^ km32; #ifdef GGML_CUDA_F16 - bdot += __hmul2(y[k], {h[0]+h[1], h[2]+h[3]}); + bdot1 += __hmul2(y[k+ 0], {h[0]+h[1], h[2]+h[3]}); + bdot2 += __hmul2(y[k+64], {h[4]+h[5], h[6]+h[7]}); #else - bdot.x += y[k].x * (float)(h[0] + h[1]); - bdot.y += y[k].y * (float)(h[2] + h[3]); + bdot.x += y[k+ 0].x * (float)(h[0] + h[1]); + bdot.y += y[k+ 0].y * (float)(h[2] + h[3]); + bdot.x += y[k+64].x * (float)(h[4] + h[5]); + bdot.y += y[k+64].y * (float)(h[6] + h[7]); #endif } #ifdef GGML_CUDA_F16 - tmp += __hmul2(dl, bdot); + tmp += __hmul2(dl1, bdot1); + tmp += __hmul2(dl2, bdot2); #else tmp.x += dl.x * bdot.x; tmp.y += dl.y * bdot.y;