diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 68dfe511..73d71c11 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6489,17 +6489,6 @@ ggml_cgraph * llm_build_context::build_deepseek2() { auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.k_l[il], kv_lora_rank, n_kv, kv_self.k_l[il]->nb[1], ggml_row_size(kv_self.k_l[il]->type, n_embd_head_qk_rope)); - auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); - int n_max_head = n_head; - if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) { - while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) { - n_max_head /= 2; kv_f32_size /= 2; - } - } - GGML_ASSERT(n_head % n_max_head == 0); - - auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; - auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_qk_rope, n_kv, 1, kv_self.k_l[il]->nb[1], kv_self.k_l[il]->nb[2], 0); //ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank)); @@ -6509,6 +6498,109 @@ ggml_cgraph * llm_build_context::build_deepseek2() { // if the build is with CUDA enabled. auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.k_l[il]->type : GGML_TYPE_F16; + //if (kv_cache_rope->type != kv_type) { + // kv_cache_rope = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16); + //} + + //auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + auto q = ggml_concat(ctx0, q_rope, q_nope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_concat", il); + ggml_build_forward_expand(gf, q); + + auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; + + if (auto & wkv_b_per_device = model.layers[il].wkv_b_per_device; wkv_b_per_device.size() > 1) { + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[3] = 1; + repeater.ne[2] = wkv_b_per_device[0]->ne[1] / n_per_head; + ggml_tensor * k_rope; + if (kv_cache_rope->type == kv_type) { + k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + } else { + auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16); + k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater); + } + cb(k_rope, "k_rope", il); + std::vector results(wkv_b_per_device.size()); + for (int id = 0; id < int(wkv_b_per_device.size()); ++id) { + int il_id = 1000*il + id; + auto wkv_b = wkv_b_per_device[id].get(); + GGML_ASSERT(wkv_b); + GGML_ASSERT(wkv_b->ne[1] % n_per_head == 0); + auto this_nhead = wkv_b->ne[1] / n_per_head; + + auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope); + cb(kv_f32, "kv_f32", il_id); + kv_f32->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, this_nhead, + ggml_row_size(kv_f32->type, this_nhead* (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, this_nhead, + ggml_row_size(kv_f32->type, this_nhead* (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il_id); + + auto v = ggml_cast(ctx0, v_f32, kv_type); + cb(v, "v", il_id); + + auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type); + cb(k_nope, "k_nope", il_id); + + //ggml_build_forward_expand(gf, k_nope); + //ggml_build_forward_expand(gf, v); + + if (repeater.ne[2] != this_nhead) { + repeater.ne[2] = this_nhead; + if (kv_cache_rope->type == kv_type) { + k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + } else { + auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16); + k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater); + } + cb(k_rope, "k_rope", il_id); + } + + auto k = ggml_concat(ctx0, k_rope, k_nope, 0); + cb(k, "k", il_id); + + //ggml_build_forward_expand(gf, k); + + auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_nhead, + q->nb[1], q->nb[2], q->nb[2]*this_nhead*id); + + kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + if (use_f32_attn_precision || q->ne[1] <= 8) { + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + } + cb(kqv, "kqv", il_id); + + results[id] = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*this_nhead, n_tokens); + //results[id] = ggml_cast(ctx0, results[id], GGML_TYPE_F16); + ggml_build_forward_expand(gf, results[id]); + + } + cur = ggml_concat(ctx0, results[0], results[1], 0); + cur->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + for (int id = 2; id < int(wkv_b_per_device.size()); ++id) { + cur = ggml_concat(ctx0, cur, results[id], 0); + } + + } else { + + auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); + int n_max_head = n_head; + if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) { + while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) { + n_max_head /= 2; kv_f32_size /= 2; + } + } + GGML_ASSERT(n_head % n_max_head == 0); + ggml_tensor repeater; repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1; ggml_tensor * k_rope; @@ -6520,13 +6612,6 @@ ggml_cgraph * llm_build_context::build_deepseek2() { } cb(k_rope, "k_rope", il); - //auto q = ggml_concat(ctx0, q_nope, q_rope, 0); - auto q = ggml_concat(ctx0, q_rope, q_nope, 0); - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_concat", il); - - ggml_build_forward_expand(gf, q); - for (int iter = 0; iter < n_head/n_max_head; ++iter) { auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head, @@ -6577,6 +6662,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { } } + } } else { diff --git a/src/llama-model.h b/src/llama-model.h index 4667193e..c1af3885 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -317,6 +317,8 @@ struct llama_layer { std::unique_ptr computed_wk_b; std::unique_ptr computed_wv_b; std::unique_ptr computed_wkv_b; + + std::vector> wkv_b_per_device; }; struct llama_lora_adapter; diff --git a/src/llama.cpp b/src/llama.cpp index 54952e25..d6f4ad90 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1687,6 +1687,34 @@ static void llm_prepare_mla(llama_model & model, int mla) { printf("Computed %s as %ld x %ld and stored in buffer %s\n", name.c_str(), wkv_b->ne[0], wkv_b->ne[1], ggml_backend_buffer_name(l.computed_wkv_b->buffer)); + if (int n_device = model.devices.size(); n_device > 1) { + l.wkv_b_per_device.reserve(n_device); + int nh_per_device = (n_head + n_device - 1)/n_device; + int n_per_head = wkv_b->ne[1] / n_head; + auto ptr = (const char *)wkv_b->data; + for (int id = 0; id < int(model.devices.size()); ++id) { + int this_nh = std::min(nh_per_device, n_head - id*nh_per_device); + if (this_nh <= 0) break; + auto wkv_b_id = std::make_unique(*wkv_b); + wkv_b_id->ne[1] = this_nh * n_per_head; + wkv_b_id->nb[2] = wkv_b_id->nb[3] = wkv_b_id->ne[1]*wkv_b_id->nb[1]; + auto buft = llama_default_buffer_type_offload(model, model.devices[id]); + auto nbytes = ggml_nbytes(wkv_b_id.get()); + wkv_b_id->buffer = ggml_backend_buft_alloc_buffer(buft, nbytes); + wkv_b_id->data = ggml_backend_buffer_get_base(wkv_b_id->buffer); + wkv_b_id->op = GGML_OP_NONE; + for (int j = 0; j < GGML_MAX_SRC; ++j) wkv_b_id->src[j] = nullptr; + auto this_name = name + '.' + std::to_string(id); + ggml_set_name(wkv_b_id.get(), this_name.c_str()); + ggml_backend_buffer_set_usage(wkv_b_id->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(wkv_b_id.get(), ptr, 0, nbytes); + printf(" Stored %s (%ld x %ld) on device %s\n", wkv_b_id->name, wkv_b_id->ne[0], wkv_b_id->ne[1], + ggml_backend_buffer_name(wkv_b_id->buffer)); + l.wkv_b_per_device.emplace_back(std::move(wkv_b_id)); + ptr += nbytes; + } + } + ggml_graph_clear(graph); } ggml_free(ctx);