diff --git a/common/common.cpp b/common/common.cpp index 71acd0f4..9f1ce736 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1276,6 +1276,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa else if (arg_next == "layer") { params.split_mode = LLAMA_SPLIT_MODE_LAYER; } + else if (arg_next == "attn") { + params.split_mode = LLAMA_SPLIT_MODE_ATTN; + } else if (arg_next == "graph") { params.split_mode = LLAMA_SPLIT_MODE_GRAPH; } diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ddcfbc38..2372efdb 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2989,6 +2989,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && + dst->src[0]->type == GGML_TYPE_F32 && // with split mode "attn" we can end up having f16 ggml_are_same_shape(dst->src[0], dst->src[1]) && dst == cgraph->nodes[i+1]->src[0] && ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) && diff --git a/include/llama.h b/include/llama.h index 6c4cb042..3c9b331c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -275,7 +275,8 @@ extern "C" { enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_GRAPH = 2, // splits computations across GPUs + LLAMA_SPLIT_MODE_ATTN = 2, // splits self-attention computations across GPUs + LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs }; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index ee4eaaca..62148771 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -154,6 +154,7 @@ struct create_tensors_helper : public create_tensors_helper_interface { std::map buft_layer_count; std::map ctx_map; + ggml_context * split_ctx = nullptr; size_t ctx_size; ggml_context * ctx_input; @@ -221,6 +222,11 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod ctx_map[it.first] = ctx; model.ctxs.push_back(ctx); } + if (model.split_buft) { + if (auto it = ctx_map.find(model.split_buft); it != ctx_map.end()) { + split_ctx = it->second; + } + } #if 0 printf("=======================================================================\n"); auto n_device = model.device_count(); @@ -295,7 +301,7 @@ static std::vector create_split(int nr, int granularity, const std::vector< ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector & ne, int flags, ggml_context ** actual_context) { - auto requested_ctx = ctx; + //auto requested_ctx = ctx; if (ml.tensor_buft_overrides) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); @@ -308,7 +314,7 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std } if (actual_context) *actual_context = ctx; auto tensor = ml.create_tensor(ctx, name, ne, flags); - if (tensor && ctx == requested_ctx) { + if (tensor && ctx == split_ctx) { //printf("%s: adding tensor %s to split tensors\n", __func__, tensor->name); split_tensors.insert(tensor); } @@ -390,12 +396,12 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) { // optional bias tensors layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.rope_freqs = create_tensor(ctx_split, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { - create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split); + create_std_ffn(i, tn, layer, n_ff, n_embd, model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer); // optional MLP bias layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -1863,10 +1869,12 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) { layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags); + auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer; + // Why are we adding an additional tensor type? // attn_post_norm is the exact same thing as ffn_norm //layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); - layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE @@ -1874,35 +1882,35 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) { if (use_moe) { // MoE layers - layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); // gate bias - layer.ffn_exp_probs_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - layer.ffn_gate_exps = create_tensor(ctx_split, + layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - layer.ffn_down_exps = create_tensor(ctx_split, + layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); - layer.ffn_up_exps = create_tensor(ctx_split, + layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); // Shared expert if (n_expert_shared > 0) { const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor(ctx_split, + layer.ffn_gate_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - layer.ffn_down_shexp = create_tensor(ctx_split, + layer.ffn_down_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); - layer.ffn_up_shexp = create_tensor(ctx_split, + layer.ffn_up_shexp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); } } else { // Dense layers (first k layers) - GLM uses separate gate/up projections - layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); - layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); - layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); } // --- NextN / MTP tensors (preserved but unused), on the final layer --- if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { @@ -2789,7 +2797,7 @@ static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor bool create_tensors_helper::create_tensors() { const auto tn = LLM_TN(model.arch); bool use_mmap_buffer = true; - if (ml.merge_qkv && model.split_mode == LLAMA_SPLIT_MODE_GRAPH) { + if (ml.merge_qkv && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) { LLAMA_LOG_WARN("\n========================================================\n"); LLAMA_LOG_WARN("merge_qkv is not compatible with split model 'graph'\n"); LLAMA_LOG_WARN(" => turning off merge_qkv\n"); @@ -2916,7 +2924,7 @@ bool create_tensors_helper::create_tensors() { default: throw std::runtime_error("unknown architecture"); } - if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH) { + if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) { std::vector mem_used(model.splits.size(), 0); const auto & hparams = model.hparams; int gqa_ratio = hparams.n_head() / hparams.n_head_kv(); @@ -2970,20 +2978,27 @@ bool create_tensors_helper::create_tensors() { } if (layer.ffn_norm) { - auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits); - prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split, mem_used); + if (auto it = split_tensors.find(layer.ffn_norm); it != split_tensors.end()) { + auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits); + prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split, mem_used); + } } if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) { - int ffn_granularity = 16; - if (ggml_is_quantized(layer.ffn_down->type)) { - auto tt = ggml_internal_get_type_traits(layer.ffn_down->type); - if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size; + bool use_split = split_tensors.find(layer.ffn_down) != split_tensors.end() && + split_tensors.find(layer.ffn_gate) != split_tensors.end() && + split_tensors.find(layer.ffn_up) != split_tensors.end(); + if (use_split) { + int ffn_granularity = 16; + if (ggml_is_quantized(layer.ffn_down->type)) { + auto tt = ggml_internal_get_type_traits(layer.ffn_down->type); + if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size; + } + auto split = create_split(layer.ffn_down->ne[0], ffn_granularity, model.splits); + prepare_split_tensors(0, ctx_split, layer.ffn_down, layer.split_ffn_down, split, mem_used); + prepare_split_tensors(1, ctx_split, layer.ffn_up, layer.split_ffn_up, split, mem_used); + prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used); } - auto split = create_split(layer.ffn_down->ne[0], ffn_granularity, model.splits); - prepare_split_tensors(0, ctx_split, layer.ffn_down, layer.split_ffn_down, split, mem_used); - prepare_split_tensors(1, ctx_split, layer.ffn_up, layer.split_ffn_up, split, mem_used); - prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used); } //bool any_ffn_split = false; @@ -3024,25 +3039,29 @@ bool create_tensors_helper::create_tensors() { } } - //if (any_ffn_split) { - if (layer.ffn_gate_inp) { + if (layer.ffn_gate_inp) { + if (auto it = split_tensors.find(layer.ffn_gate_inp); it != split_tensors.end()) { auto shared_split = create_split(ggml_nrows(layer.ffn_gate_inp), -1, model.splits); prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp, layer.split_ffn_gate_inp, shared_split, mem_used); } - if (layer.ffn_exp_probs_b) { + } + if (layer.ffn_exp_probs_b) { + if (auto it = split_tensors.find(layer.ffn_exp_probs_b); it != split_tensors.end()) { auto shared_split = create_split(ggml_nrows(layer.ffn_exp_probs_b), -1, model.splits); prepare_split_tensors(-1, ctx_split, layer.ffn_exp_probs_b, layer.split_ffn_exp_probs_b, shared_split, mem_used); } - //} + } } if (model.output) { - if (ggml_backend_buft_is_host(model.buft_output.buft_matrix)) { - LLAMA_LOG_INFO("%s: not splitting output tensor becausee buffer is host\n", __func__); - } else { - auto ctx_split = ctx_map[model.buft_output.buft_matrix]; - auto split = create_split(model.output->ne[1], 16, model.splits); - prepare_split_tensors(1, ctx_split, model.output, model.split_output, split, mem_used); + if (auto it = split_tensors.find(model.output); it != split_tensors.end()) { + if (ggml_backend_buft_is_host(model.buft_output.buft_matrix)) { + LLAMA_LOG_INFO("%s: not splitting output tensor becausee buffer is host\n", __func__); + } else { + auto ctx_split = ctx_map[model.buft_output.buft_matrix]; + auto split = create_split(model.output->ne[1], 16, model.splits); + prepare_split_tensors(1, ctx_split, model.output, model.split_output, split, mem_used); + } } } LLAMA_LOG_INFO("Estimated model buffer size per device:\n"); diff --git a/src/llama-model.h b/src/llama-model.h index c7fe8b68..bb7134cc 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -410,6 +410,7 @@ struct llama_model { ggml_backend_buffer_type_t default_buffer_type_offload(int device) const; std::vector splits; + ggml_backend_buffer_type_t split_buft = nullptr; }; struct llama_lora_weight { diff --git a/src/llama.cpp b/src/llama.cpp index 4f608092..c1ffd5ac 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -461,18 +461,18 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_ GGML_UNUSED(gpu); } -static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) { +static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu) { ggml_backend_buffer_type_t buft = nullptr; #ifdef GGML_USE_CUDA if (ggml_backend_cuda_get_device_count() > 1) { - buft = ggml_backend_cuda_split_buffer_type(tensor_split); + buft = ggml_backend_cuda_split_buffer_type(model.splits.data()); } #endif #ifdef GGML_USE_SYCL if (ggml_backend_sycl_get_device_count() > 1) { - buft = ggml_backend_sycl_split_buffer_type(tensor_split); + buft = ggml_backend_sycl_split_buffer_type(model.splits.data()); } #endif @@ -481,7 +481,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo } return buft; - GGML_UNUSED(tensor_split); } int llama_model::device_count() const { @@ -560,7 +559,7 @@ bool llama_context::update_cache_copies() { int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2; if ((int)kv_self.k_l.size() != n_layer) return false; if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false; - if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) { + if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) { for (int il = 0; il < n_layer; ++il) { auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra; auto vl = !kv_self.v_l.empty() && kv_self.v_l[il] ? (ggml_split_tensor_t *)kv_self.v_l[il]->extra : nullptr; @@ -607,7 +606,7 @@ bool llama_context::update_cache_copies() { llama_context::llama_context(const llama_model & model) : model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) { const auto & hparams = model.hparams; - if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) { + if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) { cache_copies.resize(2*model.splits.size()*hparams.n_layer); } else { cache_copies.resize(2*hparams.n_layer); @@ -666,7 +665,7 @@ static bool llama_kv_cache_init( } bool split_cache = false; - if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.arch != LLM_ARCH_DEEPSEEK2 && offload) { + if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.arch != LLM_ARCH_DEEPSEEK2 && offload) { cache.split_k_l.reserve(n_layer); cache.split_v_l.reserve(n_layer); split_cache = true; @@ -1750,7 +1749,7 @@ static bool llm_load_tensors( auto & hparams = model.hparams; - if (split_mode == LLAMA_SPLIT_MODE_GRAPH) { + if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) { if (!is_model_split_supported(model)) { LLAMA_LOG_WARN("\n=======================================================\n"); LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n"); @@ -1804,11 +1803,11 @@ static bool llm_load_tensors( model.splits = { 1.0f }; } + int device_count = model.splits.size(); + // assign the repeating layers to the devices according to the splits + int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1); if (split_mode == LLAMA_SPLIT_MODE_LAYER) { - int device_count = model.splits.size(); - // assign the repeating layers to the devices according to the splits - int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1); for (int i = i_gpu_start; i < n_layer; ++i) { int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin(); model.buft_layer[i] = llama_default_buffer_type_offload(model, model.devices[layer_gpu]); @@ -1822,18 +1821,24 @@ static bool llm_load_tensors( } } else { ggml_backend_buffer_type_t split_buft; - if (split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) { - split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu], model.splits.data()); + if ((split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) { + split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu]); + model.split_buft = split_buft; } else { // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported split_buft = llama_default_buffer_type_offload(model, model.devices[main_gpu]); } + auto buft_layer = llama_default_buffer_type_offload(model, model.devices[main_gpu]); // assign the repeating layers for (int i = i_gpu_start; i < n_layer; ++i) { - model.buft_layer[i] = { - split_buft, - llama_default_buffer_type_offload(model, model.devices[main_gpu]) - }; + if (split_mode == LLAMA_SPLIT_MODE_ATTN) { + int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, + float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin(); + model.buft_layer[i] = { split_buft, llama_default_buffer_type_offload(model, model.devices[layer_gpu]) }; + printf("Layer %d: assigning buft_layer to GPU %d\n", i, layer_gpu); + } else { + model.buft_layer[i] = { split_buft, buft_layer }; + } } // assign the output layer if (n_gpu_layers > n_layer) { @@ -4476,8 +4481,8 @@ struct llama_context * llama_new_context_with_model( } } #elif defined(GGML_USE_VULKAN) - if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH) { - LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__); + if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN) { + LLAMA_LOG_ERROR("%s: split mode 'graph' or 'attn' not supported. Failed to initialize Vulkan backend\n", __func__); llama_free(ctx); return nullptr; }