mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
Better CPU prompt processing performance for SWA models (#696)
* This does the trick for PP * Compute mask bounds when creating the mask * Set mask bounds for all supported SWA models --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -2043,6 +2043,10 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks);
|
||||
|
||||
GGML_API void ggml_flash_attn_ext_add_bounds(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * bounds);
|
||||
|
||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||
struct ggml_context * ctx,
|
||||
|
||||
@@ -8993,6 +8993,22 @@ void ggml_flash_attn_ext_add_sinks(
|
||||
a->src[4] = sinks;
|
||||
}
|
||||
|
||||
void ggml_flash_attn_ext_add_bounds(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * bounds) {
|
||||
if (!bounds) {
|
||||
a->src[5] = NULL;
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_ASSERT(bounds->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(bounds->ne[0] == 2);
|
||||
GGML_ASSERT(bounds->ne[1] >= a->src[0]->ne[1]);
|
||||
|
||||
a->src[5] = bounds;
|
||||
}
|
||||
|
||||
// ggml_flash_attn_back
|
||||
|
||||
struct ggml_tensor * ggml_flash_attn_back(
|
||||
@@ -18661,6 +18677,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const struct ggml_tensor * v = dst->src[2];
|
||||
const struct ggml_tensor * mask = dst->src[3];
|
||||
const struct ggml_tensor * sinks = dst->src[4];
|
||||
const struct ggml_tensor * bounds= dst->src[5];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
@@ -18739,7 +18756,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
dst->ne[2], dst->ne[1], dst->nb[1],
|
||||
k->type, v->type,
|
||||
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
|
||||
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
|
||||
q->data, k->data, v->data, mask->data,
|
||||
sinks ? sinks->data : NULL,
|
||||
bounds ? bounds->data : NULL,
|
||||
scale, softcap, (float *)dst->data,
|
||||
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
|
||||
|
||||
|
||||
@@ -43,6 +43,27 @@ inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float
|
||||
for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i];
|
||||
}
|
||||
}
|
||||
inline std::pair<int, int> mask_range(int nek1, const uint16_t * umask) {
|
||||
int first_k = 0, last_k = nek1;
|
||||
for (; first_k < last_k; ++first_k) {
|
||||
if (umask[first_k] == 0) break;
|
||||
}
|
||||
for (; last_k > first_k; --last_k) {
|
||||
if (umask[last_k-1] == 0) break;
|
||||
}
|
||||
return { first_k, last_k };
|
||||
}
|
||||
inline bool reduce_k_range(int nek1, int& first_k, int& last_k) {
|
||||
int nk = last_k - first_k;
|
||||
if (nk >= nek1) return false;
|
||||
if (nk%32) {
|
||||
int nk32 = 32*((nk + 31)/32);
|
||||
int diff = nk32 - nk;
|
||||
first_k = std::max(0, first_k - diff);
|
||||
last_k = first_k + nk32;
|
||||
}
|
||||
return last_k - first_k < nek1;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: get the ggml_type enum here without polution
|
||||
@@ -66,7 +87,8 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const void * sinks, // attention sinks
|
||||
const void * bounds, // attention mask bounds
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
@@ -80,22 +102,13 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
int rk3 = neq3/nek3;
|
||||
int rv3 = neq3/nev3;
|
||||
|
||||
int first_k = 0, last_k = nek1;
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nek1 > 256) {
|
||||
// This is a quick hack for SWA models.
|
||||
// Given that the mask is the same for all layers, ideally we should determinbe the
|
||||
// cache bounds once, and reuse for the whole graph. But even with this simple hack
|
||||
// we get non-negligible performance gains for SWA models and long context.
|
||||
auto umask = (const uint16_t *)mask;
|
||||
for (; first_k < last_k; ++first_k) {
|
||||
if (umask[first_k] == 0) break;
|
||||
}
|
||||
for (; last_k > first_k; --last_k) {
|
||||
if (umask[last_k-1] == 0) break;
|
||||
}
|
||||
//printf("nek1 = %d, first = %d, last = %d\n", nek1, first, last);
|
||||
if (last_k - first_k <= 3*nek1/4 && (last_k - first_k)%32 == 0) {
|
||||
//printf("Reducing from %d to %d\n", nek1, last_k - first_k);
|
||||
bool range_found = false;
|
||||
if (neq3 == 1 && rk2 > 1 && neq1 == 1 && bounds && nek1 > 32) {
|
||||
range_found = true;
|
||||
auto b = (const int32_t *)bounds;
|
||||
int first_k = b[0];
|
||||
int last_k = b[1];
|
||||
if ((last_k - first_k)%32 == 0) { // why is this not better? : if (reduce_k_range(nek1, first_k, last_k)) {
|
||||
k = (const void *)((const char *)k + first_k*stride_k);
|
||||
v = (const void *)((const char *)v + first_k*stride_v);
|
||||
mask = (const void *)((const uint16_t *)mask + first_k);
|
||||
@@ -105,7 +118,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
|
||||
int int_type_k = int_type_k_in;
|
||||
auto work_buffer = work_buffer_in;
|
||||
if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) {
|
||||
if (neq1 >= 8 || (false && rk2 >= 8 && nek2 > 1)) {
|
||||
uint64_t row_size = 0;
|
||||
work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size);
|
||||
if (int_type_k != int_type_k_in) {
|
||||
@@ -299,6 +312,25 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1g;
|
||||
int this_neq1 = std::min(neq1g, neq1-iq1);
|
||||
if (bounds && !range_found) {
|
||||
auto b = (const int32_t *)bounds + 2*iq1;
|
||||
int kmin = nek1, kmax = 0;
|
||||
for (int i = 0; i < this_neq1; ++i) {
|
||||
kmin = std::min(kmin, b[2*i+0]);
|
||||
kmax = std::max(kmax, b[2*i+1]);
|
||||
}
|
||||
if (reduce_k_range(nek1, kmin, kmax)) {
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, this_neq1, kmax-kmin, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
||||
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3 + kmin*stride_k),
|
||||
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3 + kmin*stride_v),
|
||||
(const void *)((const char *)mask + iq1*stride_m + kmin*sizeof(uint16_t)), sinksf, 1,
|
||||
scale, softcap,
|
||||
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
|
||||
Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
|
||||
|
||||
@@ -58,7 +58,8 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
|
||||
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
|
||||
const void * sinks, // attention sinks
|
||||
const void * bounds, // attention mask bounds
|
||||
float scale, // scale applied before softmax
|
||||
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
|
||||
float * qkv, // v*softmax(scale*(k*q))
|
||||
|
||||
@@ -2513,6 +2513,8 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens]
|
||||
struct ggml_tensor * inp_mask_bounds = nullptr; // I32 [2, n_tokens]
|
||||
struct ggml_tensor * inp_mask_bounds_swa = nullptr; // I32 [2, n_tokens]
|
||||
};
|
||||
|
||||
struct llama_lora_weight {
|
||||
@@ -7943,7 +7945,8 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il,
|
||||
ggml_tensor * sinks = nullptr) {
|
||||
ggml_tensor * sinks = nullptr,
|
||||
ggml_tensor * bounds = nullptr) {
|
||||
const llama_model & model = lctx.model;
|
||||
const llama_hparams & hparams = lctx.model.hparams;
|
||||
const llama_cparams & cparams = lctx.cparams;
|
||||
@@ -7990,7 +7993,8 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
ggml_flash_attn_ext_add_sinks (cur, sinks);
|
||||
ggml_flash_attn_ext_add_bounds(cur, bounds);
|
||||
|
||||
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
|
||||
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
|
||||
@@ -8148,7 +8152,8 @@ static struct ggml_tensor * llm_build_kv(
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il,
|
||||
ggml_tensor * sinks = nullptr) {
|
||||
ggml_tensor * sinks = nullptr,
|
||||
ggml_tensor * bounds = nullptr) {
|
||||
const llama_hparams & hparams = lctx.model.hparams;
|
||||
const llama_cparams & cparams = lctx.cparams;
|
||||
|
||||
@@ -8163,7 +8168,7 @@ static struct ggml_tensor * llm_build_kv(
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
|
||||
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks);
|
||||
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, bounds);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
return cur;
|
||||
@@ -8298,6 +8303,8 @@ struct llm_build_context {
|
||||
lctx.inp_pos_bucket = nullptr;
|
||||
lctx.inp_embd_enc = nullptr;
|
||||
lctx.inp_KQ_mask_cross = nullptr;
|
||||
lctx.inp_mask_bounds = nullptr;
|
||||
lctx.inp_mask_bounds_swa = nullptr;
|
||||
}
|
||||
|
||||
void free() {
|
||||
@@ -8478,6 +8485,9 @@ struct llm_build_context {
|
||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask);
|
||||
|
||||
lctx.inp_mask_bounds = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 2, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
ggml_set_input(lctx.inp_mask_bounds);
|
||||
|
||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
||||
}
|
||||
|
||||
@@ -8490,6 +8500,9 @@ struct llm_build_context {
|
||||
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask_swa);
|
||||
|
||||
lctx.inp_mask_bounds_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 2, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
ggml_set_input(lctx.inp_mask_bounds_swa);
|
||||
|
||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
|
||||
}
|
||||
|
||||
@@ -8658,6 +8671,7 @@ struct llm_build_context {
|
||||
bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true;
|
||||
auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ?
|
||||
KQ_mask_swa : KQ_mask;
|
||||
auto bounds = this_KQ_mask == KQ_mask_swa ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
@@ -8722,7 +8736,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, bounds);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -11223,7 +11237,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr, lctx.inp_mask_bounds_swa);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -12112,6 +12126,7 @@ struct llm_build_context {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// (il % 2) layers use SWA
|
||||
struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
|
||||
auto bounds = KQ_mask_l == KQ_mask_swa ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
@@ -12154,7 +12169,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr, bounds);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
@@ -12257,6 +12272,7 @@ struct llm_build_context {
|
||||
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
|
||||
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
|
||||
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
auto bounds = is_sliding ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
@@ -12291,7 +12307,7 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr, bounds);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il);
|
||||
@@ -14303,6 +14319,7 @@ struct llm_build_context {
|
||||
// fourth layer uses global attention without positional embeddings
|
||||
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
|
||||
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
auto bounds = is_sliding ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
|
||||
@@ -14356,7 +14373,7 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
|
||||
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
|
||||
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr, bounds);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -15407,6 +15424,7 @@ struct llm_build_context {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
auto bounds = is_sliding ? lctx.inp_mask_bounds_swa : lctx.inp_mask_bounds;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, cb, il);
|
||||
@@ -15446,7 +15464,7 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks);
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks, bounds);
|
||||
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
@@ -15965,16 +15983,26 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
int32_t * bounds = nullptr;
|
||||
int32_t * bounds_swa = nullptr;
|
||||
|
||||
if (lctx.inp_KQ_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
data = (float *) lctx.inp_KQ_mask->data;
|
||||
}
|
||||
if (lctx.inp_mask_bounds) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mask_bounds->buffer));
|
||||
bounds = (int32_t *)lctx.inp_mask_bounds->data;
|
||||
}
|
||||
|
||||
if (lctx.inp_KQ_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
|
||||
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
||||
}
|
||||
if (lctx.inp_mask_bounds_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mask_bounds_swa->buffer));
|
||||
bounds_swa = (int32_t *)lctx.inp_mask_bounds_swa->data;
|
||||
}
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
// of the correct sequence for each token of the batch.
|
||||
@@ -16023,6 +16051,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
if (h == 0 && bounds) {
|
||||
for (int i = 0; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
int min = n_kv, max = 0;
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
if (data[i*n_kv + j] > -INFINITY) {
|
||||
min = std::min(min, j);
|
||||
max = std::max(max, j);
|
||||
}
|
||||
}
|
||||
bounds[2*i + 0] = min;
|
||||
bounds[2*i + 1] = max+1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
@@ -16031,6 +16072,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
if (h == 0 && bounds_swa) {
|
||||
for (int i = 0; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
int min = n_kv, max = 0;
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
if (data_swa[i*n_kv + j] > -INFINITY) {
|
||||
min = std::min(min, j);
|
||||
max = std::max(max, j);
|
||||
}
|
||||
}
|
||||
bounds_swa[2*i + 0] = min;
|
||||
bounds_swa[2*i + 1] = max+1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user