diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d3194b61..05678108 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -94,6 +94,19 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } } +template +static __device__ void cpy_blck_q_f16(const char * cxi, char * cdsti) { + half * dsth = (half *)(cdsti); + +#pragma unroll + for (int j = 0; j < qk/2; j++) { + dfloat2 dq; + dequant(cxi, 0, j, dq); + *(dsth + j + 0) = __float2half(dq.x); + *(dsth + j + qk/2) = __float2half(dq.y); + } +} + template static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -247,6 +260,19 @@ static void ggml_cpy_q4_0_f32_cuda( ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_q4_0_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK4_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -271,6 +297,19 @@ static void ggml_cpy_q4_1_f32_cuda( ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_q4_1_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK4_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -295,6 +334,19 @@ static void ggml_cpy_q5_0_f32_cuda( ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_q5_0_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK5_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -319,6 +371,19 @@ static void ggml_cpy_q5_1_f32_cuda( ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_q5_1_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK5_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -354,6 +419,19 @@ static void ggml_cpy_q6_0_f32_cuda( ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_q6_0_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK6_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static __global__ void k_transpose_q8_0(const char * cx, char * cdst, const int ne10, const int ne11, const int ne12, const int nb01, const int nb02, const int nb03, @@ -488,26 +566,36 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F16) { + ggml_cpy_q4_0_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F16) { + ggml_cpy_q4_1_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F16) { + ggml_cpy_q5_0_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F16) { + ggml_cpy_q5_1_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q6_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q6_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_Q6_0 && src1->type == GGML_TYPE_F16) { + ggml_cpy_q6_0_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { @@ -573,24 +661,34 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK4_0>; + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32, QK4_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK4_1>; + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32, QK4_1>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK5_0>; + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32, QK5_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK5_1>; + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q6_0 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK6_0>; + } else if (src0->type == GGML_TYPE_Q6_0 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32, QK6_0>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {