Gigachat: CUDA FA (needs 192 x 192 for MLA = 3)

This commit is contained in:
Iwan Kawrakow
2025-11-21 11:27:10 +02:00
parent e0bc495792
commit 360c8c6fd4
2 changed files with 42 additions and 1 deletions

View File

@@ -43,6 +43,9 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
case 128:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, ncols2>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
break;
@@ -88,5 +91,5 @@ bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_con
auto K = dst->src[1];
auto V = dst->src[1];
if (K->ne[0] != V->ne[0]) return false;
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 192 || K->ne[0] == 256;
}

View File

@@ -315,6 +315,38 @@ struct fattn_mma_f16_config<192, 128> {
}
};
template <>
struct fattn_mma_f16_config<192, 192> {
static constexpr int nbatch_fa = 64;
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 1;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 64;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 64;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};
template <>
struct fattn_mma_f16_config<576, 512> {
static constexpr int nbatch_fa = 32;
@@ -2119,6 +2151,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst);
return;
}
if (K->ne[0] == 192 && V->ne[0] == 192) {
GGML_ASSERT(Q->ne[0] == 192);
GGML_ASSERT(gqa_ratio == 1);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);