From f0d7efed435e858212f84ee499e2a27d84189ec1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 27 Jan 2026 12:24:22 +0000 Subject: [PATCH] Restore SWA trick --- ggml/src/ggml-cuda/fattn.cu | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index d569f9b9..4f53ef48 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -476,6 +476,36 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_set_device(ctx.device); + + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + ggml_cuda_set_device(ctx.device); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int32_t precision = KQV->op_params[3]; + const int32_t n_swa = KQV->op_params[4]; + + ggml_tensor local_dst, Kl, Vl, Ml; + if (n_swa > 0) { + int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1])); + int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE); + int first = K->ne[1] - nton; + if (first > 0) { + local_dst = *dst; + Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first; + Vl = *V; Vl.ne[1] = nton; Vl.data = (char *)V->data + V->nb[1]*first; + Ml = *mask; Ml.ne[0] = nton; Ml.data = (char *)mask->data + mask->nb[0]*first; + local_dst.src[1] = &Kl; + local_dst.src[2] = &Vl; + local_dst.src[3] = &Ml; + local_dst.op_params[4] = 0; + dst = &local_dst; + } + } + switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { case BEST_FATTN_KERNEL_NONE: GGML_ABORT("fatal error");