mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-11 06:20:09 +00:00
Whith FA on, create mask as f16 directly
This commit is contained in:
@@ -276,6 +276,12 @@ ggml_tensor * llm_build_context::build_inp_out_ids() {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
|
||||
if (causal && flash_attn) {
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask);
|
||||
return lctx.inp_KQ_mask;
|
||||
}
|
||||
lctx.inp_KQ_mask = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
@@ -287,6 +293,12 @@ ggml_tensor * llm_build_context::build_inp_KQ_mask(bool causal) {
|
||||
|
||||
ggml_tensor * llm_build_context::build_inp_KQ_mask_swa(bool causal) {
|
||||
GGML_ASSERT(hparams.n_swa > 0);
|
||||
if (causal && flash_attn) {
|
||||
lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask_swa);
|
||||
return lctx.inp_KQ_mask_swa;
|
||||
}
|
||||
|
||||
lctx.inp_KQ_mask_swa = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
|
||||
122
src/llama.cpp
122
src/llama.cpp
@@ -2051,21 +2051,55 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
ggml_half * data_f16 = nullptr;
|
||||
ggml_half * data_swa_f16 = nullptr;
|
||||
|
||||
if (lctx.inp_KQ_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
data = (float *) lctx.inp_KQ_mask->data;
|
||||
if (cparams.flash_attn) {
|
||||
data_f16 = (ggml_half *)lctx.inp_KQ_mask->data;
|
||||
} else {
|
||||
data = (float *) lctx.inp_KQ_mask->data;
|
||||
}
|
||||
}
|
||||
|
||||
if (lctx.inp_KQ_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
|
||||
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
||||
if (cparams.flash_attn) {
|
||||
data_swa_f16 = (ggml_half *) lctx.inp_KQ_mask_swa->data;
|
||||
} else {
|
||||
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
|
||||
}
|
||||
}
|
||||
|
||||
auto noalibi_f16 = [&lctx, &hparams, n_kv, data_f16, data_swa_f16] (int j, llama_pos pos, llama_seq_id seq_id, int first, int last) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
ggml_half h_zero = ggml_fp32_to_fp16(0.f);
|
||||
for (int i = first; i < last; ++i) {
|
||||
ggml_half h = !lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos ? h_inf : h_zero;
|
||||
if (data_f16) data_f16[j*n_kv + i] = h;
|
||||
if (data_swa_f16) {
|
||||
if (h != h_inf) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
h = h_inf;
|
||||
}
|
||||
} else {
|
||||
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
h = h_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
data_swa_f16[j*n_kv + i] = h;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (n_kv >= 1024 && n_tokens >= 32) {
|
||||
int n_thread = std::max(1, int(std::thread::hardware_concurrency()/2));
|
||||
int npt = (n_kv + n_thread - 1)/n_thread;
|
||||
auto compute = [&batch, &lctx, &hparams, n_tokens, n_kv, npt, data, data_swa] (int ith) {
|
||||
auto compute = [&batch, &lctx, &hparams, &cparams, &noalibi_f16, n_tokens, n_kv, npt, data, data_swa, data_f16, data_swa_f16] (int ith) {
|
||||
int first = ith * npt;
|
||||
int last = std::min(int(n_kv), first + npt);
|
||||
if (last <= first) return;
|
||||
@@ -2073,6 +2107,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
if (!hparams.use_alibi && cparams.flash_attn) {
|
||||
noalibi_f16(j, pos, seq_id, first, last);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = first; i < last; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
@@ -2088,9 +2127,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
if (data) {
|
||||
data[j*n_kv + i] = f;
|
||||
}
|
||||
if (data_f16) {
|
||||
data_f16[j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (data_swa || data_swa_f16) {
|
||||
if (f > -INFINITY) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
@@ -2103,7 +2145,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
}
|
||||
data_swa[j*n_kv + i] = f;
|
||||
if (data_swa) {
|
||||
data_swa[j*n_kv + i] = f;
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
data_swa_f16[j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2120,6 +2167,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_f16[i*n_kv + j] = h_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
@@ -2128,6 +2183,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa_f16[i*n_kv + j] = h_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -2135,14 +2198,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
// of the correct sequence for each token of the batch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
auto data_h = data ? data + h*(n_kv*n_tokens) : nullptr;
|
||||
auto data_swa_h = data_swa ? data_swa + h*(n_kv*n_tokens) : nullptr;
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
auto data_j = data_h ? data_h + j*n_kv : nullptr;
|
||||
auto data_swa_j = data_swa_h ? data_swa_h + j*n_kv : nullptr;
|
||||
const llama_pos pos = batch.pos[j];
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
if (!hparams.use_alibi && cparams.flash_attn) {
|
||||
noalibi_f16(j, pos, seq_id, 0, n_kv);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
@@ -2155,12 +2219,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (data_j) {
|
||||
data_j[i] = f;
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
if (data_f16) {
|
||||
data_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa_j) {
|
||||
if (data_swa || data_swa_f16) {
|
||||
if (hparams.n_attn_chunk) {
|
||||
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
||||
if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
||||
@@ -2171,23 +2238,44 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
}
|
||||
data_swa_j[i] = f;
|
||||
if (data_swa) {
|
||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
data_swa_f16[h*(n_kv*n_tokens) + j*n_kv + i] = ggml_fp32_to_fp16(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_h) {
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_h[i*n_kv + j] = -INFINITY;
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_f16[h*(n_kv*n_tokens) + i*n_kv + j] = h_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa_h) {
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa_h[i*n_kv + j] = -INFINITY;
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_swa_f16) {
|
||||
ggml_half h_inf = ggml_fp32_to_fp16(-INFINITY);
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa_f16[h*(n_kv*n_tokens) + i*n_kv + j] = h_inf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user