diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 05678108..c1b83cfc 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -310,6 +310,32 @@ static void ggml_cpy_q4_1_f16_cuda( ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +static void ggml_cpy_iq4_nl_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK4_NL><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + +static void ggml_cpy_iq4_nl_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const int num_blocks = ne; + cpy_q_f32, QK4_NL><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -584,6 +610,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_q5_0_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_IQ4_NL && src1->type == GGML_TYPE_F32) { + ggml_cpy_iq4_nl_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_IQ4_NL && src1->type == GGML_TYPE_F16) { + ggml_cpy_iq4_nl_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { @@ -669,14 +700,18 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_q_f32, QK4_1>; } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F16) { return (void*) cpy_q_f32, QK4_1>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_IQ4_NL && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_NL>; + } else if (src0->type == GGML_TYPE_IQ4_NL && src1->type == GGML_TYPE_F16) { + return (void*) cpy_q_f32, QK4_NL>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK5_0>; } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F16) { return (void*) cpy_q_f32, QK5_0>; - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { - return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index 52649f27..6f6a7b10 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -19,6 +19,24 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in #endif // GGML_CUDA_F16 } +static __device__ __forceinline__ void dequantize_iq4_nl(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = kvalues_iq4nl[vui & 0xF]; + v.y = kvalues_iq4nl[vui >> 4]; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); +#else + v.x = v.x * d; + v.y = v.y * d; +#endif // GGML_CUDA_F16 +} + static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx;