cuda q4_0_r4: dot product works

I get basically the same TG performance as Q4_0.
This commit is contained in:
Iwan Kawrakow
2024-12-06 18:17:40 +02:00
parent 85a0730447
commit d9589d82cb

View File

@@ -14,7 +14,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
// constexpr int vdr = get_vdr_mmvq(type);
namespace {
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
template <int qk, int qi, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
@@ -23,9 +23,6 @@ __global__ void iqk_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
@@ -137,30 +134,33 @@ void iqk_mul_mat_vec_q_cuda(
const int64_t row_size = ggml_row_size(type, ncols_x);
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
switch (ncols_y) {
case 1:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 2:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 3:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 4:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 5:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 6:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 7:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 8:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
default:
GGML_ASSERT(false);
@@ -168,11 +168,194 @@ void iqk_mul_mat_vec_q_cuda(
}
}
__device__ __forceinline__ float vec_dot_q4_0_r4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
return 0;
//template<>
//struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
// static constexpr int qk = QK4_0 = 32
// static constexpr int qr = QR4_0 = 2
// static constexpr int qi = QI4_0 = 4
//};
// #define VDR_Q4_0_Q8_1_MMVQ 2
// #define VDR_Q4_0_Q8_1_MMQ 4
// constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi = 2*nwarps*32/4 = 16*nwarps
using block_q4_0_r4 = block_iq4_nl_x4;
__device__ __forceinline__ float vec_dot_q4_0_r4_q8_1_x(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ y, const int & kbx, const int & l, float * __restrict__ result) {
// We will have each thread process 32 quants, so 8 quants in each of the 4 interleaved rows
// I
const block_q4_0_r4 * x = (block_q4_0_r4 *)vbq + kbx;
//const int l = kbx%4;
const half2 * d4h = (const half2 *)x->d;
float2 d4[2];
const float * d = (const float *)d4;
d4[0] = __half22float2(d4h[0]);
d4[1] = __half22float2(d4h[1]);
//const float d8 = __low2float(y->ds);
const float2 d8 = __half22float2(y->ds);
const int * q8 = (const int *)y->qs + 4*(l%2) + l/2;
const int * q4 = (const int *)x->qs + 4*l;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
for (int k = 0; k < 4; ++k) {
// TODO: avoid the __vsub, use the sum stored in Q8_1 instead.
//int v1 = __vsub4(q4[k] & 0x0f0f0f0f, 0x08080808);
//int v2 = __vsub4((q4[k] >> 4) & 0x0f0f0f0f, 0x08080808);
//int dot = __dp4a(v1, q8[0], __dp4a(v2, q8[2], 0));
//result[k] += d[k]*d8*dot;
int v1 = q4[k] & 0x0f0f0f0f;
int v2 = (q4[k] >> 4) & 0x0f0f0f0f;
int dot = __dp4a(v1, q8[0], __dp4a(v2, q8[2], 0));
result[k] += d[k]*(d8.x*dot - 2.f*d8.y);
}
#else
NO_DEVICE_CODE;
#endif
}
template <int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_mul_mat_vec_q4_0_r4(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
// constexpr int nwarps = 1;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = 4*blockIdx.x;
const int blocks_per_row_x = ncols_x / 32;
const int blocks_per_col_y = nrows_y / 32;
constexpr int blocks_per_iter = nwarps*WARP_SIZE/4;
// partial sum for each thread
float tmp[ncols_y][4] = {0.0f};
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int kbx = tid/4; kbx < blocks_per_row_x; kbx += blocks_per_iter) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
vec_dot_q4_0_r4_q8_1_x((const void *)((const char *)vx + row0*row_size),
&y[j*blocks_per_col_y + kbx], kbx, tid%4, tmp[j]);
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][4][WARP_SIZE];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
}
}
}
__syncthreads();
if (threadIdx.y > 0) {
return;
}
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
}
if (threadIdx.x < 4 && (row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}
}
void iqk_mul_mat_vec_q4_0_r4_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
GGML_ASSERT(ncols_x % 32 == 0);
GGML_ASSERT(nrows_x % 4 == 0);
int id = ggml_cuda_get_device();
int nwarps = 1;
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
case 2:
case 3:
case 4:
nwarps = 4;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
break;
default:
GGML_ASSERT(false);
break;
}
}
const int64_t nblocks = nrows_x/4;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const int64_t row_size = ggml_row_size(GGML_TYPE_Q4_0_R4, ncols_x);
switch (ncols_y) {
case 1:
iqk_mul_mat_vec_q4_0_r4<1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 2:
iqk_mul_mat_vec_q4_0_r4<2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 3:
iqk_mul_mat_vec_q4_0_r4<3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 4:
iqk_mul_mat_vec_q4_0_r4<4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 5:
iqk_mul_mat_vec_q4_0_r4<5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 6:
iqk_mul_mat_vec_q4_0_r4<6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 7:
iqk_mul_mat_vec_q4_0_r4<7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 8:
iqk_mul_mat_vec_q4_0_r4<8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
default:
GGML_ASSERT(false);
break;
}
}
__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
int & val1, int & val2) {
@@ -736,8 +919,7 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
void mul_mat_vec_q4_0_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
iqk_mul_mat_vec_q_cuda<GGML_TYPE_Q4_0_R4, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_q4_0_r4_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
iqk_mul_mat_vec_q4_0_r4_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
void mul_mat_vec_iq2_k_q8_1_cuda(