WIP KQ binary mask

This commit is contained in:
Iwan Kawrakow
2024-08-28 17:27:48 +03:00
parent a8b762ddd9
commit 316345c535
2 changed files with 97 additions and 49 deletions

View File

@@ -231,6 +231,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<bool> binary_kq;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -258,6 +259,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* binary_kq */ {false},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -289,6 +291,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -bkq, --binary-kq <0|1> (default: %s)\n", join(cmd_params_defaults.binary_kq, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -503,6 +506,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
} else if (arg == "-bkq" || arg == "--binary-kq") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<bool>(argv[i], split_delim);
params.binary_kq.insert(params.binary_kq.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
@@ -591,6 +601,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.binary_kq.empty()) { params.binary_kq = cmd_params_defaults.binary_kq; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -614,6 +625,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
bool binary_kq;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -653,6 +665,7 @@ struct cmd_params_instance {
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.binary_kq = binary_kq;
cparams.embeddings = embeddings;
return cparams;
@@ -677,6 +690,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & bkq : params.binary_kq)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
@@ -697,6 +711,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .binary_kq = */ bkq,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -723,6 +738,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .binary_kq = */ bkq,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -749,6 +765,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .binary_kq = */ bkq,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -787,6 +804,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
bool binary_kq;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -813,6 +831,7 @@ struct test {
main_gpu = inst.main_gpu;
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
binary_kq = inst.binary_kq;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
@@ -884,7 +903,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn",
"main_gpu", "no_kv_offload", "flash_attn", "binary-kq",
"tensor_split", "use_mmap", "embeddings",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@@ -906,7 +925,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
field == "flash_attn" || field == "binary-kq" || field == "use_mmap" || field == "embeddings") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -940,7 +959,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(binary_kq),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1103,6 +1122,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return 2;
}
if (field == "binary-kq") {
return 3;
}
if (field == "use_mmap") {
return 4;
}
@@ -1134,6 +1156,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return "fa";
}
if (field == "binary-kq") {
return "bkq";
}
if (field == "use_mmap") {
return "mmap";
}
@@ -1183,6 +1208,9 @@ struct markdown_printer : public printer {
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
fields.emplace_back("flash_attn");
}
if (params.binary_kq.size() > 1 || params.binary_kq != cmd_params_defaults.binary_kq) {
fields.emplace_back("binary-kq");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}

View File

@@ -2074,18 +2074,6 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y
}
return max;
}
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
GGML_ASSERT(n%16 == 0);
__m512 vmax = _mm512_set1_ps(-INFINITY);
__m512 vinf = _mm512_set1_ps(-INFINITY);
const __mmask16 * mm16 = (const __mmask16 *)x;
for (int j = 0; j < n/16; ++j) {
__m512 v = _mm512_mask_blend_ps(mm16[j], _mm512_loadu_ps(y + 16*j), vinf);
_mm512_storeu_ps(y + 16*j, v);
vmax = _mm512_max_ps(vmax, v);
}
return _mm512_reduce_max_ps(vmax);
}
#else
// TODO
static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float * y, float slope) {
@@ -2093,6 +2081,7 @@ static inline float ggml_vec_add_f32_f16(const int n, const ggml_half * x, float
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_UNUSED(slope);
GGML_ASSERT(false);
return 0.f;
}
static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y, float slope) {
@@ -2100,12 +2089,7 @@ static inline float ggml_vec_add_f32_f32(const int n, const float * x, float * y
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_UNUSED(slope);
return 0.f;
}
static inline float ggml_vec_add_f32_infmask(const int n, const uint32_t * x, float * y) {
GGML_UNUSED(n);
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_ASSERT(false);
return 0.f;
}
#endif
@@ -2925,7 +2909,7 @@ static void ggml_vec_cpy_softcap_f32(const int n, const float * x, float * y, fl
}
}
#ifdef __AVX512__
#ifdef __AVX512F__
static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
const __mmask16 * m16 = (const __mmask16 *)mask;
__m512 vinf = _mm512_set1_ps(-INFINITY);
@@ -2939,6 +2923,18 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float *
}
return _mm512_reduce_max_ps(vmax);
}
static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) {
const __mmask16 * m16 = (const __mmask16 *)mask;
__m512 vinf = _mm512_set1_ps(-INFINITY);
__m512 vmax = vinf;
__m512 vscale = _mm512_set1_ps(scale);
for (int i = 0; i < n/16; ++i) {
__m512 v = _mm512_mask_mul_ps(vinf, m16[i], vscale, _mm512_loadu_ps(x + 16*i));
_mm512_storeu_ps(y + 16*i, v);
vmax = _mm512_max_ps(vmax, v);
}
return _mm512_reduce_max_ps(vmax);
}
//#elif __ARM_NEON
//static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float s_before, float s_after) {
// //const uint16_t * mask16 = (const uint16_t *)mask;
@@ -3004,6 +3000,15 @@ static float ggml_vec_cpy_softcap_mask_f32(const int n, const float * x, float *
GGML_ASSERT(false);
return 0.f;
}
static float ggml_vec_cpy_soft_mask_f32(const int n, const float * x, float * y, const uint32_t * mask, float scale) {
GGML_UNUSED(n);
GGML_UNUSED(x);
GGML_UNUSED(y);
GGML_UNUSED(mask);
GGML_UNUSED(scale);
GGML_ASSERT(false);
return 0.f;
}
#endif
static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) {
@@ -13952,16 +13957,9 @@ static void ggml_compute_forward_softcap_max_f32(
if (use_f16) {
ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + mask_row*ne00;
max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope);
} else if (use_i32) {
int n32 = (ne00 + 31)/32;
const uint32_t * mp_u32 = (const uint32_t *)src1->data + mask_row*n32;
max = ggml_vec_add_f32_infmask(nc, mp_u32, wp);
} else {
float * mp_f32 = (float *)((char *) src1->data) + mask_row*ne00;
max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope);
//for (int i = 0; i < nc; ++i) {
// wp[i] += slope*mp_f32[i];
//}
}
}
else {
@@ -14745,6 +14743,7 @@ static void ggml_compute_forward_soft_max_f32(
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const bool use_u32 = (src1 && src1->type == GGML_TYPE_I32);
for (int i1 = ir0; i1 < ir1; i1++) {
// ALiBi
@@ -14754,33 +14753,54 @@ static void ggml_compute_forward_soft_max_f32(
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
// broadcast the mask across rows
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
float max = -INFINITY;
if (use_u32) {
int n32 = ne00/32;
const uint32_t * mp_u32 = (const uint32_t *)src1->data + (i1%ne01)*n32;
max = ggml_vec_cpy_soft_mask_f32(nc, sp, wp, mp_u32, scale);
} else {
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
if (mp_f32) {
if (use_f16) {
for (int i = 0; i < nc; ++i) {
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
}
} else {
for (int i = 0; i < nc; ++i) {
wp[i] += slope*mp_f32[i];
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
if (src1) {
// broadcast the mask across rows
if (use_f16) {
ggml_fp16_t * mp_f16 = (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00;
max = ggml_vec_add_f32_f16(nc, mp_f16, wp, slope);
} else {
float * mp_f32 = (float *)((char *) src1->data) + (i1%ne01)*ne00;
max = ggml_vec_add_f32_f32(nc, mp_f32, wp, slope);
}
}
}
else {
ggml_vec_max_f32(nc, &max, wp);
}
//// broadcast the mask across rows
//ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
//float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
//if (mp_f32) {
// if (use_f16) {
// for (int i = 0; i < nc; ++i) {
// wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
// }
// } else {
// for (int i = 0; i < nc; ++i) {
// wp[i] += slope*mp_f32[i];
// }
// }
//}
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i]));
}
for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i]));
}
#endif
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp);
ggml_vec_max_f32(nc, &max, wp);
}
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
assert(sum > 0.0);