This commit is contained in:
Kawrakow
2026-01-27 14:30:26 +00:00
parent f0d7efed43
commit dc23be32a2

View File

@@ -6,6 +6,7 @@
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
#include "fattn-compat.cuh"
//#include "fattn-prev-mma-f16-interface.cuh"
template <int DKQ, int DV, int ncols2>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -87,19 +88,24 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
if constexpr (DKQ == 128 && DV == 128) {
if (use_gqa_opt && gqa_ratio == 12) {
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 1) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 1, 16>(ctx, dst);
if (Q->ne[1] <= 8) {
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] == 1) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 1, 16>(ctx, dst);
return;
}
if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 2, 16>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, 16>(ctx, dst);
return;
}
if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 2, 16>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 4, 16>(ctx, dst);
//ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
return;
//else if (ggml_cuda_fattn_prev_mma_f16_is_supported(ctx, dst)) {
// ggml_cuda_flash_attn_ext_prev_mma_f16(ctx, dst);
// return;
//}
}
}
@@ -485,7 +491,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int32_t precision = KQV->op_params[3];
const int32_t n_swa = KQV->op_params[4];
ggml_tensor local_dst, Kl, Vl, Ml;