This commit is contained in:
Kawrakow
2026-01-29 17:06:47 +00:00
parent cb4b0ebb11
commit ffc9e48a6f
2 changed files with 59 additions and 3 deletions

View File

@@ -378,6 +378,47 @@ struct fattn_mma_f16_config<576, 512> {
}
};
template <>
struct fattn_mma_f16_config<1088, 1024> {
static constexpr int nbatch_fa = 32;
static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1;
static int get_nbatch_K2_host([[maybe_unused]] const int cc, [[maybe_unused]] const int ncols) {
return 64;
}
static constexpr __device__ int get_nbatch_K2_device([[maybe_unused]] int ncols) {
return 64;
}
static int get_nbatch_V2_host([[maybe_unused]] const int cc, [[maybe_unused]] const int ncols) {
return 64;
//if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
// return ncols <= 16 ? 64 : 128;
//}
//return ncols <= 16 ? 256 : 128;
}
static constexpr __device__ int get_nbatch_V2_device([[maybe_unused]] int ncols) {
return 64;
//#if __CUDA_ARCH__ == CC_TURING
// return ncols <= 16 ? 64 : 128;
//#else
// return ncols <= 16 ? 256 : 128;
//#endif // __CUDA_ARCH__ == CC_TURING
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 64; //128;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 64; //128;
}
};
// ------------------------------------------------------------------------------------------------------------------
// The compiler is always able to unroll loops if they contain continue expressions.
@@ -2165,6 +2206,20 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 192, 1>(ctx, dst);
return;
}
if (Q->ne[0] == 1088 && K->ne[0] == 1088 && V->ne[0] == 1024) {
GGML_ASSERT(gqa_ratio == 20);
if (Q->ne[1] <= 4) {
if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) {
ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 1, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 1, 32>(ctx, dst);
}
return;
}
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<1088, 1024, 4>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<1088, 1024, 4, 4>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) {
if (ggml_cuda_info().devices[ctx.device].cc >= CC_ADA_LOVELACE) {

View File

@@ -114,7 +114,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// so no other implementation works.
//
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 1088 && V->ne[0] == 1024) ||
(K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
//printf("Using ggml_cuda_flash_attn_ext_mma_new\n");
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
@@ -185,8 +186,8 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te
return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst);
}
if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (Q->ne[0] == 576) {
if (new_mma_available(cc) && (Q->ne[0] == 576 || Q->ne[0] == 1088 || (K->ne[0] == 192 && V->ne[0] == 128 && mma_better_than_turing(cc)))) {
if (Q->ne[0] == 576 || Q->ne[0] == 1088) {
int gqa_ratio = Q->ne[2]/K->ne[2];
return (gqa_ratio % 4) == 0;
}