diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d6350f6e..7d20869f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2043,6 +2043,10 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * sinks); + GGML_API void ggml_flash_attn_ext_add_bounds( + struct ggml_tensor * a, + struct ggml_tensor * bounds); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 695dc722..620069f9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8993,6 +8993,22 @@ void ggml_flash_attn_ext_add_sinks( a->src[4] = sinks; } +void ggml_flash_attn_ext_add_bounds( + struct ggml_tensor * a, + struct ggml_tensor * bounds) { + if (!bounds) { + a->src[5] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + GGML_ASSERT(bounds->type == GGML_TYPE_I32); + GGML_ASSERT(bounds->ne[0] == 2); + GGML_ASSERT(bounds->ne[1] >= a->src[0]->ne[1]); + + a->src[5] = bounds; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -18661,6 +18677,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const struct ggml_tensor * v = dst->src[2]; const struct ggml_tensor * mask = dst->src[3]; const struct ggml_tensor * sinks = dst->src[4]; + const struct ggml_tensor * bounds= dst->src[5]; GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(size_t, nbq, q, nb) @@ -18739,7 +18756,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( dst->ne[2], dst->ne[1], dst->nb[1], k->type, v->type, Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], - q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL, + q->data, k->data, v->data, mask->data, + sinks ? sinks->data : NULL, + bounds ? bounds->data : NULL, scale, softcap, (float *)dst->data, params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index ccd81079..00791ba3 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -43,6 +43,27 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; } } +inline std::pair mask_range(int nek1, const uint16_t * umask) { + int first_k = 0, last_k = nek1; + for (; first_k < last_k; ++first_k) { + if (umask[first_k] == 0) break; + } + for (; last_k > first_k; --last_k) { + if (umask[last_k-1] == 0) break; + } + return { first_k, last_k }; +} +inline bool reduce_k_range(int nek1, int& first_k, int& last_k) { + int nk = last_k - first_k; + if (nk >= nek1) return false; + if (nk%32) { + int nk32 = 32*((nk + 31)/32); + int diff = nk32 - nk; + first_k = std::max(0, first_k - diff); + last_k = first_k + nk32; + } + return last_k - first_k < nek1; +} } // TODO: get the ggml_type enum here without polution @@ -66,7 +87,8 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements + const void * sinks, // attention sinks + const void * bounds, // attention mask bounds float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) @@ -80,22 +102,13 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int rk3 = neq3/nek3; int rv3 = neq3/nev3; - int first_k = 0, last_k = nek1; - if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) { - // This is a quick hack for SWA models. - // Given that the mask is the same for all layers, ideally we should determinbe the - // cache bounds once, and reuse for the whole graph. But even with this simple hack - // we get non-negligible performance gains for SWA models and long context. - auto umask = (const uint16_t *)mask; - for (; first_k < last_k; ++first_k) { - if (umask[first_k] == 0) break; - } - for (; last_k > first_k; --last_k) { - if (umask[last_k-1] == 0) break; - } - //printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last); - if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) { - //printf("Reducing from %d to %d\n", nek1, last_k - first_k); + bool range_found = false; + if (neq3 == 1 && rk2 > 1 && neq1 == 1 && bounds && nek1 > 32) { + range_found = true; + auto b = (const int32_t *)bounds; + int first_k = b[0]; + int last_k = b[1]; + if ((last_k - first_k)%32 == 0) { // why is this not better? : if (reduce_k_range(nek1, first_k, last_k)) { k = (const void *)((const char *)k + first_k*stride_k); v = (const void *)((const char *)v + first_k*stride_v); mask = (const void *)((const uint16_t *)mask + first_k); @@ -105,7 +118,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int int_type_k = int_type_k_in; auto work_buffer = work_buffer_in; - if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) { + if (neq1 >= 8 || (false && rk2 >= 8 && nek2 > 1)) { uint64_t row_size = 0; work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); if (int_type_k != int_type_k_in) { @@ -299,6 +312,25 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float if (counter++ % (nth/ntg) == ith/ntg) { int iq1 = (ith%ntg)*neq1g; int this_neq1 = std::min(neq1g, neq1-iq1); + if (bounds && !range_found) { + auto b = (const int32_t *)bounds + 2*iq1; + int kmin = nek1, kmax = 0; + for (int i = 0; i < this_neq1; ++i) { + kmin = std::min(kmin, b[2*i+0]); + kmax = std::max(kmax, b[2*i+1]); + } + if (reduce_k_range(nek1, kmin, kmax)) { + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, this_neq1, kmax-kmin, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float), + (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), + (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3 + kmin*stride_k), + (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3 + kmin*stride_v), + (const void *)((const char *)mask + iq1*stride_m + kmin*sizeof(uint16_t)), sinksf, 1, + scale, softcap, + (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; + continue; + } + } if (!iqk_flash_attn_impl(int_type_k, int_type_v, Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float), (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index b131095b..bcb7b91f 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -58,7 +58,8 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements + const void * sinks, // attention sinks + const void * bounds, // attention mask bounds float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) diff --git a/src/llama.cpp b/src/llama.cpp index 4d7254c4..c810e511 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2513,6 +2513,8 @@ struct llama_context { struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] + struct ggml_tensor * inp_mask_bounds = nullptr; // I32 [2, n_tokens] + struct ggml_tensor * inp_mask_bounds_swa = nullptr; // I32 [2, n_tokens] }; struct llama_lora_weight { @@ -7943,7 +7945,8 @@ static struct ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr) { + ggml_tensor * sinks = nullptr, + ggml_tensor * bounds = nullptr) { const llama_model & model = lctx.model; const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -7990,7 +7993,8 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - ggml_flash_attn_ext_add_sinks(cur, sinks); + ggml_flash_attn_ext_add_sinks (cur, sinks); + ggml_flash_attn_ext_add_bounds(cur, bounds); // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. @@ -8148,7 +8152,8 @@ static struct ggml_tensor * llm_build_kv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr) { + ggml_tensor * sinks = nullptr, + ggml_tensor * bounds = nullptr) { const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -8163,7 +8168,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, - q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks); + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, bounds); cb(cur, "kqv_out", il); return cur; @@ -8298,6 +8303,8 @@ struct llm_build_context { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; + lctx.inp_mask_bounds = nullptr; + lctx.inp_mask_bounds_swa = nullptr; } void free() { @@ -8478,6 +8485,9 @@ struct llm_build_context { cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); + lctx.inp_mask_bounds = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 2, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(lctx.inp_mask_bounds); + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } @@ -8490,6 +8500,9 @@ struct llm_build_context { cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); ggml_set_input(lctx.inp_KQ_mask_swa); + lctx.inp_mask_bounds_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 2, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(lctx.inp_mask_bounds_swa); + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa; } @@ -8658,6 +8671,7 @@ struct llm_build_context { bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ? KQ_mask_swa : KQ_mask; + auto bounds = this_KQ_mask == KQ_mask_swa ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds; // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -8722,7 +8736,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, bounds); } if (il == n_layer - 1) { @@ -11223,7 +11237,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr, lctx.inp_mask_bounds_swa); } if (il == n_layer - 1) { @@ -12112,6 +12126,7 @@ struct llm_build_context { for (int il = 0; il < n_layer; ++il) { // (il % 2) layers use SWA struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask; + auto bounds = KQ_mask_l == KQ_mask_swa ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds; // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -12154,7 +12169,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr, bounds); } cur = llm_build_norm(ctx0, cur, hparams, @@ -12257,6 +12272,7 @@ struct llm_build_context { const float freq_base_l = is_sliding ? 10000.0f : freq_base; const float freq_scale_l = is_sliding ? 1.0f : freq_scale; struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + auto bounds = is_sliding ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds; // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -12291,7 +12307,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr, bounds); } cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); @@ -14303,6 +14319,7 @@ struct llm_build_context { // fourth layer uses global attention without positional embeddings const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + auto bounds = is_sliding ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds; // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il); @@ -14356,7 +14373,7 @@ struct llm_build_context { } cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, - KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr, bounds); } if (il == n_layer - 1) { @@ -15407,6 +15424,7 @@ struct llm_build_context { ggml_tensor * inpSA = inpL; struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + auto bounds = is_sliding ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds; // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il); @@ -15446,7 +15464,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks, bounds); cb(cur, "attn_out", il); } @@ -15965,16 +15983,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = nullptr; float * data_swa = nullptr; + int32_t * bounds = nullptr; + int32_t * bounds_swa = nullptr; if (lctx.inp_KQ_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); data = (float *) lctx.inp_KQ_mask->data; } + if (lctx.inp_mask_bounds) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mask_bounds->buffer)); + bounds = (int32_t *)lctx.inp_mask_bounds->data; + } if (lctx.inp_KQ_mask_swa) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); data_swa = (float *) lctx.inp_KQ_mask_swa->data; } + if (lctx.inp_mask_bounds_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mask_bounds_swa->buffer)); + bounds_swa = (int32_t *)lctx.inp_mask_bounds_swa->data; + } // For causal attention, use only the previous KV cells // of the correct sequence for each token of the batch. @@ -16023,6 +16051,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } + if (h == 0 && bounds) { + for (int i = 0; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + int min = n_kv, max = 0; + for (int j = 0; j < n_kv; ++j) { + if (data[i*n_kv + j] > -INFINITY) { + min = std::min(min, j); + max = std::max(max, j); + } + } + bounds[2*i + 0] = min; + bounds[2*i + 1] = max+1; + } + } } if (data_swa) { @@ -16031,6 +16072,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } + if (h == 0 && bounds_swa) { + for (int i = 0; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + int min = n_kv, max = 0; + for (int j = 0; j < n_kv; ++j) { + if (data_swa[i*n_kv + j] > -INFINITY) { + min = std::min(min, j); + max = std::max(max, j); + } + } + bounds_swa[2*i + 0] = min; + bounds_swa[2*i + 1] = max+1; + } + } } } } else {