mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-22 14:14:32 +00:00
Restore SWA trick
This commit is contained in:
@@ -476,6 +476,36 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const int32_t n_swa = KQV->op_params[4];
|
||||
|
||||
ggml_tensor local_dst, Kl, Vl, Ml;
|
||||
if (n_swa > 0) {
|
||||
int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1]));
|
||||
int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE);
|
||||
int first = K->ne[1] - nton;
|
||||
if (first > 0) {
|
||||
local_dst = *dst;
|
||||
Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first;
|
||||
Vl = *V; Vl.ne[1] = nton; Vl.data = (char *)V->data + V->nb[1]*first;
|
||||
Ml = *mask; Ml.ne[0] = nton; Ml.data = (char *)mask->data + mask->nb[0]*first;
|
||||
local_dst.src[1] = &Kl;
|
||||
local_dst.src[2] = &Vl;
|
||||
local_dst.src[3] = &Ml;
|
||||
local_dst.op_params[4] = 0;
|
||||
dst = &local_dst;
|
||||
}
|
||||
}
|
||||
|
||||
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
||||
case BEST_FATTN_KERNEL_NONE:
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
Reference in New Issue
Block a user