This is slightly better

This commit is contained in:
Iwan Kawrakow
2025-09-04 09:06:31 +03:00
parent bf0b5088e0
commit b02e137f60

View File

@@ -16,6 +16,8 @@
#include <cstdint>
#define FATTN_KQ_STRIDE 256
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
@@ -28,6 +30,24 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int32_t precision = KQV->op_params[3];
const int32_t n_swa = KQV->op_params[4];
ggml_tensor local_dst, Kl, Vl, Ml;
if (n_swa > 0) {
int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1]));
int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE);
int first = K->ne[1] - nton;
if (first > 0) {
local_dst = *dst;
Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first;
Vl = *V; Vl.ne[1] = nton; Vl.data = (char *)V->data + V->nb[1]*first;
Ml = *mask; Ml.ne[0] = nton; Ml.data = (char *)mask->data + mask->nb[0]*first;
local_dst.src[1] = &Kl;
local_dst.src[2] = &Vl;
local_dst.src[3] = &Ml;
local_dst.op_params[4] = 0;
dst = &local_dst;
}
}
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {