mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 02:20:01 +00:00
CUDA: iq4_k_r4 GEMV
~10% slower than iq4_k.
This commit is contained in:
@@ -759,52 +759,19 @@ static __global__ void dequantize_block_iq4_k_r4(const void * __restrict__ vx, d
|
||||
|
||||
int64_t ii = blockIdx.x;
|
||||
|
||||
//int64_t nblock = n_per_row/256;
|
||||
//int64_t row = ii/nblock;
|
||||
//int64_t ibl = ii - row*nblock;
|
||||
//int64_t row4 = row/4;
|
||||
//int64_t ir = row4%4;
|
||||
|
||||
//const block_iq4_k_r4 * x = (const block_iq4_k_r4 *)vx + row4*nblock;
|
||||
|
||||
//int64_t row4 = (256*ii)/(4*n_per_row); // rows of 4 index
|
||||
//int64_t ibl = ii - row4*n_per_row/64; // block index within the rows of 4
|
||||
//int64_t ir = row4%4; // row
|
||||
|
||||
//int64_t ibl = ii/4;
|
||||
//int ir = ii%4;
|
||||
int64_t nblock = n_per_row/256;
|
||||
int64_t row = ii/nblock;
|
||||
int64_t row4 = row/4;
|
||||
int64_t ir = row%4;
|
||||
int64_t ibl = row4*nblock + ii%nblock;
|
||||
// ii = 0 -> row = 0, row4 = 0, ir = 0, ibl should be 0
|
||||
// ii = 1 -> row = 0, row4 = 0, ir = 0, ibl should be 1
|
||||
// ii = 2 -> row = 0, row4 = 0, ir = 0, ibl should be 2
|
||||
// ...
|
||||
// ii = 16 -> row = 1, row4 = 0, ir = 1, ibl should be 0
|
||||
// ..
|
||||
// ii = 64 -> row = 4, row4 = 1, ir = 0, ibl should be 16
|
||||
|
||||
|
||||
const block_iq4_k_r4 * x = (const block_iq4_k_r4 *)vx;
|
||||
////const block_iq4_k_r4 * x = (const block_iq4_k_r4 *)((const char *)vx + 4*row4*row_size);
|
||||
|
||||
// Say, we have rows of 4096, and we have 8 rows -> 4096*8/256 = 128 blocks of 256, 16 blocks per row
|
||||
// ii = 0 -> ibl = 0, ir = 0 -> warp processes 0...255 in row 0
|
||||
// ii = 1 -> ibl = 0, ir = 1 -> warp processes 0...255 in row 1
|
||||
// ii = 2 -> ibl = 0, ir = 2 -> warp processes 0...255 in row 2
|
||||
// ii = 3 -> ibl = 0, ir = 3 -> warp processes 0...255 in row 3
|
||||
// ii = 4 -> ibl = 1, ir = 0 -> warp processes 256...511 in row 0
|
||||
// ii = 5 -> ibl = 1, ir = 1 -> warp processes 256...511 in row 1
|
||||
// ii = 6 -> ibl = 1, ir = 2 -> warp processes 256...511 in row 2
|
||||
// ii = 7 -> ibl = 1, ir = 3 -> warp processes 256...511 in row 3
|
||||
// ...
|
||||
// ii = 63 -> ibl = 15, ir = 3 -> warp processes 3840...4096 in row 3
|
||||
// ii = 64 -> ibl = 16, ir = 0 -> warp processes 0...255 in row 4, so offset is 4*4096 = 4*16*256
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
|
||||
const block_iq4_k_r4 * x = (const block_iq4_k_r4 *)vx;
|
||||
dst_t * y = yy + 256*ii + 32*ib;
|
||||
|
||||
const float d = __half2float(x[ibl].d[ir]);
|
||||
int is = 8*ib + ir;
|
||||
float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
|
||||
@@ -812,9 +779,6 @@ static __global__ void dequantize_block_iq4_k_r4(const void * __restrict__ vx, d
|
||||
float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32);
|
||||
auto values1 = iq4k_values + (((x[ibl].extra[ir+0] >> ib) & 1) << 4);
|
||||
auto values2 = iq4k_values + (((x[ibl].extra[ir+4] >> ib) & 1) << 4);
|
||||
dst_t * y = yy + 256*ii + 32*ib;
|
||||
//dst_t * y = yy + (4*row4 + ir)*n_per_row + ibl*QK_K + 32*ib;
|
||||
//dst_t * y = yy + ir*n_per_row + 4*ibl*QK_K + 32*ib;
|
||||
auto qs = x[ibl].qs + 64*ib + 4*ir;
|
||||
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
|
||||
y[il+ 0] = __float2bfloat16(dl1 * values1[qs[il+ 0] & 0xf]);
|
||||
|
||||
@@ -6,7 +6,15 @@
|
||||
|
||||
#include "iqk_mmvq.cuh"
|
||||
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
||||
typedef void (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float *);
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K_R4> {
|
||||
static constexpr int qk = QK_K;
|
||||
static constexpr int qr = QR4_XS;
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
|
||||
// Reminder:
|
||||
// constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
@@ -14,7 +22,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
|
||||
// constexpr int vdr = get_vdr_mmvq(type);
|
||||
|
||||
namespace {
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
|
||||
@@ -24,10 +32,10 @@ __device__ void iqk_mul_mat_vec_q(
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
constexpr int rows_per_cuda_block = n_interleaved;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
constexpr int rows_per_cuda_block = n_interleaved == 1 ? ncols_y == 1 ? 1 : 2 : n_interleaved;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
@@ -49,10 +57,15 @@ __device__ void iqk_mul_mat_vec_q(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
if constexpr (n_interleaved == 1) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp[j][i] += vec_dot_q_cuda((const void *)((const char *)vx + (row0 + i)*row_size),
|
||||
&y[j*blocks_per_col_y + kby], kbx, kqs);
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
vec_dot_q_cuda((const void *)((const char *)vx + (row0 + i)*row_size),
|
||||
&y[j*blocks_per_col_y + kby], kbx, kqs, &tmp[j][i]);
|
||||
}
|
||||
} else {
|
||||
vec_dot_q_cuda((const void *)((const char *)vx + row0*row_size),
|
||||
&y[j*blocks_per_col_y + kby], kbx, kqs, tmp[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +103,7 @@ __device__ void iqk_mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
@@ -105,10 +118,10 @@ __global__ void iqk_mul_mat_vec_q(
|
||||
const char * cx = (const char *)vx + i02*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int n_interleaved = 1>
|
||||
void iqk_mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
@@ -120,26 +133,26 @@ void iqk_mul_mat_vec_q_cuda(
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
int64_t rows_per_cuda_block = n_interleaved;
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
nwarps = 4;
|
||||
rows_per_cuda_block = 1;
|
||||
rows_per_cuda_block = n_interleaved == 1 ? 1 : n_interleaved;
|
||||
break;
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
nwarps = 4;
|
||||
rows_per_cuda_block = 2;
|
||||
rows_per_cuda_block = n_interleaved == 1 ? 2 : n_interleaved;
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
case 7:
|
||||
case 8:
|
||||
nwarps = 2;
|
||||
rows_per_cuda_block = 2;
|
||||
rows_per_cuda_block = n_interleaved == 1 ? 2 : n_interleaved;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
@@ -154,28 +167,28 @@ void iqk_mul_mat_vec_q_cuda(
|
||||
|
||||
switch (ncols_y) {
|
||||
case 1:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 2:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 3:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 4:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 5:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 6:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 7:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 8:
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
@@ -202,8 +215,8 @@ __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4,
|
||||
#define VDR_IQ4_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ4_K_Q8_1_MMQ 4
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq4_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq4_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
@@ -226,44 +239,60 @@ __device__ __forceinline__ float vec_dot_iq4_k_q8_1(
|
||||
const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
|
||||
const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
return d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
*result += d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
|
||||
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
|
||||
const int8_t * q0_8 = (const int8_t *) &q0_32;
|
||||
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
|
||||
|
||||
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
|
||||
const int8_t * q1_8 = (const int8_t *) &q1_32;
|
||||
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
|
||||
|
||||
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
|
||||
}
|
||||
|
||||
// TODO
|
||||
__device__ __forceinline__ float vec_dot_iq4_k_r4_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
return 0.f;
|
||||
|
||||
const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
const block_iq4_k_r4 * bq4 = (const block_iq4_k_r4 *)vbq + kbx;
|
||||
|
||||
// iqs is 0...28
|
||||
const int ib32 = iqs/4;
|
||||
// Why iqs/4 ?
|
||||
const float d8 = __low2float(bq8_1[ib32].ds);
|
||||
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
|
||||
const uint16_t extra = bq4->extra >> 2*ib32;
|
||||
int v1, v2;
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
|
||||
get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
|
||||
sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
|
||||
sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
|
||||
|
||||
int2 val1, val2;
|
||||
const int * q4 = (const int *)bq4->qs + 16*ib32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto values1 = iq4k_values + (((bq4->extra[i+0] >> ib32) & 1) << 4);
|
||||
auto values2 = iq4k_values + (((bq4->extra[i+4] >> ib32) & 1) << 4);
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
val1 = get_int_from_table_16(q4[i+ 0], values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[0], ggml_cuda_dp4a(val1.y, q8[2], sumi1));
|
||||
val2 = get_int_from_table_16(q4[i+ 4], values2);
|
||||
sumi2 = ggml_cuda_dp4a(val2.x, q8[4], ggml_cuda_dp4a(val2.y, q8[6], sumi2));
|
||||
val1 = get_int_from_table_16(q4[i+ 8], values1);
|
||||
sumi1 = ggml_cuda_dp4a(val1.x, q8[1], ggml_cuda_dp4a(val1.y, q8[3], sumi1));
|
||||
val2 = get_int_from_table_16(q4[i+12], values2);
|
||||
sumi2 = ggml_cuda_dp4a(val2.x, q8[5], ggml_cuda_dp4a(val2.y, q8[7], sumi2));
|
||||
int is = 8*ib32 + i;
|
||||
int ls1 = (((bq4->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bq4->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32;
|
||||
is += 4;
|
||||
int ls2 = (((bq4->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bq4->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32;
|
||||
const float d = __half2float(bq4->d[i]) * d8;
|
||||
result[i] += d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
}
|
||||
const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
|
||||
const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
|
||||
const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
return d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
}
|
||||
|
||||
#define VDR_IQ4_KS_Q8_1_MMVQ 4
|
||||
#define VDR_IQ4_KS_Q8_1_MMQ 4
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq4_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
@@ -281,14 +310,14 @@ __device__ __forceinline__ float vec_dot_iq4_ks_q8_1(
|
||||
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
|
||||
}
|
||||
return dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
}
|
||||
|
||||
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
|
||||
#define VDR_IQ4_KSS_Q8_1_MMQ 4
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq4_kss_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
@@ -310,7 +339,7 @@ __device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
|
||||
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
|
||||
}
|
||||
return dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
}
|
||||
|
||||
#define VDR_IQ5_K_Q8_1_MMVQ 4
|
||||
@@ -322,9 +351,8 @@ __device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t
|
||||
return v1 | (v2 << 16);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq5_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
__device__ __forceinline__ void vec_dot_iq5_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
const block_iq5_k * bq5 = (const block_iq5_k *) vbq + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq5nl_values;
|
||||
@@ -355,11 +383,11 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
|
||||
const uint8_t sh = bq5->scales_h[i4/2] >> 2*(i4%2);
|
||||
const int ls1 = (((bq5->scales_l[2*(i4/2)+0] >> 4*(i4%2)) & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
const int ls2 = (((bq5->scales_l[2*(i4/2)+1] >> 4*(i4%2)) & 0xf) | ((sh << 0) & 0x30)) - 32;
|
||||
return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
|
||||
*result += d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq5_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq5_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq5_ks * bq5 = (const block_iq5_ks *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
@@ -388,15 +416,14 @@ __device__ __forceinline__ float vec_dot_iq5_ks_q8_1(
|
||||
}
|
||||
const int ls1 = (bq5->scales[2*(i4/2)+0] & 254) - 127;
|
||||
const int ls2 = (bq5->scales[2*(i4/2)+1] & 254) - 127;
|
||||
return scale * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
|
||||
*result += scale * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
|
||||
}
|
||||
|
||||
#define VDR_IQ6_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ6_K_Q8_1_MMQ 4
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq6_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
__device__ __forceinline__ void vec_dot_iq6_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
const block_iq6_k * bq6 = (const block_iq6_k *) vbq + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq6nl_values;
|
||||
@@ -425,7 +452,7 @@ __device__ __forceinline__ float vec_dot_iq6_k_q8_1(
|
||||
sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2);
|
||||
}
|
||||
const float d6 = __half2float(bq6->d);
|
||||
return d6 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * bq6->scales[4*(i4/2)+(i4%2)] + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * bq6->scales[4*(i4/2)+(i4%2)+2]);
|
||||
*result += d6 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * bq6->scales[4*(i4/2)+(i4%2)] + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * bq6->scales[4*(i4/2)+(i4%2)+2]);
|
||||
}
|
||||
|
||||
static const __device__ uint32_t iq2k_table[512] = {
|
||||
@@ -502,8 +529,8 @@ __device__ __forceinline__ int int_from_table_4(const uint8_t * a8, const int *
|
||||
#define VDR_IQ2_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ2_K_Q8_1_MMQ 4
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq2_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq2_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
// iqs is 0, 4, 8, 12, 16, 20, 24, 28
|
||||
// we have 16 packed quants (when cast to int)
|
||||
@@ -554,18 +581,17 @@ __device__ __forceinline__ float vec_dot_iq2_k_q8_1(
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
|
||||
|
||||
return __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
|
||||
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
|
||||
+ __low2float(bq8_1[4*(i4/4)+2].ds) * sumi3
|
||||
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
|
||||
|
||||
*result += __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
|
||||
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
|
||||
+ __low2float(bq8_1[4*(i4/4)+2].ds) * sumi3
|
||||
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
|
||||
}
|
||||
|
||||
#define VDR_IQ2_KS_Q8_1_MMVQ 4
|
||||
#define VDR_IQ2_KS_Q8_1_MMQ 4
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
float scale = *(const half *)vbq;
|
||||
const block_iq2_ks * bq2 = (const block_iq2_ks *)((const char *)vbq + sizeof(half)) + kbx;
|
||||
@@ -614,10 +640,10 @@ __device__ __forceinline__ float vec_dot_iq2_ks_q8_1(
|
||||
v2 = int_from_table_4(a8 + 4, values);
|
||||
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
|
||||
|
||||
return scale * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
|
||||
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
|
||||
+ __low2float(bq8_1[4*(i4/4)+2].ds) * sumi3
|
||||
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
|
||||
*result += scale * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
|
||||
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
|
||||
+ __low2float(bq8_1[4*(i4/4)+2].ds) * sumi3
|
||||
+ __low2float(bq8_1[4*(i4/4)+3].ds) * sumi4);
|
||||
}
|
||||
|
||||
#define VDR_IQ3_K_Q8_1_MMVQ 4
|
||||
@@ -638,8 +664,8 @@ __device__ __forceinline__ int int_from_table_2(const uint8_t * a8, const uint16
|
||||
return values[a8[0] | (a8[1] << 3)] | (values[a8[2] | (a8[3] << 3)] << 16);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq3_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq3_k_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
|
||||
const block_iq3_k * bq3 = (const block_iq3_k *) vbq + kbx;
|
||||
|
||||
int iqs = iiqs/4;
|
||||
@@ -697,15 +723,15 @@ __device__ __forceinline__ float vec_dot_iq3_k_q8_1(
|
||||
const float d = __half2float(bq3->d);
|
||||
const uint16_t * sl16 = (const uint16_t *)bq3->scales_l + 2*ib128;
|
||||
aux32 = ((((sl16[0] | (sl16[1] << 16)) >> shift) & 0x0f0f0f0f) << 1) | 0x01010101;
|
||||
return d * (__low2float(bq8_1[4*ib128+0].ds) * aux8[0] * (sh & 0x01 ? -1 : 1) * sumi[0] +
|
||||
__low2float(bq8_1[4*ib128+1].ds) * aux8[1] * (sh & 0x04 ? -1 : 1) * sumi[1] +
|
||||
__low2float(bq8_1[4*ib128+2].ds) * aux8[2] * (sh & 0x10 ? -1 : 1) * sumi[2] +
|
||||
__low2float(bq8_1[4*ib128+3].ds) * aux8[3] * (sh & 0x40 ? -1 : 1) * sumi[3]);
|
||||
*result += d * (__low2float(bq8_1[4*ib128+0].ds) * aux8[0] * (sh & 0x01 ? -1 : 1) * sumi[0] +
|
||||
__low2float(bq8_1[4*ib128+1].ds) * aux8[1] * (sh & 0x04 ? -1 : 1) * sumi[1] +
|
||||
__low2float(bq8_1[4*ib128+2].ds) * aux8[2] * (sh & 0x10 ? -1 : 1) * sumi[2] +
|
||||
__low2float(bq8_1[4*ib128+3].ds) * aux8[3] * (sh & 0x40 ? -1 : 1) * sumi[3]);
|
||||
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq1_bn_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
half d16; memcpy(&d16, vbq, sizeof(d16));
|
||||
float scale = d16;
|
||||
@@ -739,7 +765,7 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
||||
sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
|
||||
}
|
||||
float2 d8 = __half22float2(bq8_1[iqs].ds);
|
||||
return scale * (d8.x * sumi - d8.y);
|
||||
*result += scale * (d8.x * sumi - d8.y);
|
||||
#else
|
||||
static const uint16_t k_mult[5] = {81, 27, 9, 3, 1};
|
||||
const int8_t * q8 = bq8_1[iqs].qs;
|
||||
@@ -759,12 +785,12 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
|
||||
sumi += q8[0]*(vs - 1);
|
||||
q8++;
|
||||
}
|
||||
return scale * __low2float(bq8_1[iqs].ds) * sumi;
|
||||
*result += scale * __low2float(bq8_1[iqs].ds) * sumi;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
__device__ __forceinline__ void vec_dot_iq2_bn_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq2_bn * bq2 = (const block_iq2_bn *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
@@ -786,7 +812,7 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
||||
}
|
||||
auto d8l = __half22float2(bq8_1[0].ds);
|
||||
auto d8h = __half22float2(bq8_1[1].ds);
|
||||
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
|
||||
*result += scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
|
||||
#else
|
||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
||||
auto q8l = bq8_1[0].qs + 8*iqs;
|
||||
@@ -800,7 +826,7 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
|
||||
}
|
||||
auto d8l = __half22float2(bq8_1[0].ds);
|
||||
auto d8h = __half22float2(bq8_1[1].ds);
|
||||
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
|
||||
*result += scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -835,7 +861,7 @@ void mul_mat_vec_iq4_k_r4_q8_1_cuda(
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_r4_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K_R4, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_k_r4_q8_1, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq4_ks_q8_1_cuda(
|
||||
|
||||
@@ -542,6 +542,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
mul_mat_vec_iq4_k_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
@@ -655,6 +658,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user