mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-14 07:48:16 +00:00
Support GigaChat3 (#995)
* Fixing Gigachat support * Gigachat: CUDA FA (needs 192 x 192 for MLA = 3) * Gigachat: CPU FA (needs 192 x 192 for MLA = 3) --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -265,6 +265,7 @@ if (GGML_IQK_MUL_MAT)
|
||||
iqk/iqk_flash_attn.cpp
|
||||
iqk/fa/iqk_fa_576_512.cpp
|
||||
iqk/fa/iqk_fa_192_128.cpp
|
||||
iqk/fa/iqk_fa_192_192.cpp
|
||||
iqk/fa/iqk_fa_256_256.cpp
|
||||
iqk/fa/iqk_fa_128_128.cpp
|
||||
iqk/fa/iqk_fa_96_96.cpp
|
||||
|
||||
@@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
@@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[1];
|
||||
if (K->ne[0] != V->ne[0]) return false;
|
||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
|
||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256;
|
||||
}
|
||||
|
||||
@@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<192, 192> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 1;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<576, 512> {
|
||||
static constexpr int nbatch_fa = 32;
|
||||
@@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 192) {
|
||||
GGML_ASSERT(Q->ne[0] == 192);
|
||||
GGML_ASSERT(gqa_ratio == 1);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
|
||||
45
ggml/src/iqk/fa/iqk_fa_192_192.cpp
Normal file
45
ggml/src/iqk/fa/iqk_fa_192_192.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "iqk/iqk_config.h"
|
||||
|
||||
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
|
||||
|
||||
#include "iqk/fa/iqk_fa_templates.h"
|
||||
|
||||
IQK_FA_CASE(iqk_fa_192_192) {
|
||||
|
||||
auto type_k = ggml_type(int_type_k);
|
||||
auto type_v = ggml_type(int_type_v);
|
||||
|
||||
stride_q /= sizeof(float); // q stride as float
|
||||
auto ck = (const char *)k;
|
||||
auto cv = (const char *)v;
|
||||
auto cm = (const char *)mask;
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
if (type_k == GGML_TYPE_BF16) {
|
||||
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
|
||||
if (nk%64 == 0) {
|
||||
iqk_flash_helper_T<192, 192, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
iqk_flash_helper_T<192, 192, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (nk%128 == 0) {
|
||||
return iqk_flash_helper_T<192, 192, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
if (nk%64 == 0) {
|
||||
return iqk_flash_helper_T<192, 192, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
}
|
||||
|
||||
return iqk_flash_helper_T<192, 192, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -2235,6 +2235,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
|
||||
IQK_FA_CASE(iqk_fa_576_512);
|
||||
IQK_FA_CASE(iqk_fa_192_128);
|
||||
IQK_FA_CASE(iqk_fa_192_192);
|
||||
IQK_FA_CASE(iqk_fa_256_256);
|
||||
IQK_FA_CASE(iqk_fa_128_128);
|
||||
IQK_FA_CASE(iqk_fa_96_96);
|
||||
|
||||
@@ -1349,6 +1349,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 192 && Dv == 192) {
|
||||
return iqk_fa_192_192(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
}
|
||||
|
||||
if (Dk == 256 && Dv == 256) {
|
||||
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
|
||||
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
|
||||
|
||||
@@ -142,6 +142,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
|
||||
@@ -135,6 +135,8 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
|
||||
@@ -5931,7 +5931,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
int32_t n_tokens = this->n_tokens;
|
||||
|
||||
bool is_lite = (hparams.n_layer == 27);
|
||||
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||
|
||||
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
|
||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
||||
|
||||
@@ -762,8 +762,10 @@ void llm_load_hparams(
|
||||
for (auto& item : hparams.n_head_kv_arr) item = n_nead_kv;
|
||||
hparams.n_embd_head_k = 192;
|
||||
hparams.n_embd_head_v = 128;
|
||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k);
|
||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v);
|
||||
}
|
||||
bool is_lite = (hparams.n_layer == 27);
|
||||
bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
if (!is_lite) {
|
||||
|
||||
@@ -1617,7 +1617,7 @@ bool create_tensors_helper::create_arctix_tensors(const LLM_TN & tn) {
|
||||
bool create_tensors_helper::create_deepseek2_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
|
||||
const bool is_lite = (hparams.n_layer == 27);
|
||||
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
Reference in New Issue
Block a user