mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
WIP split mode attn
Works for LlaMA models, but not for GLM-4.5. Doesn't seem to improve performance, so I guess no point in trying to fix it.
This commit is contained in:
@@ -1276,6 +1276,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
else if (arg_next == "layer") {
|
||||
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
}
|
||||
else if (arg_next == "attn") {
|
||||
params.split_mode = LLAMA_SPLIT_MODE_ATTN;
|
||||
}
|
||||
else if (arg_next == "graph") {
|
||||
params.split_mode = LLAMA_SPLIT_MODE_GRAPH;
|
||||
}
|
||||
|
||||
@@ -2989,6 +2989,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM &&
|
||||
ggml_is_contiguous(dst->src[0]) &&
|
||||
ggml_is_contiguous(dst->src[1]) &&
|
||||
dst->src[0]->type == GGML_TYPE_F32 && // with split mode "attn" we can end up having f16
|
||||
ggml_are_same_shape(dst->src[0], dst->src[1]) &&
|
||||
dst == cgraph->nodes[i+1]->src[0] &&
|
||||
ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) &&
|
||||
|
||||
@@ -275,7 +275,8 @@ extern "C" {
|
||||
enum llama_split_mode {
|
||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||
LLAMA_SPLIT_MODE_GRAPH = 2, // splits computations across GPUs
|
||||
LLAMA_SPLIT_MODE_ATTN = 2, // splits self-attention computations across GPUs
|
||||
LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -154,6 +154,7 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
ggml_context * split_ctx = nullptr;
|
||||
size_t ctx_size;
|
||||
|
||||
ggml_context * ctx_input;
|
||||
@@ -221,6 +222,11 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
|
||||
ctx_map[it.first] = ctx;
|
||||
model.ctxs.push_back(ctx);
|
||||
}
|
||||
if (model.split_buft) {
|
||||
if (auto it = ctx_map.find(model.split_buft); it != ctx_map.end()) {
|
||||
split_ctx = it->second;
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
printf("=======================================================================\n");
|
||||
auto n_device = model.device_count();
|
||||
@@ -295,7 +301,7 @@ static std::vector<int> create_split(int nr, int granularity, const std::vector<
|
||||
|
||||
ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne,
|
||||
int flags, ggml_context ** actual_context) {
|
||||
auto requested_ctx = ctx;
|
||||
//auto requested_ctx = ctx;
|
||||
if (ml.tensor_buft_overrides) {
|
||||
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
|
||||
std::regex pattern(overrides->pattern);
|
||||
@@ -308,7 +314,7 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std
|
||||
}
|
||||
if (actual_context) *actual_context = ctx;
|
||||
auto tensor = ml.create_tensor(ctx, name, ne, flags);
|
||||
if (tensor && ctx == requested_ctx) {
|
||||
if (tensor && ctx == split_ctx) {
|
||||
//printf("%s: adding tensor %s to split tensors\n", __func__, tensor->name);
|
||||
split_tensors.insert(tensor);
|
||||
}
|
||||
@@ -390,12 +396,12 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
// optional bias tensors
|
||||
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm = create_tensor(model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.rope_freqs = create_tensor(ctx_split, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
|
||||
if (n_expert == 0) {
|
||||
create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split);
|
||||
create_std_ffn(i, tn, layer, n_ff, n_embd, model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer);
|
||||
|
||||
// optional MLP bias
|
||||
layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
@@ -1863,10 +1869,12 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
layer.attn_k_norm = create_tensor(ctx_split,
|
||||
tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
|
||||
|
||||
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
|
||||
|
||||
// Why are we adding an additional tensor type?
|
||||
// attn_post_norm is the exact same thing as ffn_norm
|
||||
//layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
|
||||
// Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
|
||||
// GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
|
||||
@@ -1874,35 +1882,35 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
if (use_moe) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
|
||||
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
|
||||
// gate bias
|
||||
layer.ffn_exp_probs_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
|
||||
layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split,
|
||||
layer.ffn_gate_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
|
||||
layer.ffn_down_exps = create_tensor(ctx_split,
|
||||
layer.ffn_down_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
|
||||
layer.ffn_up_exps = create_tensor(ctx_split,
|
||||
layer.ffn_up_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
|
||||
|
||||
// Shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_gate_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
|
||||
layer.ffn_down_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_down_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
|
||||
layer.ffn_up_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_up_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
|
||||
}
|
||||
} else {
|
||||
// Dense layers (first k layers) - GLM uses separate gate/up projections
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
|
||||
layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
|
||||
}
|
||||
// --- NextN / MTP tensors (preserved but unused), on the final layer ---
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
@@ -2789,7 +2797,7 @@ static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor
|
||||
bool create_tensors_helper::create_tensors() {
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
bool use_mmap_buffer = true;
|
||||
if (ml.merge_qkv && model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
if (ml.merge_qkv && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
|
||||
LLAMA_LOG_WARN("\n========================================================\n");
|
||||
LLAMA_LOG_WARN("merge_qkv is not compatible with split model 'graph'\n");
|
||||
LLAMA_LOG_WARN(" => turning off merge_qkv\n");
|
||||
@@ -2916,7 +2924,7 @@ bool create_tensors_helper::create_tensors() {
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
std::vector<size_t> mem_used(model.splits.size(), 0);
|
||||
const auto & hparams = model.hparams;
|
||||
int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
|
||||
@@ -2970,20 +2978,27 @@ bool create_tensors_helper::create_tensors() {
|
||||
}
|
||||
|
||||
if (layer.ffn_norm) {
|
||||
auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split, mem_used);
|
||||
if (auto it = split_tensors.find(layer.ffn_norm); it != split_tensors.end()) {
|
||||
auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) {
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
bool use_split = split_tensors.find(layer.ffn_down) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up) != split_tensors.end();
|
||||
if (use_split) {
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down->ne[0], ffn_granularity, model.splits);
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down, layer.split_ffn_down, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up, layer.split_ffn_up, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used);
|
||||
}
|
||||
auto split = create_split(layer.ffn_down->ne[0], ffn_granularity, model.splits);
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down, layer.split_ffn_down, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up, layer.split_ffn_up, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used);
|
||||
}
|
||||
|
||||
//bool any_ffn_split = false;
|
||||
@@ -3024,25 +3039,29 @@ bool create_tensors_helper::create_tensors() {
|
||||
}
|
||||
}
|
||||
|
||||
//if (any_ffn_split) {
|
||||
if (layer.ffn_gate_inp) {
|
||||
if (layer.ffn_gate_inp) {
|
||||
if (auto it = split_tensors.find(layer.ffn_gate_inp); it != split_tensors.end()) {
|
||||
auto shared_split = create_split(ggml_nrows(layer.ffn_gate_inp), -1, model.splits);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp, layer.split_ffn_gate_inp, shared_split, mem_used);
|
||||
}
|
||||
if (layer.ffn_exp_probs_b) {
|
||||
}
|
||||
if (layer.ffn_exp_probs_b) {
|
||||
if (auto it = split_tensors.find(layer.ffn_exp_probs_b); it != split_tensors.end()) {
|
||||
auto shared_split = create_split(ggml_nrows(layer.ffn_exp_probs_b), -1, model.splits);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_exp_probs_b, layer.split_ffn_exp_probs_b, shared_split, mem_used);
|
||||
}
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.output) {
|
||||
if (ggml_backend_buft_is_host(model.buft_output.buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting output tensor becausee buffer is host\n", __func__);
|
||||
} else {
|
||||
auto ctx_split = ctx_map[model.buft_output.buft_matrix];
|
||||
auto split = create_split(model.output->ne[1], 16, model.splits);
|
||||
prepare_split_tensors(1, ctx_split, model.output, model.split_output, split, mem_used);
|
||||
if (auto it = split_tensors.find(model.output); it != split_tensors.end()) {
|
||||
if (ggml_backend_buft_is_host(model.buft_output.buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting output tensor becausee buffer is host\n", __func__);
|
||||
} else {
|
||||
auto ctx_split = ctx_map[model.buft_output.buft_matrix];
|
||||
auto split = create_split(model.output->ne[1], 16, model.splits);
|
||||
prepare_split_tensors(1, ctx_split, model.output, model.split_output, split, mem_used);
|
||||
}
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_INFO("Estimated model buffer size per device:\n");
|
||||
|
||||
@@ -410,6 +410,7 @@ struct llama_model {
|
||||
ggml_backend_buffer_type_t default_buffer_type_offload(int device) const;
|
||||
|
||||
std::vector<float> splits;
|
||||
ggml_backend_buffer_type_t split_buft = nullptr;
|
||||
};
|
||||
|
||||
struct llama_lora_weight {
|
||||
|
||||
@@ -461,18 +461,18 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
|
||||
GGML_UNUSED(gpu);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
|
||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu) {
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ggml_backend_cuda_get_device_count() > 1) {
|
||||
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
|
||||
buft = ggml_backend_cuda_split_buffer_type(model.splits.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
if (ggml_backend_sycl_get_device_count() > 1) {
|
||||
buft = ggml_backend_sycl_split_buffer_type(tensor_split);
|
||||
buft = ggml_backend_sycl_split_buffer_type(model.splits.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -481,7 +481,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo
|
||||
}
|
||||
return buft;
|
||||
|
||||
GGML_UNUSED(tensor_split);
|
||||
}
|
||||
|
||||
int llama_model::device_count() const {
|
||||
@@ -560,7 +559,7 @@ bool llama_context::update_cache_copies() {
|
||||
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
if ((int)kv_self.k_l.size() != n_layer) return false;
|
||||
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) {
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = !kv_self.v_l.empty() && kv_self.v_l[il] ? (ggml_split_tensor_t *)kv_self.v_l[il]->extra : nullptr;
|
||||
@@ -607,7 +606,7 @@ bool llama_context::update_cache_copies() {
|
||||
llama_context::llama_context(const llama_model & model)
|
||||
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {
|
||||
const auto & hparams = model.hparams;
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) {
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
cache_copies.resize(2*model.splits.size()*hparams.n_layer);
|
||||
} else {
|
||||
cache_copies.resize(2*hparams.n_layer);
|
||||
@@ -666,7 +665,7 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
|
||||
bool split_cache = false;
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH && model.arch != LLM_ARCH_DEEPSEEK2 && offload) {
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.arch != LLM_ARCH_DEEPSEEK2 && offload) {
|
||||
cache.split_k_l.reserve(n_layer);
|
||||
cache.split_v_l.reserve(n_layer);
|
||||
split_cache = true;
|
||||
@@ -1750,7 +1749,7 @@ static bool llm_load_tensors(
|
||||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
if (split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
if (!is_model_split_supported(model)) {
|
||||
LLAMA_LOG_WARN("\n=======================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n");
|
||||
@@ -1804,11 +1803,11 @@ static bool llm_load_tensors(
|
||||
model.splits = { 1.0f };
|
||||
}
|
||||
|
||||
int device_count = model.splits.size();
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||
|
||||
int device_count = model.splits.size();
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_layer[i] = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
@@ -1822,18 +1821,24 @@ static bool llm_load_tensors(
|
||||
}
|
||||
} else {
|
||||
ggml_backend_buffer_type_t split_buft;
|
||||
if (split_mode == LLAMA_SPLIT_MODE_GRAPH && model.splits.size() > 1) {
|
||||
split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu], model.splits.data());
|
||||
if ((split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu]);
|
||||
model.split_buft = split_buft;
|
||||
} else {
|
||||
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
||||
split_buft = llama_default_buffer_type_offload(model, model.devices[main_gpu]);
|
||||
}
|
||||
auto buft_layer = llama_default_buffer_type_offload(model, model.devices[main_gpu]);
|
||||
// assign the repeating layers
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
model.buft_layer[i] = {
|
||||
split_buft,
|
||||
llama_default_buffer_type_offload(model, model.devices[main_gpu])
|
||||
};
|
||||
if (split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count,
|
||||
float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_layer[i] = { split_buft, llama_default_buffer_type_offload(model, model.devices[layer_gpu]) };
|
||||
printf("Layer %d: assigning buft_layer to GPU %d\n", i, layer_gpu);
|
||||
} else {
|
||||
model.buft_layer[i] = { split_buft, buft_layer };
|
||||
}
|
||||
}
|
||||
// assign the output layer
|
||||
if (n_gpu_layers > n_layer) {
|
||||
@@ -4476,8 +4481,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
LLAMA_LOG_ERROR("%s: split mode 'graph' or 'attn' not supported. Failed to initialize Vulkan backend\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user