From 316345c5353dded20f0ffa3371d02c5277e8f323 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 28 Aug 2024 17:27:48 +0300 Subject: [PATCH] WIP KQ binary mask --- examples/llama-bench/llama-bench.cpp | 34 +++++++- ggml/src/ggml.c | 112 ++++++++++++++++----------- 2 files changed, 97 insertions(+), 49 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 813d7bae..0736e393 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -231,6 +231,7 @@ struct cmd_params { std::vector main_gpu; std::vector no_kv_offload; std::vector flash_attn; + std::vector binary_kq; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -258,6 +259,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, + /* binary_kq */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -289,6 +291,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -bkq, --binary-kq <0|1> (default: %s)\n", join(cmd_params_defaults.binary_kq, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -503,6 +506,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-bkq" || arg == "--binary-kq") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.binary_kq.insert(params.binary_kq.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -591,6 +601,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.binary_kq.empty()) { params.binary_kq = cmd_params_defaults.binary_kq; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -614,6 +625,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; + bool binary_kq; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -653,6 +665,7 @@ struct cmd_params_instance { cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; + cparams.binary_kq = binary_kq; cparams.embeddings = embeddings; return cparams; @@ -677,6 +690,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) + for (const auto & bkq : params.binary_kq) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -697,6 +711,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -723,6 +738,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -749,6 +765,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .binary_kq = */ bkq, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -787,6 +804,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; + bool binary_kq; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -813,6 +831,7 @@ struct test { main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; + binary_kq = inst.binary_kq; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -884,7 +903,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", + "main_gpu", "no_kv_offload", "flash_attn", "binary-kq", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -906,7 +925,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "binary-kq" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -940,7 +959,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(binary_kq), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1103,6 +1122,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return 2; } + if (field == "binary-kq") { + return 3; + } if (field == "use_mmap") { return 4; } @@ -1134,6 +1156,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return "fa"; } + if (field == "binary-kq") { + return "bkq"; + } if (field == "use_mmap") { return "mmap"; } @@ -1183,6 +1208,9 @@ struct markdown_printer : public printer { if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { fields.emplace_back("flash_attn"); } + if (params.binary_kq.size() > 1 || params.binary_kq != cmd_params_defaults.binary_kq) { + fields.emplace_back("binary-kq"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e45732aa..39987217 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2074,18 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y } return max; } -static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { - GGML_ASSERT(n%16 == 0); - __m512 vmax = _mm512_set1_ps(-INFINITY); - __m512 vinf = _mm512_set1_ps(-INFINITY); - const __mmask16 * mm16 = (const __mmask16 *)x; - for (int j = 0; j < n/16; ++j) { - __m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf); - _mm512_storeu_ps(y + 16*j, v); - vmax = _mm512_max_ps(vmax, v); - } - return _mm512_reduce_max_ps(vmax); -} #else // TODO static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) { @@ -2093,6 +2081,7 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(slope); + GGML_ASSERT(false); return 0.f; } static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) { @@ -2100,12 +2089,7 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(slope); - return 0.f; -} -static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) { - GGML_UNUSED(n); - GGML_UNUSED(x); - GGML_UNUSED(y); + GGML_ASSERT(false); return 0.f; } #endif @@ -2925,7 +2909,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl } } -#ifdef __AVX512__ +#ifdef __AVX512F__ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { const __mmask16 * m16 = (const __mmask16 *)mask; __m512 vinf = _mm512_set1_ps(-INFINITY); @@ -2939,6 +2923,18 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * } return _mm512_reduce_max_ps(vmax); } +static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) { + const __mmask16 * m16 = (const __mmask16 *)mask; + __m512 vinf = _mm512_set1_ps(-INFINITY); + __m512 vmax = vinf; + __m512 vscale = _mm512_set1_ps(scale); + for (int i = 0; i < n/16; ++i) { + __m512 v = _mm512_mask_mul_ps(vinf, m16[i], vscale, _mm512_loadu_ps(x + 16*i)); + _mm512_storeu_ps(y + 16*i, v); + vmax = _mm512_max_ps(vmax, v); + } + return _mm512_reduce_max_ps(vmax); +} //#elif __ARM_NEON //static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) { // //const uint16_t * mask16 = (const uint16_t *)mask; @@ -3004,6 +3000,15 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * GGML_ASSERT(false); return 0.f; } +static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) { + GGML_UNUSED(n); + GGML_UNUSED(x); + GGML_UNUSED(y); + GGML_UNUSED(mask); + GGML_UNUSED(scale); + GGML_ASSERT(false); + return 0.f; +} #endif static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { @@ -13952,16 +13957,9 @@ static void ggml_compute_forward_softcap_max_f32( if (use_f16) { ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00; max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); - } else if (use_i32) { - int n32 = (ne00 + 31)/32; - const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32; - max = ggml_vec_add_f32_infmask(nc, mp_u32, wp); } else { float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00; max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); - //for (int i = 0; i < nc; ++i) { - // wp[i] += slope*mp_f32[i]; - //} } } else { @@ -14745,6 +14743,7 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32); for (int i1 = ir0; i1 < ir1; i1++) { // ALiBi @@ -14754,33 +14753,54 @@ static void ggml_compute_forward_soft_max_f32( float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float max = -INFINITY; + if (use_u32) { + int n32 = ne00/32; + const uint32_t * mp_u32 = (const uint32_t *)src1->data + (i1%ne01)*n32; + max = ggml_vec_cpy_soft_mask_f32(nc, sp, wp, mp_u32, scale); + } else { - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (src1) { + // broadcast the mask across rows + if (use_f16) { + ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00; + max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope); + } else { + float * mp_f32 = (float *)((char *) src1->data) + (i1%ne01)*ne00; + max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope); } } - } + else { + ggml_vec_max_f32(nc, &max, wp); + } + + //// broadcast the mask across rows + //ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + //float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + //if (mp_f32) { + // if (use_f16) { + // for (int i = 0; i < nc; ++i) { + // wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + // } + // } else { + // for (int i = 0; i < nc; ++i) { + // wp[i] += slope*mp_f32[i]; + // } + // } + //} #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + ggml_vec_max_f32(nc, &max, wp); + } ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0);