Adding q6_0: CUDA cpy, so Q6_0 can be used for KV-cache

This commit is contained in:
Iwan Kawrakow
2024-10-02 10:50:37 +03:00
parent 4cdf9b333f
commit c255a14a45
2 changed files with 53 additions and 0 deletions

View File

@@ -2242,6 +2242,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
if (s == "q6_0") {
return GGML_TYPE_Q6_0;
}
throw std::runtime_error("Invalid cache type: " + s);
}

View File

@@ -221,6 +221,41 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q6_0 * dsti = (block_q6_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK6_0; ++j) {
const float v = xi[j];
const float av = fabsf(xi[j]);
if (amax < av) {
amax = av;
vmax = v;
}
}
const float d = vmax / -32;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
memset(dsti->qh, 0, QK6_0/4);
for (int j = 0; j < QK6_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
}
}
static __device__ const int8_t iq4nl_index[241] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
@@ -397,6 +432,17 @@ static void ggml_cpy_f32_q5_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q6_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK6_0 == 0);
const int num_blocks = ne / QK6_0;
cpy_f32_q<cpy_blck_f32_q6_0, QK6_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -466,6 +512,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
@@ -505,6 +553,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q6_0, QK6_0>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {