WIP: play with KQ mask - make it fp16

This commit is contained in:
Iwan Kawrakow
2024-08-27 19:08:31 +03:00
parent c7e99c88a2
commit 0f301124b1

View File

@@ -8687,25 +8687,40 @@ struct llm_build_context {
}
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16;
lctx.inp_KQ_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
//lctx.inp_KQ_mask = causal
// ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
// : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(lctx.inp_KQ_mask, "KQ_mask", -1);
//ggml_set_input(lctx.inp_KQ_mask);
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
//return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
}
struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
GGML_ASSERT(hparams.n_swa > 0);
auto type = hparams.use_alibi ? GGML_TYPE_F32 : GGML_TYPE_F16;
lctx.inp_KQ_mask_swa = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
? ggml_new_tensor_2d(ctx0, type, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, type, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(lctx.inp_KQ_mask_swa);
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
return flash_attn && type == GGML_TYPE_F32 ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
//lctx.inp_KQ_mask_swa = causal
// ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
// : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
//ggml_set_input(lctx.inp_KQ_mask_swa);
//return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
}
struct ggml_tensor * build_inp_mean() {
@@ -14259,71 +14274,122 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
float * data = nullptr;
float * data_swa = nullptr;
if (lctx.inp_KQ_mask && lctx.inp_KQ_mask_swa) {
GGML_ASSERT(lctx.inp_KQ_mask->type == lctx.inp_KQ_mask_swa->type);
}
if (lctx.inp_KQ_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
data = (float *) lctx.inp_KQ_mask->data;
}
if (lctx.inp_KQ_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
}
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
auto float_type = lctx.inp_KQ_mask ? lctx.inp_KQ_mask->type : lctx.inp_KQ_mask_swa->type;
GGML_ASSERT(float_type == GGML_TYPE_F16 || float_type == GGML_TYPE_F32);
if (float_type == GGML_TYPE_F16) {
// in order this to be true, we are not using alibi
GGML_ASSERT(!hparams.use_alibi);
auto h_zero = ggml_fp32_to_fp16(0.0f);
auto h_inf = ggml_fp32_to_fp16(-INFINITY);
ggml_fp16_t * h_data = lctx.inp_KQ_mask ? (ggml_fp16_t *)lctx.inp_KQ_mask->data : nullptr;
ggml_fp16_t * h_data_swa = lctx.inp_KQ_mask_swa ? (ggml_fp16_t *)lctx.inp_KQ_mask_swa->data : nullptr;
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
} else {
f = 0.0f;
}
auto f = lctx.kv_self.cells[i].pos <= pos && lctx.kv_self.cells[i].has_seq_id(seq_id) ? h_zero : h_inf;
if (h_data) h_data[j*n_kv + i] = f;
if (h_data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) f = h_inf;
h_data_swa[j*n_kv + i] = f;
}
}
}
if (data) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
if (h_data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) h_data[i*n_kv + j] = h_inf;
}
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
if (h_data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) h_data_swa[i*n_kv + j] = h_inf;
}
}
}
else {
if (lctx.inp_KQ_mask) {
data = (float *) lctx.inp_KQ_mask->data;
}
if (lctx.inp_KQ_mask_swa) {
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
}
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(lctx.kv_self.cells[i].pos - pos);
} else {
f = 0.0f;
}
}
if (data) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
}
} else {
// TODO
GGML_ASSERT(false);
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;