Much better

This commit is contained in:
Iwan Kawrakow
2025-02-28 16:53:02 +02:00
parent efc757c95c
commit 285b97b6bb
3 changed files with 80 additions and 53 deletions

View File

@@ -233,6 +233,7 @@ struct cmd_params {
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<int> mla_attn;
std::vector<int> attn_max_batch;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = {
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* mla_attn */ {0},
/* attn_max_batch */ {0},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -301,6 +303,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -578,6 +581,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<int>(argv[i], split_delim);
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-amb" || arg == "--attn-max-batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<int>(argv[i], split_delim);
params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
@@ -690,6 +700,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -727,6 +738,7 @@ struct cmd_params_instance {
bool no_kv_offload;
bool flash_attn;
int mla_attn;
int attn_max_batch;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -773,6 +785,7 @@ struct cmd_params_instance {
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.mla_attn = mla_attn;
cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe;
cparams.embeddings = embeddings;
@@ -799,6 +812,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & mla : params.mla_attn)
for (const auto & amb : params.attn_max_batch)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
@@ -821,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -852,6 +867,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -883,6 +899,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -914,6 +931,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -956,6 +974,7 @@ struct test {
bool no_kv_offload;
bool flash_attn;
int mla_attn;
int attn_max_batch;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -987,6 +1006,7 @@ struct test {
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
mla_attn = inst.mla_attn;
attn_max_batch = inst.attn_max_batch;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
@@ -1081,7 +1101,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@@ -1097,7 +1117,7 @@ struct test {
field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" ||
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
@@ -1138,7 +1158,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1305,6 +1325,9 @@ struct markdown_printer : public printer {
if (field == "mla_attn") {
return 3;
}
if (field == "attn_max_batch") {
return 5;
}
if (field == "use_mmap") {
return 4;
}
@@ -1345,6 +1368,9 @@ struct markdown_printer : public printer {
if (field == "mla_attn") {
return "mla";
}
if (field == "attn_max_batch") {
return "amb";
}
if (field == "use_mmap") {
return "mmap";
}
@@ -1403,6 +1429,9 @@ struct markdown_printer : public printer {
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
fields.emplace_back("mla_attn");
}
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
fields.emplace_back("attn_max_batch");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}

View File

@@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
if (dim != 3) {
if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
} else {
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(
src0_d + i3 * (src0->nb[3] / 4),
@@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
}
} else {
const size_t size0 = ggml_nbytes(src0);
const size_t size1 = ggml_nbytes(src1);
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
}
//if (dim != 3) {
// for (int i3 = 0; i3 < dst->ne[3]; i3++) {
// concat_f32_cuda(
// src0_d + i3 * (src0->nb[3] / 4),
// src1_d + i3 * (src1->nb[3] / 4),
// dst_d + i3 * ( dst->nb[3] / 4),
// src0->ne[0], src0->ne[1], src0->ne[2],
// dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
// }
//} else {
// const size_t size0 = ggml_nbytes(src0);
// const size_t size1 = ggml_nbytes(src1);
// CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
// CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
//}
} else {
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(

View File

@@ -8785,7 +8785,8 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= k->ne[1] || q->ne[2] == 1) {
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@@ -8834,13 +8835,21 @@ static struct ggml_tensor * llm_build_kqv(
cb(cur, "kqv_merged_cont", il);
}
else {
// For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2];
GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]);
int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch;
n_step = std::min(n_step, int(k->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
auto r2k = q->ne[2] / k->ne[2];
auto r2v = q->ne[2] / n_head_kv;
auto r2v = q->ne[2] / v->ne[2];
n_step = q->ne[2];
n_per_step = 1;
ggml_tensor * kqv;
for (int i12 = 0; i12 < q->ne[2]; ++i12) {
for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) {
int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12;
int i02 = i12/r2k;
auto k_i = ggml_view_2d(ctx, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i02);
auto q_i = ggml_view_2d(ctx, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i12);
auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
@@ -8855,7 +8864,7 @@ static struct ggml_tensor * llm_build_kqv(
kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
i02 = i12 / r2v;
auto v_i = ggml_view_2d(ctx, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i02);
auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02);
auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i);
if (i12 == 0) {
kqv = kqv_i;
@@ -13625,7 +13634,8 @@ struct llm_build_context {
}
ggml_tensor * kqv_compressed;
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) {
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
@@ -13634,9 +13644,6 @@ struct llm_build_context {
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
//printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3],
// kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
@@ -13653,10 +13660,6 @@ struct llm_build_context {
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
//printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n",
// kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3],
// kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
@@ -13664,46 +13667,25 @@ struct llm_build_context {
} else {
int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
//kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]);
//printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step);
//printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]);
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
for (int i_head = 0; i_head < q->ne[2]; ++i_head) {
ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head);
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0);
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
//ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head);
//ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//ggml_build_forward_expand(gf, head_i);
}
//for (int i_step = 0; i_step < n_step; ++i_step) {
// int i_start = i_step * lctx.cparams.attn_max_batch;
// int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start;
// ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]);
// cb(q_i, "q_i", il);
// ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
// cb(kq_i, "kq_i", il);
// ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]);
// kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias);
// cb(kq_i, "kq_i_softmwax", il);
// ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
// cb(kqv_i, "kqv_i", il);
// ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2],
// kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]);
// printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch,
// kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]);
// ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//}
cb(kqv_compressed, "kqv_compressed", il);
}