Use MMA for TG also when quantized

This commit is contained in:
Iwan Kawrakow
2025-05-03 15:34:56 +03:00
parent dd46438d25
commit 056f08182a

View File

@@ -502,8 +502,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
// So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster.
// On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.
//const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
//const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies;
const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (precision == GGML_PREC_DEFAULT) {
@@ -514,6 +517,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
// We need this because I haven't adapted the MMA kernels to work for different
// K and V head sizes.
if (K->ne[0] != V->ne[0]) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
}