Whith FA on, create mask as f16 directly

This commit is contained in:
Iwan Kawrakow
2025-10-12 14:23:05 +03:00
parent 9b02dd0405
commit 95419ed393
2 changed files with 117 additions and 17 deletions

View File

@@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() {
}
ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
if (causal && flash_attn) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
return lctx.inp_KQ_mask;
}
lctx.inp_KQ_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
@@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) {
GGML_ASSERT(hparams.n_swa > 0);
if (causal && flash_attn) {
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(lctx.inp_KQ_mask_swa);
return lctx.inp_KQ_mask_swa;
}
lctx.inp_KQ_mask_swa = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))

View File

@@ -2051,21 +2051,55 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
float * data = nullptr;
float * data_swa = nullptr;
ggml_half * data_f16 = nullptr;
ggml_half * data_swa_f16 = nullptr;
if (lctx.inp_KQ_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
data = (float *) lctx.inp_KQ_mask->data;
if (cparams.flash_attn) {
data_f16 = (ggml_half *)lctx.inp_KQ_mask->data;
} else {
data = (float *) lctx.inp_KQ_mask->data;
}
}
if (lctx.inp_KQ_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
if (cparams.flash_attn) {
data_swa_f16 = (ggml_half *) lctx.inp_KQ_mask_swa->data;
} else {
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
}
}
auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
ggml_half h_zero = ggml_fp32_to_fp16(0.f);
for (int i = first; i < last; ++i) {
ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero;
if (data_f16) data_f16[j*n_kv + i] = h;
if (data_swa_f16) {
if (h != h_inf) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
h = h_inf;
}
} else {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
h = h_inf;
}
}
}
data_swa_f16[j*n_kv + i] = h;
}
}
};
if (n_kv >= 1024 && n_tokens >= 32) {
int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2));
int npt = (n_kv + n_thread - 1)/n_thread;
auto compute = [&batch, &lctx, &hparams, n_tokens, n_kv, npt, data, data_swa] (int ith) {
auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
int first = ith * npt;
int last = std::min(int(n_kv), first + npt);
if (last <= first) return;
@@ -2073,6 +2107,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
if (!hparams.use_alibi && cparams.flash_attn) {
noalibi_f16(j, pos, seq_id, first, last);
continue;
}
for (int i = first; i < last; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
@@ -2088,9 +2127,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (data) {
data[j*n_kv + i] = f;
}
if (data_f16) {
data_f16[j*n_kv + i] = ggml_fp32_to_fp16(f);
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (data_swa || data_swa_f16) {
if (f > -INFINITY) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
@@ -2103,7 +2145,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
data_swa[j*n_kv + i] = f;
if (data_swa) {
data_swa[j*n_kv + i] = f;
}
if (data_swa_f16) {
data_swa_f16[j*n_kv + i] = ggml_fp32_to_fp16(f);
}
}
}
}
@@ -2120,6 +2167,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
if (data_f16) {
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_f16[i*n_kv + j] = h_inf;
}
}
}
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
@@ -2128,6 +2183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}
if (data_swa_f16) {
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa_f16[i*n_kv + j] = h_inf;
}
}
}
}
else {
@@ -2135,14 +2198,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
auto data_h = data ? data + h*(n_kv*n_tokens) : nullptr;
auto data_swa_h = data_swa ? data_swa + h*(n_kv*n_tokens) : nullptr;
for (int j = 0; j < n_tokens; ++j) {
auto data_j = data_h ? data_h + j*n_kv : nullptr;
auto data_swa_j = data_swa_h ? data_swa_h + j*n_kv : nullptr;
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
if (!hparams.use_alibi && cparams.flash_attn) {
noalibi_f16(j, pos, seq_id, 0, n_kv);
continue;
}
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
@@ -2155,12 +2219,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (data_j) {
data_j[i] = f;
if (data) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
if (data_f16) {
data_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f);
}
// may need to cut off old tokens for sliding window
if (data_swa_j) {
if (data_swa || data_swa_f16) {
if (hparams.n_attn_chunk) {
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
@@ -2171,23 +2238,44 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
f = -INFINITY;
}
}
data_swa_j[i] = f;
if (data_swa) {
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
if (data_swa_f16) {
data_swa_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f);
}
}
}
}
if (data_h) {
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_h[i*n_kv + j] = -INFINITY;
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
if (data_f16) {
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_f16[h*(n_kv*n_tokens) + i*n_kv + j] = h_inf;
}
}
}
if (data_swa_h) {
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa_h[i*n_kv + j] = -INFINITY;
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
if (data_swa_f16) {
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa_f16[h*(n_kv*n_tokens) + i*n_kv + j] = h_inf;
}
}
}