diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index b0bd3778..c9acf1fc 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT) iqk/iqk_flash_attn.cpp iqk/fa/iqk_fa_576_512.cpp iqk/fa/iqk_fa_192_128.cpp + iqk/fa/iqk_fa_192_192.cpp iqk/fa/iqk_fa_256_256.cpp iqk/fa/iqk_fa_128_128.cpp iqk/fa/iqk_fa_96_96.cpp diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu index 01e63541..539d7728 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context case 128: ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); break; + case 192: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst); + break; case 256: ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); break; @@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con auto K = dst->src[1]; auto V = dst->src[1]; if (K->ne[0] != V->ne[0]) return false; - return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256; + return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256; } diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index b2285fdd..ef557209 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> { } }; +template <> +struct fattn_mma_f16_config<192, 192> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 1; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } +}; + template <> struct fattn_mma_f16_config<576, 512> { static constexpr int nbatch_fa = 32; @@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst); return; } + if (K->ne[0] == 192 && V->ne[0] == 192) { + GGML_ASSERT(Q->ne[0] == 192); + GGML_ASSERT(gqa_ratio == 1); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst); + return; + } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); GGML_ASSERT(gqa_ratio % 16 == 0); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); diff --git a/ggml/src/iqk/fa/iqk_fa_192_192.cpp b/ggml/src/iqk/fa/iqk_fa_192_192.cpp new file mode 100644 index 00000000..21fe033c --- /dev/null +++ b/ggml/src/iqk/fa/iqk_fa_192_192.cpp @@ -0,0 +1,45 @@ +#include "iqk/iqk_config.h" + +#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION + +#include "iqk/fa/iqk_fa_templates.h" + +IQK_FA_CASE(iqk_fa_192_192) { + + auto type_k = ggml_type(int_type_k); + auto type_v = ggml_type(int_type_v); + + stride_q /= sizeof(float); // q stride as float + auto ck = (const char *)k; + auto cv = (const char *)v; + auto cm = (const char *)mask; + +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types + if (nk%64 == 0) { + iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + return true; + } + iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + return true; + } +#endif + + if (nk%128 == 0) { + return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + } + if (nk%64 == 0) { + return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + } + + return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S); + +} + +#endif diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 3a0b7248..8e96844e 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, IQK_FA_CASE(iqk_fa_576_512); IQK_FA_CASE(iqk_fa_192_128); +IQK_FA_CASE(iqk_fa_192_192); IQK_FA_CASE(iqk_fa_256_256); IQK_FA_CASE(iqk_fa_128_128); IQK_FA_CASE(iqk_fa_96_96); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7876d199..35573e49 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k q, k, v, mask, scale, softcap, qkv, sinksf, M, S); } + if (Dk == 192 && Dv == 192) { + return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, k, v, mask, scale, softcap, qkv, sinksf, M, S); + } + if (Dk == 256 && Dv == 256) { return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, k, v, mask, scale, softcap, qkv, sinksf, M, S); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 3c717873..83280c3b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -142,6 +142,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, + { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index b9f06f9d..a872e3ce 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -135,6 +135,8 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, + LLM_KV_ATTENTION_KEY_LENGTH_MLA, + LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 9dc72999..9588d1b9 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -5931,7 +5931,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; - bool is_lite = (hparams.n_layer == 27); + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index e14167c1..fb6c2708 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -762,8 +762,10 @@ void llm_load_hparams( for (auto& item : hparams.n_head_kv_arr) item = n_nead_kv; hparams.n_embd_head_k = 192; hparams.n_embd_head_v = 128; + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v); } - bool is_lite = (hparams.n_layer == 27); + bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 04db3d90..5f671fe7 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -1617,7 +1617,7 @@ bool create_tensors_helper::create_arctix_tensors(const LLM_TN & tn) { bool create_tensors_helper::create_deepseek2_tensors(const LLM_TN & tn) { LOADING_PRELUDE - const bool is_lite = (hparams.n_layer == 27); + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;