mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-02 12:21:42 +00:00
Reduce size of compute buffers (#237)
* This reduces compute buffer size for MLA * This should accomplish it for standard attention * Much better * Better concat for contiguous tensors If all the op does is to concatenate the second tensor to the first, why would we want to have a loop? --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.mla_attn = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-amb" || arg == "--attention-max-batch") {
|
||||
CHECK_ARG
|
||||
params.attn_max_batch = std::stoi(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "-fmoe" || arg == "--fused-moe") {
|
||||
params.fused_moe_up_gate = true;
|
||||
return true;
|
||||
@@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
|
||||
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
|
||||
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
|
||||
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
|
||||
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
|
||||
"in conversation mode, this will be used as system prompt\n"
|
||||
@@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
cparams.offload_kqv = !params.no_kv_offload;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.mla_attn = params.mla_attn;
|
||||
cparams.attn_max_batch = params.attn_max_batch;
|
||||
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
|
||||
|
||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||
@@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
||||
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
|
||||
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
|
||||
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
|
||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||
|
||||
|
||||
@@ -175,7 +175,8 @@ struct gpt_params {
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool flash_attn = false; // flash attention
|
||||
int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
|
||||
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
|
||||
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
|
||||
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
|
||||
@@ -233,6 +233,7 @@ struct cmd_params {
|
||||
std::vector<bool> no_kv_offload;
|
||||
std::vector<bool> flash_attn;
|
||||
std::vector<int> mla_attn;
|
||||
std::vector<int> attn_max_batch;
|
||||
std::vector<std::vector<float>> tensor_split;
|
||||
std::vector<bool> use_mmap;
|
||||
std::vector<bool> embeddings;
|
||||
@@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* no_kv_offload */ {false},
|
||||
/* flash_attn */ {false},
|
||||
/* mla_attn */ {0},
|
||||
/* attn_max_batch */ {0},
|
||||
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
||||
/* use_mmap */ {true},
|
||||
/* embeddings */ {false},
|
||||
@@ -301,6 +303,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
|
||||
printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
|
||||
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
|
||||
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
||||
@@ -578,6 +581,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
}
|
||||
auto p = string_split<int>(argv[i], split_delim);
|
||||
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
|
||||
} else if (arg == "-amb" || arg == "--attn-max-batch") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = string_split<int>(argv[i], split_delim);
|
||||
params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
|
||||
} else if (arg == "-mmp" || arg == "--mmap") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -690,6 +700,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
||||
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
|
||||
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
|
||||
if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
|
||||
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
||||
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
||||
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
||||
@@ -727,6 +738,7 @@ struct cmd_params_instance {
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
int mla_attn;
|
||||
int attn_max_batch;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool embeddings;
|
||||
@@ -773,6 +785,7 @@ struct cmd_params_instance {
|
||||
cparams.offload_kqv = !no_kv_offload;
|
||||
cparams.flash_attn = flash_attn;
|
||||
cparams.mla_attn = mla_attn;
|
||||
cparams.attn_max_batch = attn_max_batch;
|
||||
cparams.fused_moe_up_gate = fmoe;
|
||||
cparams.embeddings = embeddings;
|
||||
|
||||
@@ -799,6 +812,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
for (const auto & nkvo : params.no_kv_offload)
|
||||
for (const auto & fa : params.flash_attn)
|
||||
for (const auto & mla : params.mla_attn)
|
||||
for (const auto & amb : params.attn_max_batch)
|
||||
for (const auto & nt : params.n_threads) {
|
||||
for (const auto & n_prompt : params.n_prompt) {
|
||||
if (n_prompt == 0) {
|
||||
@@ -821,6 +835,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .mla_attn = */ mla,
|
||||
/* .attn_max_b = */ amb,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -852,6 +867,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .mla_attn = */ mla,
|
||||
/* .attn_max_b = */ amb,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -883,6 +899,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .mla_attn = */ mla,
|
||||
/* .attn_max_b = */ amb,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -914,6 +931,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .mla_attn = */ mla,
|
||||
/* .attn_max_b = */ amb,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .embeddings = */ embd,
|
||||
@@ -956,6 +974,7 @@ struct test {
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
int mla_attn;
|
||||
int attn_max_batch;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool embeddings;
|
||||
@@ -987,6 +1006,7 @@ struct test {
|
||||
no_kv_offload = inst.no_kv_offload;
|
||||
flash_attn = inst.flash_attn;
|
||||
mla_attn = inst.mla_attn;
|
||||
attn_max_batch = inst.attn_max_batch;
|
||||
tensor_split = inst.tensor_split;
|
||||
use_mmap = inst.use_mmap;
|
||||
embeddings = inst.embeddings;
|
||||
@@ -1081,7 +1101,7 @@ struct test {
|
||||
"n_batch", "n_ubatch",
|
||||
"n_threads", "type_k", "type_v",
|
||||
"n_gpu_layers", "split_mode",
|
||||
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
|
||||
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch",
|
||||
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe",
|
||||
"n_prompt", "n_gen", "test_time",
|
||||
"avg_ns", "stddev_ns",
|
||||
@@ -1097,7 +1117,7 @@ struct test {
|
||||
field == "n_threads" ||
|
||||
field == "model_size" || field == "model_n_params" ||
|
||||
field == "n_gpu_layers" || field == "main_gpu" ||
|
||||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
|
||||
field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" ||
|
||||
field == "avg_ns" || field == "stddev_ns") {
|
||||
return INT;
|
||||
}
|
||||
@@ -1138,7 +1158,7 @@ struct test {
|
||||
std::to_string(n_batch), std::to_string(n_ubatch),
|
||||
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
||||
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
||||
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
|
||||
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch),
|
||||
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe),
|
||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||
@@ -1305,6 +1325,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "mla_attn") {
|
||||
return 3;
|
||||
}
|
||||
if (field == "attn_max_batch") {
|
||||
return 5;
|
||||
}
|
||||
if (field == "use_mmap") {
|
||||
return 4;
|
||||
}
|
||||
@@ -1345,6 +1368,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "mla_attn") {
|
||||
return "mla";
|
||||
}
|
||||
if (field == "attn_max_batch") {
|
||||
return "amb";
|
||||
}
|
||||
if (field == "use_mmap") {
|
||||
return "mmap";
|
||||
}
|
||||
@@ -1403,6 +1429,9 @@ struct markdown_printer : public printer {
|
||||
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
|
||||
fields.emplace_back("mla_attn");
|
||||
}
|
||||
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
|
||||
fields.emplace_back("attn_max_batch");
|
||||
}
|
||||
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
||||
fields.emplace_back("tensor_split");
|
||||
}
|
||||
|
||||
@@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
if (dim != 3) {
|
||||
if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) {
|
||||
const size_t size0 = ggml_nbytes(src0);
|
||||
const size_t size1 = ggml_nbytes(src1);
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_cuda(
|
||||
src0_d + i3 * (src0->nb[3] / 4),
|
||||
@@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||
}
|
||||
} else {
|
||||
const size_t size0 = ggml_nbytes(src0);
|
||||
const size_t size1 = ggml_nbytes(src1);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
//if (dim != 3) {
|
||||
// for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
// concat_f32_cuda(
|
||||
// src0_d + i3 * (src0->nb[3] / 4),
|
||||
// src1_d + i3 * (src1->nb[3] / 4),
|
||||
// dst_d + i3 * ( dst->nb[3] / 4),
|
||||
// src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
// dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||
// }
|
||||
//} else {
|
||||
// const size_t size0 = ggml_nbytes(src0);
|
||||
// const size_t size1 = ggml_nbytes(src1);
|
||||
|
||||
// CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
|
||||
// CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
|
||||
//}
|
||||
} else {
|
||||
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
|
||||
|
||||
@@ -12627,6 +12627,26 @@ static void ggml_compute_forward_concat_f32(
|
||||
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) &&
|
||||
(dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) {
|
||||
// simply copy the data
|
||||
const int64_t size_src_0 = ggml_nbytes(src0);
|
||||
const int64_t size_src_1 = ggml_nbytes(src1);
|
||||
const int64_t block_size = 4096;
|
||||
const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size;
|
||||
for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) {
|
||||
const int64_t start = i_block*block_size;
|
||||
if (start < size_src_0) {
|
||||
int64_t copy_size = MIN(block_size, size_src_0 - start);
|
||||
memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size);
|
||||
} else {
|
||||
int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start);
|
||||
memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = src0->ne[dim];
|
||||
|
||||
|
||||
@@ -384,6 +384,7 @@ extern "C" {
|
||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
|
||||
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
|
||||
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]
|
||||
|
||||
// Abort callback
|
||||
|
||||
122
src/llama.cpp
122
src/llama.cpp
@@ -2511,6 +2511,7 @@ struct llama_cparams {
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
int mla_attn;
|
||||
int attn_max_batch;
|
||||
bool fused_moe_up_gate;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
@@ -8774,6 +8775,18 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
|
||||
} else {
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
|
||||
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
@@ -8812,15 +8825,6 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
|
||||
GGML_ASSERT(kv.size == n_ctx);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_kv, n_embd_head_v, n_head_kv,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx,
|
||||
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
|
||||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
@@ -8830,6 +8834,50 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
}
|
||||
else {
|
||||
// For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2];
|
||||
GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]);
|
||||
int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch;
|
||||
n_step = std::min(n_step, int(k->ne[2]));
|
||||
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
|
||||
auto r2k = q->ne[2] / k->ne[2];
|
||||
auto r2v = q->ne[2] / v->ne[2];
|
||||
n_step = q->ne[2];
|
||||
n_per_step = 1;
|
||||
ggml_tensor * kqv;
|
||||
for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) {
|
||||
int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12;
|
||||
int i02 = i12/r2k;
|
||||
auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
|
||||
auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
|
||||
auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
|
||||
}
|
||||
if (model.arch == LLM_ARCH_GROK) {
|
||||
kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f);
|
||||
}
|
||||
if (hparams.attn_soft_cap) {
|
||||
kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping);
|
||||
} else {
|
||||
kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
}
|
||||
i02 = i12 / r2v;
|
||||
auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02);
|
||||
auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i);
|
||||
if (i12 == 0) {
|
||||
kqv = kqv_i;
|
||||
} else {
|
||||
kqv = ggml_concat(ctx, kqv, kqv_i, 2);
|
||||
}
|
||||
}
|
||||
ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
|
||||
@@ -8924,6 +8972,7 @@ struct llm_build_context {
|
||||
|
||||
const bool flash_attn;
|
||||
const int mla_attn;
|
||||
const int attn_max_batch;
|
||||
const bool fused_moe_up_gate;
|
||||
|
||||
const enum llama_pooling_type pooling_type;
|
||||
@@ -8976,6 +9025,7 @@ struct llm_build_context {
|
||||
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||
flash_attn (cparams.flash_attn),
|
||||
mla_attn (cparams.mla_attn),
|
||||
attn_max_batch (cparams.attn_max_batch),
|
||||
fused_moe_up_gate(cparams.fused_moe_up_gate),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
@@ -13572,10 +13622,25 @@ struct llm_build_context {
|
||||
|
||||
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
|
||||
cb(q, "q", il);
|
||||
|
||||
if (lctx.cparams.mla_attn > 1) {
|
||||
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||
kv_lora_rank, n_kv,
|
||||
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
|
||||
cb(kv_cache, "kv_cache_lora", il);
|
||||
|
||||
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
}
|
||||
|
||||
ggml_tensor * kqv_compressed;
|
||||
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
|
||||
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
|
||||
if (!pp_opt) {
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
cb(q, "q_perm", il);
|
||||
}
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
@@ -13592,17 +13657,7 @@ struct llm_build_context {
|
||||
cb(kq, "kq_soft_max_ext_perm", il);
|
||||
}
|
||||
|
||||
if (lctx.cparams.mla_attn > 1) {
|
||||
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||
kv_lora_rank, n_kv,
|
||||
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
|
||||
cb(kv_cache, "kv_cache_lora", il);
|
||||
|
||||
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
|
||||
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
|
||||
if (!pp_opt) {
|
||||
@@ -13610,6 +13665,30 @@ struct llm_build_context {
|
||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
|
||||
n_step = std::min(n_step, int(q->ne[2]));
|
||||
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
|
||||
|
||||
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
|
||||
|
||||
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
|
||||
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
|
||||
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
|
||||
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
|
||||
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
|
||||
if (i_head == 0) {
|
||||
kqv_compressed = kqv_i;
|
||||
} else {
|
||||
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
|
||||
}
|
||||
ggml_build_forward_expand(gf, kqv_compressed);
|
||||
}
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
|
||||
ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank),
|
||||
ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0);
|
||||
@@ -17644,6 +17723,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.flash_attn =*/ false,
|
||||
/*.mla_attn =*/ 0,
|
||||
/*.attn_max_batch =*/ 0,
|
||||
/*.fused_moe_up_gate =*/ false,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
@@ -17844,6 +17924,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
cparams.mla_attn = params.mla_attn;
|
||||
cparams.attn_max_batch = params.attn_max_batch;
|
||||
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
@@ -17912,6 +17993,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
|
||||
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
|
||||
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
Reference in New Issue
Block a user