mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Adding q6_0: CUDA cpy, so Q6_0 can be used for KV-cache
This commit is contained in:
@@ -2242,6 +2242,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||||||
if (s == "q5_1") {
|
if (s == "q5_1") {
|
||||||
return GGML_TYPE_Q5_1;
|
return GGML_TYPE_Q5_1;
|
||||||
}
|
}
|
||||||
|
if (s == "q6_0") {
|
||||||
|
return GGML_TYPE_Q6_0;
|
||||||
|
}
|
||||||
|
|
||||||
throw std::runtime_error("Invalid cache type: " + s);
|
throw std::runtime_error("Invalid cache type: " + s);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -221,6 +221,41 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
|||||||
memcpy(dsti->qh, &qh, sizeof(qh));
|
memcpy(dsti->qh, &qh, sizeof(qh));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
|
||||||
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q6_0 * dsti = (block_q6_0 *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float vmax = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK6_0; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
const float av = fabsf(xi[j]);
|
||||||
|
if (amax < av) {
|
||||||
|
amax = av;
|
||||||
|
vmax = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = vmax / -32;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dsti->d = d;
|
||||||
|
memset(dsti->qh, 0, QK6_0/4);
|
||||||
|
|
||||||
|
for (int j = 0; j < QK6_0/2; ++j) {
|
||||||
|
const float x0 = xi[0 + j]*id;
|
||||||
|
const float x1 = xi[QK4_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
|
||||||
|
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
|
||||||
|
|
||||||
|
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
|
||||||
|
dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ const int8_t iq4nl_index[241] = {
|
static __device__ const int8_t iq4nl_index[241] = {
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
|
||||||
@@ -397,6 +432,17 @@ static void ggml_cpy_f32_q5_1_cuda(
|
|||||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q6_0_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne % QK6_0 == 0);
|
||||||
|
const int num_blocks = ne / QK6_0;
|
||||||
|
cpy_f32_q<cpy_blck_f32_q6_0, QK6_0><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
@@ -466,6 +512,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
|
||||||
|
ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
@@ -505,6 +553,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|||||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
|
||||||
|
return (void*) cpy_f32_q<cpy_blck_f32_q6_0, QK6_0>;
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
|
|||||||
Reference in New Issue
Block a user