mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
Gigachat: CUDA FA (needs 192 x 192 for MLA = 3)
This commit is contained in:
@@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 192:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
@@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con
|
||||
auto K = dst->src[1];
|
||||
auto V = dst->src[1];
|
||||
if (K->ne[0] != V->ne[0]) return false;
|
||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
|
||||
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256;
|
||||
}
|
||||
|
||||
@@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<192, 192> {
|
||||
static constexpr int nbatch_fa = 64;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 1;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 64;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
||||
return 32;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<576, 512> {
|
||||
static constexpr int nbatch_fa = 32;
|
||||
@@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 192) {
|
||||
GGML_ASSERT(Q->ne[0] == 192);
|
||||
GGML_ASSERT(gqa_ratio == 1);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
|
||||
Reference in New Issue
Block a user