iq2_kt: f16 CUDA dot product

We arrive at 112 t/s.
This commit is contained in:
Iwan Kawrakow
2024-11-07 12:32:10 +02:00
parent aed3910dfa
commit b354392c77

View File

@@ -15,7 +15,7 @@
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst,
static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ vx, const dfloat * __restrict__ yy, float * __restrict__ dst,
const int ncols, int nrows, int64_t row_size) {
constexpr uint32_t ka = 89226354;
@@ -32,7 +32,7 @@ static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ v
const int num_blocks_per_row = ncols / QK_K;
float tmp = 0; // partial sum for thread in warp
dfloat2 tmp = {0, 0};
const int it = threadIdx.x;
@@ -40,26 +40,37 @@ static __global__ void dequantize_mul_mat_vec_iq2_kt(const void * __restrict__ v
const half * h = (const half *)&s;
for (int i = 0; i < num_blocks_per_row; ++i) {
const float * y = yy + i * QK_K + 8*it;
const float dl = iq4k_values[(x[i].scales[(it/4)%4] >> 4*(it/16)) & 0xf];
const dfloat2 * y = (const dfloat2 *)(yy + i * QK_K + 8*it);
const uint16_t * ql = (const uint16_t *)x[i].ql;
float bdot = 0;
const dfloat scale = iq4k_values[(x[i].scales[(it/4)%4] >> 4*(it/16)) & 0xf];
const dfloat2 dl = {scale, scale};
dfloat2 bdot = {0, 0};
uint32_t val = ql[it] + 4096;
for (int k = 0; k < 8; k += 2) {
for (int k = 0; k < 4; ++k) {
val = ka*val + kb;
s[0] = (val & kmask) ^ km32;
val = ka*val + kb;
s[1] = (val & kmask) ^ km32;
bdot += y[k+0] * (float)(h[0] + h[1]) + y[k+1] * (float)(h[2] + h[3]);
#ifdef GGML_CUDA_F16
bdot += __hmul2(y[k], {h[0]+h[1], h[2]+h[3]});
#else
bdot.x += y[k].x * (float)(h[0] + h[1]);
bdot.y += y[k].y * (float)(h[2] + h[3]);
#endif
}
tmp += dl*bdot;
#ifdef GGML_CUDA_F16
tmp += __hmul2(dl, bdot);
#else
tmp.x += dl.x * bdot.x;
tmp.y += dl.y * bdot.y;
#endif
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp*d;
dst[row] = d * (float)(tmp.x + tmp.y);
}
}
@@ -609,7 +620,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_iq2_kt_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void dequantize_mul_mat_vec_iq2_kt_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny;
@@ -680,7 +691,8 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 ||
src0->type == GGML_TYPE_IQ2_KT;
if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00);
@@ -712,7 +724,7 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_IQ2_KT:
dequantize_mul_mat_vec_iq2_kt_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_iq2_kt_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);