From 05f95229a7eb1a0625daf87ef2ba2da7d3e8a915 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 15:01:02 +0200 Subject: [PATCH] WIP KQ binary mask: make it a parameter, turn on via command line It is a pain to implement binary mask to 32-bit value conversion on NEON and AVX2, so I decided to make the binary mask optional There is also a commented out (and not working) attempt for NEON in this commit. --- common/common.cpp | 7 ++++++ common/common.h | 1 + ggml/src/ggml.c | 55 +++++++++++++++++++++++++++++++++++++++++++++++ include/llama.h | 1 + src/llama.cpp | 10 +++++++-- 5 files changed, 72 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3b45d066..85baa5e2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -808,6 +808,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } + if (arg == "-bkq" || arg == "--binary-kq") { + params.binary_kq = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1442,6 +1446,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-bkq, --binary-kq", "enable binary KQ mask (default: %s)", params.binary_kq ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2265,6 +2270,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.binary_kq = params.binary_kq; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3261,6 +3267,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "binary_kq: %s # default: false\n", params.binary_kq ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 50035897..28b56471 100644 --- a/common/common.h +++ b/common/common.h @@ -173,6 +173,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool binary_kq = false; // use binary KQ mask (if allowed in the given context) bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 09b6c0b4..e45732aa 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2939,6 +2939,60 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * } return _mm512_reduce_max_ps(vmax); } +//#elif __ARM_NEON +//static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { +// //const uint16_t * mask16 = (const uint16_t *)mask; +// const uint8_t * mask8 = (const uint8_t *)mask; +// float32x4_t vinf = vdupq_n_f32(-INFINITY); +// float32x4x4_t vmax = { vinf, vinf, vinf, vinf }; +// float32x4_t vs_before = vdupq_n_f32(s_before); +// float32x4_t vs_after = vdupq_n_f32(s_after ); +// const uint8x16_t vmask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); +// //const uint8x8_t vmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201)); +// //static const uint32_t k_shuffle[8] = { 0x00000000, 0x01010101, 0x02020202, 0x03030303, +// // 0x04040404, 0x05050505, 0x06060606, 0x07070707 }; +// //const uint8x8x4_t vtab = vld1_u8_x4((const uint8_t *)k_shuffle); +// //for (int i = 0; i < n/16; ++i) { +// // float32x4x4_t vx = vld1q_f32_x4(x + 16*i); +// // uint8x8_t m1 = vceq_u8(vand_u8(vdup_n_u8(mask8[2*i+0]), vmask), vmask); +// // uint8x8_t m2 = vceq_u8(vand_u8(vdup_n_u8(mask8[2*i+1]), vmask), vmask); +// // uint8x16x4_t mk = { vcombine_u8(vqtbl1_u8(m1, vtab.val[0]), vqtbl1_u8(m1, vtab.val[1])), +// // for (int k = 0; k < 4; ++k) { +// // vx.val[k] = ggml_v_softcap(vx.val[k], vs_before, vs_after); +// // //uint8x16_t mk = vcombine(vqtbl1_u8(m1, vtab.val[k]), +// // uint8x16_t v_on = vandq_u8(vreinterpretq_u8_f32(vx.val[k]), mk); +// // uint8x16_t v_off = vandq_u8(vreinterpretq_u8_f32(vinf), mk); +// // vx.val[k] = vreinterpretq_f32_u8(vorrq_u8(v_on, v_off)); +// // vmax.val[k] = vmaxq_f32(vmax.val[k], vx.val[k]); +// // vst1q_f32(y + 16*i + 4*k, vx.val[k]); +// // } +// //} +// static const uint32_t k_shuffle[16] = { 0x00000000, 0x01010101, 0x02020202, 0x03030303, +// 0x04040404, 0x05050505, 0x06060606, 0x07070707, +// 0x08080808, 0x09090909, 0x0a0a0a0a, 0x0b0b0b0b, +// 0x0c0c0c0c, 0x0d0d0d0d, 0x0e0e0e0e, 0x0f0f0f0f}; +// const uint8x16x4_t vtab = vld1q_u8_x4((const uint8_t *)k_shuffle); +// for (int i = 0; i < n/16; ++i) { +// float32x4x4_t vx = vld1q_f32_x4(x + 16*i); +// uint8x16_t m = vcombine_u8(vdup_n_u8(mask8[2*i+0]), vdup_n_u8(mask8[2*i+1])); +// m = vceqq_u8(vandq_u8(m, vmask), vmask); +// for (int k = 0; k < 4; ++k) { +// vx.val[k] = ggml_v_softcap(vx.val[k], vs_before, vs_after); +// uint8x16_t mk = vqtbl1q_u8(m, vtab.val[k]); +// uint8x16_t v_on = vandq_u8(vreinterpretq_u8_f32(vx.val[k]), mk); +// uint8x16_t v_off = vandq_u8(vreinterpretq_u8_f32(vinf), mk); +// vx.val[k] = vreinterpretq_f32_u8(vorrq_u8(v_on, v_off)); +// vmax.val[k] = vmaxq_f32(vmax.val[k], vx.val[k]); +// vst1q_f32(y + 16*i + 4*k, vx.val[k]); +// } +// } +// float max = vmaxvq_f32(vmax.val[0]); +// for (int k = 1; k < 4; ++k) { +// float maxk = vmaxvq_f32(vmax.val[k]); +// max = MAX(max, maxk); +// } +// return max; +//} #else static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { GGML_UNUSED(n); @@ -2947,6 +3001,7 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * GGML_UNUSED(mask); GGML_UNUSED(s_before); GGML_UNUSED(s_after); + GGML_ASSERT(false); return 0.f; } #endif diff --git a/include/llama.h b/include/llama.h index a9af4c48..dd13d657 100644 --- a/include/llama.h +++ b/include/llama.h @@ -340,6 +340,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool binary_kq; // whether to use binary KQ mask [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 2a6edf36..83aac3da 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2348,6 +2348,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool binary_kq; enum llama_pooling_type pooling_type; @@ -8446,6 +8447,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; + const bool binary_kq; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8495,6 +8497,7 @@ struct llm_build_context { kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), + binary_kq (cparams.binary_kq), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -8689,7 +8692,7 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { auto nx = causal ? n_kv : n_tokens; // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy - auto type = !lctx.is_encoding ? flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; //auto type = flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32; if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); @@ -8702,7 +8705,7 @@ struct llm_build_context { GGML_ASSERT(hparams.n_swa > 0); auto nx = causal ? n_kv : n_tokens; // Note: we only use a binary mask when nx%32 == 0 because otherwise the CUDA implementation becomes way more messy - auto type = !lctx.is_encoding ? flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; + auto type = !lctx.is_encoding ? !binary_kq || flash_attn || hparams.use_alibi || (nx%32 != 0) ? GGML_TYPE_F16 : GGML_TYPE_I32 : GGML_TYPE_F32; if (type == GGML_TYPE_I32) nx /= 32; lctx.inp_KQ_mask_swa = ggml_new_tensor_2d(ctx0, type, nx, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1); @@ -16727,6 +16730,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.binary_kq =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -16917,6 +16921,7 @@ struct llama_context * llama_new_context_with_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.binary_kq = params.binary_kq; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -16983,6 +16988,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: binary_kq = %d\n", __func__, cparams.binary_kq); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);