mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 10:30:27 +00:00
WIP KQ binary mask
This commit is contained in:
@@ -231,6 +231,7 @@ struct cmd_params {
|
||||
std::vector<int> main_gpu;
|
||||
std::vector<bool> no_kv_offload;
|
||||
std::vector<bool> flash_attn;
|
||||
std::vector<bool> binary_kq;
|
||||
std::vector<std::vector<float>> tensor_split;
|
||||
std::vector<bool> use_mmap;
|
||||
std::vector<bool> embeddings;
|
||||
@@ -258,6 +259,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* main_gpu */ {0},
|
||||
/* no_kv_offload */ {false},
|
||||
/* flash_attn */ {false},
|
||||
/* binary_kq */ {false},
|
||||
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
||||
/* use_mmap */ {true},
|
||||
/* embeddings */ {false},
|
||||
@@ -289,6 +291,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||
printf(" -bkq, --binary-kq <0|1> (default: %s)\n", join(cmd_params_defaults.binary_kq, ",").c_str());
|
||||
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
|
||||
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
||||
@@ -503,6 +506,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
}
|
||||
auto p = string_split<bool>(argv[i], split_delim);
|
||||
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
|
||||
} else if (arg == "-bkq" || arg == "--binary-kq") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = string_split<bool>(argv[i], split_delim);
|
||||
params.binary_kq.insert(params.binary_kq.end(), p.begin(), p.end());
|
||||
} else if (arg == "-mmp" || arg == "--mmap") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -591,6 +601,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
||||
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
||||
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
|
||||
if (params.binary_kq.empty()) { params.binary_kq = cmd_params_defaults.binary_kq; }
|
||||
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
||||
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
||||
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
||||
@@ -614,6 +625,7 @@ struct cmd_params_instance {
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
bool binary_kq;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool embeddings;
|
||||
@@ -653,6 +665,7 @@ struct cmd_params_instance {
|
||||
cparams.type_v = type_v;
|
||||
cparams.offload_kqv = !no_kv_offload;
|
||||
cparams.flash_attn = flash_attn;
|
||||
cparams.binary_kq = binary_kq;
|
||||
cparams.embeddings = embeddings;
|
||||
|
||||
return cparams;
|
||||
@@ -677,6 +690,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
for (const auto & tv : params.type_v)
|
||||
for (const auto & nkvo : params.no_kv_offload)
|
||||
for (const auto & fa : params.flash_attn)
|
||||
for (const auto & bkq : params.binary_kq)
|
||||
for (const auto & nt : params.n_threads) {
|
||||
for (const auto & n_prompt : params.n_prompt) {
|
||||
if (n_prompt == 0) {
|
||||
@@ -697,6 +711,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .main_gpu = */ mg,
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .binary_kq = */ bkq,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -723,6 +738,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .main_gpu = */ mg,
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .binary_kq = */ bkq,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -749,6 +765,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .main_gpu = */ mg,
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .binary_kq = */ bkq,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -787,6 +804,7 @@ struct test {
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
bool binary_kq;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool embeddings;
|
||||
@@ -813,6 +831,7 @@ struct test {
|
||||
main_gpu = inst.main_gpu;
|
||||
no_kv_offload = inst.no_kv_offload;
|
||||
flash_attn = inst.flash_attn;
|
||||
binary_kq = inst.binary_kq;
|
||||
tensor_split = inst.tensor_split;
|
||||
use_mmap = inst.use_mmap;
|
||||
embeddings = inst.embeddings;
|
||||
@@ -884,7 +903,7 @@ struct test {
|
||||
"n_batch", "n_ubatch",
|
||||
"n_threads", "type_k", "type_v",
|
||||
"n_gpu_layers", "split_mode",
|
||||
"main_gpu", "no_kv_offload", "flash_attn",
|
||||
"main_gpu", "no_kv_offload", "flash_attn", "binary-kq",
|
||||
"tensor_split", "use_mmap", "embeddings",
|
||||
"n_prompt", "n_gen", "test_time",
|
||||
"avg_ns", "stddev_ns",
|
||||
@@ -906,7 +925,7 @@ struct test {
|
||||
}
|
||||
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
||||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
||||
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
|
||||
field == "flash_attn" || field == "binary-kq" || field == "use_mmap" || field == "embeddings") {
|
||||
return BOOL;
|
||||
}
|
||||
if (field == "avg_ts" || field == "stddev_ts") {
|
||||
@@ -940,7 +959,7 @@ struct test {
|
||||
std::to_string(n_batch), std::to_string(n_ubatch),
|
||||
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
||||
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
||||
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
|
||||
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(binary_kq),
|
||||
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||
@@ -1103,6 +1122,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "flash_attn") {
|
||||
return 2;
|
||||
}
|
||||
if (field == "binary-kq") {
|
||||
return 3;
|
||||
}
|
||||
if (field == "use_mmap") {
|
||||
return 4;
|
||||
}
|
||||
@@ -1134,6 +1156,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "flash_attn") {
|
||||
return "fa";
|
||||
}
|
||||
if (field == "binary-kq") {
|
||||
return "bkq";
|
||||
}
|
||||
if (field == "use_mmap") {
|
||||
return "mmap";
|
||||
}
|
||||
@@ -1183,6 +1208,9 @@ struct markdown_printer : public printer {
|
||||
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
|
||||
fields.emplace_back("flash_attn");
|
||||
}
|
||||
if (params.binary_kq.size() > 1 || params.binary_kq != cmd_params_defaults.binary_kq) {
|
||||
fields.emplace_back("binary-kq");
|
||||
}
|
||||
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
||||
fields.emplace_back("tensor_split");
|
||||
}
|
||||
|
||||
112
ggml/src/ggml.c
112
ggml/src/ggml.c
@@ -2074,18 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y
|
||||
}
|
||||
return max;
|
||||
}
|
||||
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
|
||||
GGML_ASSERT(n%16 == 0);
|
||||
__m512 vmax = _mm512_set1_ps(-INFINITY);
|
||||
__m512 vinf = _mm512_set1_ps(-INFINITY);
|
||||
const __mmask16 * mm16 = (const __mmask16 *)x;
|
||||
for (int j = 0; j < n/16; ++j) {
|
||||
__m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf);
|
||||
_mm512_storeu_ps(y + 16*j, v);
|
||||
vmax = _mm512_max_ps(vmax, v);
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) {
|
||||
@@ -2093,6 +2081,7 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_UNUSED(slope);
|
||||
GGML_ASSERT(false);
|
||||
return 0.f;
|
||||
}
|
||||
static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) {
|
||||
@@ -2100,12 +2089,7 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_UNUSED(slope);
|
||||
return 0.f;
|
||||
}
|
||||
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
|
||||
GGML_UNUSED(n);
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_ASSERT(false);
|
||||
return 0.f;
|
||||
}
|
||||
#endif
|
||||
@@ -2925,7 +2909,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX512__
|
||||
#ifdef __AVX512F__
|
||||
static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
|
||||
const __mmask16 * m16 = (const __mmask16 *)mask;
|
||||
__m512 vinf = _mm512_set1_ps(-INFINITY);
|
||||
@@ -2939,6 +2923,18 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float *
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
}
|
||||
static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) {
|
||||
const __mmask16 * m16 = (const __mmask16 *)mask;
|
||||
__m512 vinf = _mm512_set1_ps(-INFINITY);
|
||||
__m512 vmax = vinf;
|
||||
__m512 vscale = _mm512_set1_ps(scale);
|
||||
for (int i = 0; i < n/16; ++i) {
|
||||
__m512 v = _mm512_mask_mul_ps(vinf, m16[i], vscale, _mm512_loadu_ps(x + 16*i));
|
||||
_mm512_storeu_ps(y + 16*i, v);
|
||||
vmax = _mm512_max_ps(vmax, v);
|
||||
}
|
||||
return _mm512_reduce_max_ps(vmax);
|
||||
}
|
||||
//#elif __ARM_NEON
|
||||
//static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
|
||||
// //const uint16_t * mask16 = (const uint16_t *)mask;
|
||||
@@ -3004,6 +3000,15 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float *
|
||||
GGML_ASSERT(false);
|
||||
return 0.f;
|
||||
}
|
||||
static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) {
|
||||
GGML_UNUSED(n);
|
||||
GGML_UNUSED(x);
|
||||
GGML_UNUSED(y);
|
||||
GGML_UNUSED(mask);
|
||||
GGML_UNUSED(scale);
|
||||
GGML_ASSERT(false);
|
||||
return 0.f;
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) {
|
||||
@@ -13952,16 +13957,9 @@ static void ggml_compute_forward_softcap_max_f32(
|
||||
if (use_f16) {
|
||||
ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00;
|
||||
max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope);
|
||||
} else if (use_i32) {
|
||||
int n32 = (ne00 + 31)/32;
|
||||
const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32;
|
||||
max = ggml_vec_add_f32_infmask(nc, mp_u32, wp);
|
||||
} else {
|
||||
float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00;
|
||||
max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope);
|
||||
//for (int i = 0; i < nc; ++i) {
|
||||
// wp[i] += slope*mp_f32[i];
|
||||
//}
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -14745,6 +14743,7 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
// ALiBi
|
||||
@@ -14754,33 +14753,54 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
|
||||
// broadcast the mask across rows
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
float max = -INFINITY;
|
||||
if (use_u32) {
|
||||
int n32 = ne00/32;
|
||||
const uint32_t * mp_u32 = (const uint32_t *)src1->data + (i1%ne01)*n32;
|
||||
max = ggml_vec_cpy_soft_mask_f32(nc, sp, wp, mp_u32, scale);
|
||||
} else {
|
||||
|
||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
ggml_vec_scale_f32(nc, wp, scale);
|
||||
if (mp_f32) {
|
||||
if (use_f16) {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += slope*mp_f32[i];
|
||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
ggml_vec_scale_f32(nc, wp, scale);
|
||||
if (src1) {
|
||||
// broadcast the mask across rows
|
||||
if (use_f16) {
|
||||
ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00;
|
||||
max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope);
|
||||
} else {
|
||||
float * mp_f32 = (float *)((char *) src1->data) + (i1%ne01)*ne00;
|
||||
max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
ggml_vec_max_f32(nc, &max, wp);
|
||||
}
|
||||
|
||||
//// broadcast the mask across rows
|
||||
//ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
//float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
|
||||
//if (mp_f32) {
|
||||
// if (use_f16) {
|
||||
// for (int i = 0; i < nc; ++i) {
|
||||
// wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
// }
|
||||
// } else {
|
||||
// for (int i = 0; i < nc; ++i) {
|
||||
// wp[i] += slope*mp_f32[i];
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
//printf("p[%d] = %f\n", i, p[i]);
|
||||
assert(!isnan(wp[i]));
|
||||
}
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
//printf("p[%d] = %f\n", i, p[i]);
|
||||
assert(!isnan(wp[i]));
|
||||
}
|
||||
#endif
|
||||
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(nc, &max, wp);
|
||||
ggml_vec_max_f32(nc, &max, wp);
|
||||
}
|
||||
|
||||
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
|
||||
assert(sum > 0.0);
|
||||
|
||||
Reference in New Issue
Block a user