Add special FA handling for dense Qwen3.5

This commit is contained in:
Kawrakow
2026-02-26 08:16:31 +00:00
parent 0aa6f7e7cd
commit 7340745572
2 changed files with 54 additions and 40 deletions

View File

@@ -227,46 +227,46 @@ struct fattn_mma_f16_config<128, 128> {
return 64;
}
};
//
//template <>
//struct fattn_mma_f16_config<256, 256> {
// static constexpr int nbatch_fa = 32;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
//
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
// return 128;
// }
//
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
// return 128;
// }
//
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
// return 128;
// }
//
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
// return 128;
// }
//
// static int get_nbatch_combine_host(const int cc, const int ncols) {
// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
// return ncols <= 16 ? 128 : 64;
// }
// return 64;
// }
//
// static constexpr __device__ int get_nbatch_combine_device(int ncols) {
//#if __CUDA_ARCH__ == CC_TURING
// return ncols <= 16 ? 128 : 64;
//#else
// GGML_UNUSED(ncols);
// return 128;
//#endif // __CUDA_ARCH__ == CC_TURING
// }
//};
template <>
struct fattn_mma_f16_config<256, 256> {
static constexpr int nbatch_fa = 32;
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 128;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 128;
}
static int get_nbatch_combine_host(const int cc, const int ncols) {
if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
return ncols <= 16 ? 128 : 64;
}
return 64;
}
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
#if __CUDA_ARCH__ == CC_TURING
return ncols <= 16 ? 128 : 64;
#else
GGML_UNUSED(ncols);
return 128;
#endif // __CUDA_ARCH__ == CC_TURING
}
};
template <>
struct fattn_mma_f16_config<192, 128> {
@@ -2149,6 +2149,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
}
return;
}
if (K->ne[0] == 256) {
GGML_ASSERT(Q->ne[0] == 256 && V->ne[0] == 256);
if (gqa_ratio == 6) {
ggml_cuda_flash_attn_ext_mma_f16_case<256, 256, 1, 8>(ctx, dst);
} else {
GGML_ABORT("Not implemented");
}
return;
}
if (K->ne[0] == 192 && V->ne[0] == 128) {
GGML_ASSERT(Q->ne[0] == 192);
//GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1

View File

@@ -95,6 +95,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
if (new_mma_available(cc) && K->ne[0] == 256 && V->ne[0] == 256 && Q->ne[0] == 256 && Q->ne[1] == 1 && Q->ne[2] / K->ne[2] == 6) {
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
return;
}
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
// So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster.
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.