mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
cuda: re-add q8_0 -> q8_0 transpose
so mla = 2 can be used with CUDA graphs and q8_0 cache.
This commit is contained in:
@@ -2941,29 +2941,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
return fuse_down;
|
||||
}
|
||||
|
||||
static void ggml_cuda_cpy_wrapper(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
auto src0 = dst->src[0];
|
||||
auto src1 = dst->src[1];
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
CUDA_CHECK(cudaMemcpyAsync((char *)src1->data, (char *)src0->data, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, ctx.stream()));
|
||||
return;
|
||||
}
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
if (ctx.cuda_graph->use_cpy_indirection) {
|
||||
GGML_ASSERT(ctx.cuda_graph->graph_cpynode_index < (int)ctx.cuda_graph->cpy_dest_ptrs.size());
|
||||
auto dest_ptr = ctx.cuda_graph->cpy_dest_ptrs[ctx.cuda_graph->graph_cpynode_index];
|
||||
ggml_tensor aux_src1 = *src1;
|
||||
aux_src1.data = dest_ptr;
|
||||
ggml_cuda_cpy(ctx, src0, &aux_src1);
|
||||
++ctx.cuda_graph->graph_cpynode_index;
|
||||
} else {
|
||||
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
|
||||
}
|
||||
#else
|
||||
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) {
|
||||
// why is this here instead of mul_mat?
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
|
||||
@@ -2985,7 +2962,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
//ggml_cuda_cpy_wrapper(ctx, dst);
|
||||
ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
@@ -3269,20 +3245,6 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
//static void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs,
|
||||
// const int host_dest_ptrs_size, cudaStream_t stream) {
|
||||
// if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
||||
// CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
// if (cuda_graph->dest_ptrs_d != nullptr) {
|
||||
// CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
||||
// }
|
||||
// CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
||||
// cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
||||
// }
|
||||
// // copy destination pointers to GPU
|
||||
// CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
||||
// cuda_graph->graph_cpynode_index = 0; // reset index
|
||||
//}
|
||||
|
||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
|
||||
@@ -325,6 +325,54 @@ static void ggml_cpy_q6_0_f32_cuda(
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
}
|
||||
|
||||
static __global__ void k_transpose_q8_0(const char * cx, char * cdst,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb01, const int nb02, const int nb03,
|
||||
const int nb11, const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
|
||||
//const int64_t ne00 = ne11;
|
||||
//const int64_t ne01 = ne10;
|
||||
//const int64_t ne02 = ne12;
|
||||
const int64_t i03 = i13;
|
||||
const int64_t i02 = i12;
|
||||
const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
|
||||
const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
|
||||
|
||||
const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const int ib0 = i00/QK8_0;
|
||||
const int iq0 = i00%QK8_0;
|
||||
|
||||
float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0];
|
||||
float amax = fabsf(xi);
|
||||
amax = warp_reduce_max(amax);
|
||||
|
||||
float d = amax/127;
|
||||
int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
||||
block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
dst[i10 / QK8_0].qs[i10 % QK8_0] = q;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[i10 / QK8_0].d = __float2half(d);
|
||||
}
|
||||
}
|
||||
|
||||
static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
auto stream = ctx.stream();
|
||||
auto num_blocks = ggml_nelements(dst)/QK8_0;
|
||||
k_transpose_q8_0<<<num_blocks, QK8_0, 0, stream>>>(
|
||||
(const char *)src->data, (char *)dst->data,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3],
|
||||
dst->nb[1], dst->nb[2], dst->nb[3]);
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@@ -428,6 +476,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
// This is needed for MLA with mla=2 when using q8_0 cache.
|
||||
transpose_q8_0(ctx, src0, src1);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@@ -497,6 +548,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void *)transpose_q8_0;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
||||
Reference in New Issue
Block a user