mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
iq4_xxs: CUDA dot product
We get TG-128 = 126 t/s for LLaMA-3.1-8B, compared to 123 t/s for q4_0.
This commit is contained in:
@@ -529,6 +529,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_K> {
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XXS> {
|
||||
static constexpr int qk = QK_K;
|
||||
static constexpr int qr = QR4_XS;
|
||||
static constexpr int qi = QI4_XS;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
|
||||
static constexpr int qk = QK_K;
|
||||
|
||||
@@ -220,30 +220,24 @@ __device__ __forceinline__ float vec_dot_iq4_k_q8_1(
|
||||
// TODO
|
||||
__device__ __forceinline__ float vec_dot_iq4_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
return 0.f;
|
||||
//
|
||||
// const block_iq4_k * bq4 = (const block_iq4_k *) vbq + kbx;
|
||||
// const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
//
|
||||
// // iqs is 0...28
|
||||
// const int ib32 = iqs/4;
|
||||
// // Why iqs/4 ?
|
||||
// const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
// const uint16_t * q4 = (const uint16_t *)bq4->qs + 8*ib32;
|
||||
// const uint16_t extra = bq4->extra >> 2*ib32;
|
||||
// int v1, v2;
|
||||
// int sumi1 = 0, sumi2 = 0;
|
||||
// for (int j = 0; j < 4; ++j) {
|
||||
// const uint32_t aux32 = q4[2*j+0] | (q4[2*j+1] << 16);
|
||||
// get_int_from_table_16_shift(aux32, extra, all_values, v1, v2);
|
||||
// sumi1 = ggml_cuda_dp4a(v1, q8[j+0], sumi1);
|
||||
// sumi2 = ggml_cuda_dp4a(v2, q8[j+4], sumi2);
|
||||
// }
|
||||
// const float d = __half2float(bq4->d) * __low2float(bq8_1[ib32].ds);
|
||||
// const uint8_t sh = bq4->scales_h[ib32/2] >> 4*(ib32%2);
|
||||
// const int ls1 = ((bq4->scales_l[ib32] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
||||
// const int ls2 = ((bq4->scales_l[ib32] >> 4) | ((sh << 2) & 0x30)) - 32;
|
||||
// return d * (sumi1 * ls1 + sumi2 * ls2);
|
||||
|
||||
float scale = *(const float *)vbq;
|
||||
const block_iq4_xxs * bq4 = (const block_iq4_xxs *)((const char *)vbq + sizeof(float)) + kbx;
|
||||
const uint8_t * all_values = (const uint8_t *)iq4k_values;
|
||||
|
||||
// iqs is 0...28
|
||||
const int ib32 = iqs/4; // Why iqs/4 ?
|
||||
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
|
||||
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
|
||||
const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
|
||||
int v1, v2;
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
|
||||
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
|
||||
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
|
||||
}
|
||||
return dl * __low2float(bq8_1[ib32].ds) * sumi;
|
||||
}
|
||||
|
||||
#define VDR_IQ5_K_Q8_1_MMVQ 4
|
||||
@@ -648,7 +642,7 @@ void mul_mat_vec_iq4_xxs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_K, VDR_IQ4_K_Q8_1_MMVQ, vec_dot_iq4_xxs_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XXS, VDR_IQ4_XXS_Q8_1_MMVQ, vec_dot_iq4_xxs_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
void mul_mat_vec_iq5_k_q8_1_cuda(
|
||||
|
||||
Reference in New Issue
Block a user