mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-24 08:29:29 +00:00
Add special FA handling for dense Qwen3.5
This commit is contained in:
@@ -227,46 +227,46 @@ struct fattn_mma_f16_config<128, 128> {
|
||||
return 64;
|
||||
}
|
||||
};
|
||||
//
|
||||
//template <>
|
||||
//struct fattn_mma_f16_config<256, 256> {
|
||||
// static constexpr int nbatch_fa = 32;
|
||||
// static constexpr int nwarps_max = 4;
|
||||
// static constexpr bool Q_in_reg = true;
|
||||
// static constexpr int nstages_target = 2;
|
||||
//
|
||||
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 128;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
// return 128;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
// return 128;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
// return 128;
|
||||
// }
|
||||
//
|
||||
// static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||
// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
|
||||
// return ncols <= 16 ? 128 : 64;
|
||||
// }
|
||||
// return 64;
|
||||
// }
|
||||
//
|
||||
// static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||
//#if __CUDA_ARCH__ == CC_TURING
|
||||
// return ncols <= 16 ? 128 : 64;
|
||||
//#else
|
||||
// GGML_UNUSED(ncols);
|
||||
// return 128;
|
||||
//#endif // __CUDA_ARCH__ == CC_TURING
|
||||
// }
|
||||
//};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<256, 256> {
|
||||
static constexpr int nbatch_fa = 32;
|
||||
static constexpr int nwarps_max = 4;
|
||||
static constexpr bool Q_in_reg = true;
|
||||
static constexpr int nstages_target = 2;
|
||||
|
||||
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
||||
return 128;
|
||||
}
|
||||
|
||||
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||
if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
}
|
||||
return 64;
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||
#if __CUDA_ARCH__ == CC_TURING
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
#else
|
||||
GGML_UNUSED(ncols);
|
||||
return 128;
|
||||
#endif // __CUDA_ARCH__ == CC_TURING
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fattn_mma_f16_config<192, 128> {
|
||||
@@ -2149,6 +2149,15 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (K->ne[0] == 256) {
|
||||
GGML_ASSERT(Q->ne[0] == 256 && V->ne[0] == 256);
|
||||
if (gqa_ratio == 6) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<256, 256, 1, 8>(ctx, dst);
|
||||
} else {
|
||||
GGML_ABORT("Not implemented");
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (K->ne[0] == 192 && V->ne[0] == 128) {
|
||||
GGML_ASSERT(Q->ne[0] == 192);
|
||||
//GGML_ASSERT(gqa_ratio == 1); // Haha, this assert was for DeepSeek. But now we have Mimo2, which has GQA > 1
|
||||
|
||||
@@ -95,6 +95,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
if (new_mma_available(cc) && K->ne[0] == 256 && V->ne[0] == 256 && Q->ne[0] == 256 && Q->ne[1] == 1 && Q->ne[2] / K->ne[2] == 6) {
|
||||
ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
// So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster.
|
||||
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
|
||||
|
||||
Reference in New Issue
Block a user