Support GigaChat3 (#995)

* Fixing Gigachat support

* Gigachat: CUDA FA (needs 192 x 192 for MLA = 3)

* Gigachat: CPU FA (needs 192 x 192 for MLA = 3)

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-24 06:55:14 +01:00
committed by GitHub
parent 1feccd4174
commit f1191036b2
11 changed files with 103 additions and 4 deletions

View File

@@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT)
iqk/iqk_flash_attn.cpp
iqk/fa/iqk_fa_576_512.cpp
iqk/fa/iqk_fa_192_128.cpp
iqk/fa/iqk_fa_192_192.cpp
iqk/fa/iqk_fa_256_256.cpp
iqk/fa/iqk_fa_128_128.cpp
iqk/fa/iqk_fa_96_96.cpp

View File

@@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
case 128:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
break;
@@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con
auto K = dst->src[1];
auto V = dst->src[1];
if (K->ne[0] != V->ne[0]) return false;
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256;
}

View File

@@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> {
}
};
template <>
struct fattn_mma_f16_config<192, 192> {
static constexpr int nbatch_fa = 64;
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 1;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};
template <>
struct fattn_mma_f16_config<576, 512> {
static constexpr int nbatch_fa = 32;
@@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
return;
}
if (K->ne[0] == 192 && V->ne[0] == 192) {
GGML_ASSERT(Q->ne[0] == 192);
GGML_ASSERT(gqa_ratio == 1);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);

View File

@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
#include "iqk/fa/iqk_fa_templates.h"
IQK_FA_CASE(iqk_fa_192_192) {
auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);
stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
#endif

View File

@@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
IQK_FA_CASE(iqk_fa_576_512);
IQK_FA_CASE(iqk_fa_192_128);
IQK_FA_CASE(iqk_fa_192_192);
IQK_FA_CASE(iqk_fa_256_256);
IQK_FA_CASE(iqk_fa_128_128);
IQK_FA_CASE(iqk_fa_96_96);

View File

@@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 192 && Dv == 192) {
return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 256 && Dv == 256) {
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);

View File

@@ -142,6 +142,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },

View File

@@ -135,6 +135,8 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,

View File

@@ -5931,7 +5931,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;
bool is_lite = (hparams.n_layer == 27);
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.

View File

@@ -762,8 +762,10 @@ void llm_load_hparams(
for (auto& item : hparams.n_head_kv_arr) item = n_nead_kv;
hparams.n_embd_head_k = 192;
hparams.n_embd_head_v = 128;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v);
}
bool is_lite = (hparams.n_layer == 27);
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
if (!is_lite) {

View File

@@ -1617,7 +1617,7 @@ bool create_tensors_helper::create_arctix_tensors(const LLM_TN & tn) {
bool create_tensors_helper::create_deepseek2_tensors(const LLM_TN & tn) {
LOADING_PRELUDE
const bool is_lite = (hparams.n_layer == 27);
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;