mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-14 07:48:16 +00:00
FlashMLA(CUDA) - allow q8_0 for KV cache
This works, and PP is not bad, but TG is still quite a bit slower.
This commit is contained in:
@@ -2296,9 +2296,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
if (i02 < 0 || i02 >= n_as) continue;
|
||||
//GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
||||
|
||||
if (row_id_i != i02) {
|
||||
continue;
|
||||
}
|
||||
@@ -3458,6 +3455,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) {
|
||||
if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (ggml_are_same_shape(op->src[0], op->src[1]) && op->src[0]->type == GGML_TYPE_Q8_0 && op->src[1]->type == GGML_TYPE_Q8_0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
|
||||
@@ -66,6 +66,26 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
|
||||
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
|
||||
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
|
||||
|
||||
const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const int ib = i00/QK8_0;
|
||||
const int iq = i00%QK8_0;
|
||||
|
||||
dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
@@ -465,6 +485,26 @@ static void ggml_cpy_f16_f16_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
auto stream = ctx.stream();
|
||||
auto ne = ggml_nelements(dst);
|
||||
ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne);
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
auto aux_src = *dst;
|
||||
aux_src.nb[0] = sizeof(float);
|
||||
aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0];
|
||||
aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1];
|
||||
aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2];
|
||||
cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
((const char *)src->data, dst_f32.get(), ne,
|
||||
src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
aux_src.type = GGML_TYPE_F32;
|
||||
ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@@ -542,9 +582,14 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream);
|
||||
}
|
||||
}
|
||||
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
transpose_q8_0(ctx, src0, src1);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
@@ -593,7 +638,13 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (to_bf16) return (void*)to_bf16;
|
||||
}
|
||||
}
|
||||
else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void *)transpose_q8_0;
|
||||
}
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user