From 1322c3f3e58b0668e7a377adb1364ce8427e313b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 21 Oct 2024 11:06:24 +0300 Subject: [PATCH] Add IQ4_NL + IQ4_NL to FA This is a better alternative than Q4_0 + Q4_0 for the VRAM poor. --- Makefile | 1 + ggml/src/CMakeLists.txt | 2 + ggml/src/ggml-cuda/fattn-common.cuh | 82 +++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 6 +- ...tn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu | 5 ++ ...attn-vec-f16-instance-hs128-iq4_nl-q4_0.cu | 5 ++ ...tn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu | 5 ++ 7 files changed, 86 insertions(+), 20 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-q4_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu diff --git a/Makefile b/Makefile index 6aaf1d5b..ae636f50 100644 --- a/Makefile +++ b/Makefile @@ -599,6 +599,7 @@ else OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu)) OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu)) OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu)) + OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*:iq4_nl-iq4_nl.cu)) endif # GGML_CUDA_FA_ALL_QUANTS ifdef GGML_CUDA diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 5bcceae3..eb6d457c 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -330,6 +330,8 @@ if (GGML_CUDA) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-iq4_nl.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*iq4_nl-iq4_nl.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) endif() list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 51151ca2..1984c838 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -136,6 +136,49 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( return sum; } +static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) { + const uint8_t * q0_8 = (const uint8_t *) &q4; + const char4 val0_8 = make_char4(kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]); + return *((const int *) &val0_8); +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_iq4_nl * K_iq4_nl = (const block_iq4_nl *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_NL; + const int shift = k_KQ & (QI8_1/2); + + const int v = get_one_int_from_table_16((get_int_b2(K_iq4_nl[ib].qs, iqs4) >> shift) & 0x0F0F0F0F); + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + sum += (T) (((half)sumi) * K_iq4_nl[ib].d * Q_ds[k_KQ_0/WARP_SIZE].x); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + sum += (T) ((float)sumi * __half2float(K_iq4_nl[ib].d) * Q_ds[k_KQ_0/WARP_SIZE].x); + } + } + + return sum; +} + template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -495,24 +538,26 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v template constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : - nullptr; + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; } template constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : - nullptr; + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; } constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { @@ -590,10 +635,11 @@ static void on_no_fattn_vec_case(const int D) { } else if (D == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); - fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); - fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n"); - fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); - fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); + fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); + fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.50 BPV\n"); + fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.50 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); + fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 595ba3df..ae491aea 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -212,7 +212,8 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) #else FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -227,7 +228,8 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg //FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) //FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) #endif // GGML_CUDA_FA_ALL_QUANTS diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu new file mode 100644 index 00000000..672a39d0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-iq4_nl.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-q4_0.cu new file mode 100644 index 00000000..9836aa10 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-iq4_nl-q4_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu new file mode 100644 index 00000000..286c9e20 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-iq4_nl-iq4_nl.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL);