mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-26 01:19:20 +00:00
FlashMLA(CUDA) - allow q8_0 for KV cache
This is better. ~9% slower than f16 cache for short contexts, nearly on par at 16k tokens.
This commit is contained in:
@@ -66,24 +66,64 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
|
||||
//static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne,
|
||||
// const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
|
||||
// const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
//
|
||||
// if (i >= ne) {
|
||||
// return;
|
||||
// }
|
||||
//
|
||||
// const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
// const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
|
||||
// const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
|
||||
// const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
|
||||
//
|
||||
// const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
// const int ib = i00/QK8_0;
|
||||
// const int iq = i00%QK8_0;
|
||||
//
|
||||
// dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
|
||||
//}
|
||||
|
||||
static __global__ void k_transpose_q8_0(const char * cx, char * cdst,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
|
||||
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
|
||||
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
|
||||
//const int64_t ne00 = ne11;
|
||||
//const int64_t ne01 = ne10;
|
||||
//const int64_t ne02 = ne12;
|
||||
const int64_t i03 = i13;
|
||||
const int64_t i02 = i12;
|
||||
const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
|
||||
const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
|
||||
|
||||
const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const int ib = i00/QK8_0;
|
||||
const int iq = i00%QK8_0;
|
||||
const int ib0 = i00/QK8_0;
|
||||
const int iq0 = i00%QK8_0;
|
||||
|
||||
dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
|
||||
float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0];
|
||||
float amax = fabsf(xi);
|
||||
amax = warp_reduce_max(amax);
|
||||
|
||||
//printf("%d, %d, %d: i = %ld, i11 = %ld i10 = %ld, xi = %g, amax = %g\n", blockDim.x, blockIdx.x, threadIdx.x, i, i11, i10, xi, amax);
|
||||
|
||||
float d = amax/127;
|
||||
int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
||||
block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
dst[i10 / QK8_0].qs[i10 % QK8_0] = q;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[i10 / QK8_0].d = __float2half(d);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
@@ -487,22 +527,28 @@ static void ggml_cpy_f16_f16_cuda(
|
||||
|
||||
static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
auto stream = ctx.stream();
|
||||
auto ne = ggml_nelements(dst);
|
||||
ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne);
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
auto aux_src = *dst;
|
||||
aux_src.nb[0] = sizeof(float);
|
||||
aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0];
|
||||
aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1];
|
||||
aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2];
|
||||
cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
((const char *)src->data, dst_f32.get(), ne,
|
||||
src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
aux_src.type = GGML_TYPE_F32;
|
||||
ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream);
|
||||
auto num_blocks = ggml_nelements(dst)/QK8_0;
|
||||
k_transpose_q8_0<<<num_blocks, QK8_0, 0, stream>>>(
|
||||
(const char *)src->data, (char *)dst->data,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3],
|
||||
dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
|
||||
//auto ne = ggml_nelements(dst);
|
||||
//ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne);
|
||||
//const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
//auto aux_src = *dst;
|
||||
//aux_src.nb[0] = sizeof(float);
|
||||
//aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0];
|
||||
//aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1];
|
||||
//aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2];
|
||||
//cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
// ((const char *)src->data, dst_f32.get(), ne,
|
||||
// src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]);
|
||||
//CUDA_CHECK(cudaGetLastError());
|
||||
//aux_src.type = GGML_TYPE_F32;
|
||||
//ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
// aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3],
|
||||
// dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
|
||||
Reference in New Issue
Block a user