mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-26 16:14:10 +00:00
POC: CUDA tensor parallel (MoE models) (#1022)
* Remove most of split mode row
* WIP
* WIP: also allocate the KV cache using tensor split
* WIP: it runs with wrong result
But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
wait for GPU 0 to finish its entore FFN calculation before it can
start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
* WIP
* This works, but it is slow
* This is slightly better
the graph is still not being computed in parallel.
Why? Because the scheduler creates graph splits where the
result of the computation on one GPU becomes an input for the
other split. Hence, to trigger the computation on the second GPU
one needs to wait for the computation on the first GPU to finish,
even thiough the two can be done in parallel up to the sunchronization
point. So, all that is left to do is to trick the scheduler to create
to splits that can be done in parallel, and then have a graph split
where the results get combined.
* Playing games with the scheduler
This change tricks it into doing the right thing^TM.
Still quite a bit slower than split mode layer for the 8B LlaMA model.
But for the 70B LlaMA it now beats split mode layer for TG:
28 t/s vs 24.4 t/s. PP is 627 t/s vs 744 t/s.
In comparison, split mode "row" in mainline gets
484 t/s PP and 19.3 t/s TG.
* Fix attn split
Granularity for Wq, Wo is not just head size, but
head size * gqa_ratio.
Else the Wk, Wv tensors end up not being a multiple of the
head size when we divide the split determined by Wo with
the gqa_ratio.
* Show memory used per device
* Make it work with partial offload
but no tensor overrides yet, just ngl < num_layers.
* Allow for f16 source in fused_rms_norm
* This results in faster PP.
Now PP is faster than split mode layer for L3-70B.
* Rename split mode "row" to split mode "graph"
* Leave FFN partial results as f16
* WIP GLM4.5 - runs with wrong results
* WIP GLM4.5 - this works
PP is already better than split mode layer, but TG for zero context
is kind of low - 60 vs 92 t/s. TG becomes better than split mode layer
at around 20k tokens. PP at 26k tokens is 1.55X of sm layer.
* Work around compiler bug
It issues a warning that there is an extra semicolon outside of a function,
but there isn't. If I remove the anonymous namespace and turn the
functions inside into static, the warning disapears, so clearly
a compiler bug.
* Make graph reuse work with split mode graph
* Remove more split mode row remnants
* WIP tensor overrides
Runs with wrong results, don't see where the issue could be.
* This works but is slow
Still does not work for row-interleaved quants
* Slightly better
* Slightly better
* Row-interleaved quants work
* Better
* Minor
* Guarad against using split mode "graph" for unsupported models
* Guards against using merge_qkv with split mode "graph"
* WIP split mode attn
Works for LlaMA models, but not for GLM-4.5.
Doesn't seem to improve performance, so I guess no point in trying to
fix it.
* Split mode graph for qwen3moe
* Try to better distribute the splits
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
340
src/llama.cpp
340
src/llama.cpp
@@ -108,6 +108,7 @@
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
@@ -460,18 +461,18 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
|
||||
GGML_UNUSED(gpu);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
|
||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu) {
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ggml_backend_cuda_get_device_count() > 1) {
|
||||
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
|
||||
buft = ggml_backend_cuda_split_buffer_type(model.splits.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
if (ggml_backend_sycl_get_device_count() > 1) {
|
||||
buft = ggml_backend_sycl_split_buffer_type(tensor_split);
|
||||
buft = ggml_backend_sycl_split_buffer_type(model.splits.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -480,7 +481,14 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo
|
||||
}
|
||||
return buft;
|
||||
|
||||
GGML_UNUSED(tensor_split);
|
||||
}
|
||||
|
||||
int llama_model::device_count() const {
|
||||
return llama_get_device_count(*this);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t llama_model::default_buffer_type_offload(int device) const {
|
||||
return llama_default_buffer_type_offload(*this, device);
|
||||
}
|
||||
|
||||
static size_t llama_get_device_memory(const llama_model & model, int device) {
|
||||
@@ -548,23 +556,49 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
}
|
||||
|
||||
bool llama_context::update_cache_copies() {
|
||||
int n_layer = cache_copies.size()/2;
|
||||
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
if ((int)kv_self.k_l.size() != n_layer) return false;
|
||||
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto& c = cache_copies[2*il+0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
if (kv_self.v_l.empty()) return true;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto& c = cache_copies[2*il+1];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = !kv_self.v_l.empty() && kv_self.v_l[il] ? (ggml_split_tensor_t *)kv_self.v_l[il]->extra : nullptr;
|
||||
GGML_ASSERT(kl && (!kv_self.v_l[il] || vl));
|
||||
if (vl) {
|
||||
GGML_ASSERT(kl->n_device == vl->n_device);
|
||||
}
|
||||
for (int id = 0; id < kl->n_device; ++id) {
|
||||
auto& c = cache_copies[2*model.splits.size()*il + 2*id + 0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kl->splits[id]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kl->splits[id]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
if (!vl) continue;
|
||||
for (int id = 0; id < vl->n_device; ++id) {
|
||||
auto& c = cache_copies[2*model.splits.size()*il + 2*id + 1];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != vl->splits[id]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)vl->splits[id]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto& c = cache_copies[2*il+0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.k_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
if (kv_self.v_l.empty()) return true;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto& c = cache_copies[2*il+1];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.v_l[il]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -572,7 +606,11 @@ bool llama_context::update_cache_copies() {
|
||||
llama_context::llama_context(const llama_model & model)
|
||||
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {
|
||||
const auto & hparams = model.hparams;
|
||||
cache_copies.resize(2*hparams.n_layer);
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
cache_copies.resize(2*model.splits.size()*hparams.n_layer);
|
||||
} else {
|
||||
cache_copies.resize(2*hparams.n_layer);
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
@@ -626,42 +664,35 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
bool split_cache = false;
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.arch != LLM_ARCH_DEEPSEEK2 && offload) {
|
||||
cache.split_k_l.reserve(n_layer);
|
||||
cache.split_v_l.reserve(n_layer);
|
||||
split_cache = true;
|
||||
}
|
||||
|
||||
// count used buffer types
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
if (offload) {
|
||||
for (int64_t i = 0; i < n_layer; ++i) {
|
||||
buft_layer_count[model.buft_layer[i].buft]++;
|
||||
if (split_cache) {
|
||||
buft_layer_count[model.buft_layer[i].buft_matrix]++;
|
||||
} else {
|
||||
buft_layer_count[model.buft_layer[i].buft]++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
buft_layer_count[llama_default_buffer_type_cpu(true)] = n_layer;
|
||||
}
|
||||
|
||||
//if (cparams.fused_moe_up_gate) {
|
||||
// int nbad = 0;
|
||||
// for (int i = 0; i < (int) n_layer; i++) {
|
||||
// auto& layer = model.layers[i];
|
||||
// if (layer.ffn_gate_exps && layer.ffn_up_exps && layer.ffn_gate_exps->type != layer.ffn_up_exps->type) {
|
||||
// ++nbad;
|
||||
// }
|
||||
// }
|
||||
// if (nbad > 0) {
|
||||
// if (nbad == (int)n_layer) {
|
||||
// LLAMA_LOG_WARN("=============== ffn_up and ffn_gate are of different type => disabling fmoe\n");
|
||||
// const_cast<llama_cparams&>(cparams).fused_moe_up_gate = false;
|
||||
// }
|
||||
// else {
|
||||
// LLAMA_LOG_WARN("=============== ffn_up and ffn_gate are of different in %d out of %d layers, where fmoe will be disabled\n",
|
||||
// nbad, (int)n_layer);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
for (auto & it : buft_layer_count) {
|
||||
int n_layers = it.second;
|
||||
size_t ctx_mem_size = 5u*n_layers*ggml_tensor_overhead();
|
||||
if (split_cache) ctx_mem_size += 2*model.splits.size()*n_layers*ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(),
|
||||
/*.mem_size =*/ ctx_mem_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
@@ -698,24 +729,25 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
cache.k_l.reserve(n_layer);
|
||||
bool needs_v_cache = true;
|
||||
cache.k_l.reserve(n_layer);
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
|
||||
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
|
||||
}
|
||||
if (needs_v_cache) cache.v_l.reserve(n_layer);
|
||||
|
||||
std::vector<size_t> mem_split(model.splits.size(), 0);
|
||||
|
||||
int n_mla = 0;
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
const uint32_t n_head_kv = hparams.n_head_kv(i);
|
||||
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
|
||||
|
||||
|
||||
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
if (cparams.mla_attn) {
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
|
||||
// DeepSeek MLA
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
@@ -740,10 +772,53 @@ static bool llama_kv_cache_init(
|
||||
else {
|
||||
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
|
||||
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
|
||||
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
|
||||
ggml_set_name(k, k_name.c_str());
|
||||
ggml_set_name(v, v_name.c_str());
|
||||
//ggml_format_name(k, "cache_k_l%d", i);
|
||||
//ggml_format_name(v, "cache_v_l%d", i);
|
||||
cache.k_l.push_back(k);
|
||||
cache.v_l.push_back(v);
|
||||
if (split_cache) {
|
||||
auto K = model.layers[i].wk;
|
||||
auto V = model.layers[i].wv;
|
||||
if (K && V && K->extra && V->extra) {
|
||||
auto extra_K = (const ggml_split_tensor_t *)K->extra;
|
||||
auto extra_V = (const ggml_split_tensor_t *)V->extra;
|
||||
auto & split_k_l = cache.split_k_l.emplace_back();
|
||||
auto & split_v_l = cache.split_v_l.emplace_back();
|
||||
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
|
||||
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
|
||||
for (int is = 0; is < extra_K->n_device; ++is) {
|
||||
auto split = extra_K->splits[is];
|
||||
if (!split) continue;
|
||||
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, split->ne[1]/n_embd_head_k * kv_size);
|
||||
auto split_name = k_name + '.' + std::to_string(is);
|
||||
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
|
||||
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
|
||||
}
|
||||
split_k_l.ggml.n_device = extra_K->n_device;
|
||||
split_k_l.ggml.split_dim = 0;
|
||||
split_k_l.ggml.splits = split_k_l.tensor_splits.data();
|
||||
for (int is = 0; is < extra_V->n_device; ++is) {
|
||||
auto split = extra_V->splits[is];
|
||||
if (!split) continue;
|
||||
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
|
||||
auto split_name = v_name + '.' + std::to_string(is);
|
||||
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
|
||||
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
|
||||
}
|
||||
split_v_l.ggml.n_device = extra_V->n_device;
|
||||
split_v_l.ggml.split_dim = 0;
|
||||
split_v_l.ggml.splits = split_v_l.tensor_splits.data();
|
||||
k->extra = (void *)&split_k_l.ggml;
|
||||
v->extra = (void *)&split_v_l.ggml;
|
||||
}
|
||||
//} else {
|
||||
// printf("Oops: don't have yet K and V for layer %d\n", i);
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
|
||||
@@ -756,15 +831,46 @@ static bool llama_kv_cache_init(
|
||||
for (auto it : ctx_map) {
|
||||
ggml_backend_buffer_type_t buft = it.first;
|
||||
ggml_context * ctx = it.second;
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (!buf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
|
||||
return false;
|
||||
int ntensor = 0;
|
||||
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
++ntensor;
|
||||
}
|
||||
if (ntensor > 0) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (!buf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
cache.bufs.push_back(buf);
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
cache.bufs.push_back(buf);
|
||||
}
|
||||
if (split_cache) {
|
||||
LLAMA_LOG_INFO("%s: KV cache size per device:\n", __func__);
|
||||
for (int i = 0; i < int(mem_split.size()); ++i) printf(" Device %d: %g MiB\n", i, mem_split[i]/1024./1024.);
|
||||
}
|
||||
|
||||
#if 0
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (cache.k_l[il]->extra) {
|
||||
printf("Layer %2d, K-buffer: %p:", il, (void *)cache.k_l[il]->buffer);
|
||||
auto split_kl = (ggml_split_tensor_t *)cache.k_l[il]->extra;
|
||||
for (int id = 0; id < split_kl->n_device; ++id) {
|
||||
if (split_kl->splits[id]) printf(" %p,%p", (void *)split_kl->splits[id]->data, (void *)split_kl->splits[id]->buffer);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
if (cache.v_l[il]->extra) {
|
||||
printf("Layer %2d, V-buffer: %p:", il, (void *)cache.v_l[il]->buffer);
|
||||
auto split_vl = (ggml_split_tensor_t *)cache.v_l[il]->extra;
|
||||
for (int id = 0; id < split_vl->n_device; ++id) {
|
||||
if (split_vl->splits[id]) printf(" %p,%p", (void *)split_vl->splits[id]->data, (void *)split_vl->splits[id]->buffer);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1617,6 +1723,16 @@ static void ggml_backend_add_from_device(llama_context* ctx, ggml_backend_t back
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_model_split_supported(const llama_model & model) {
|
||||
static std::unordered_set<llm_arch> k_supported = {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_GLM4_MOE,
|
||||
};
|
||||
auto it = k_supported.find(model.arch);
|
||||
return it != k_supported.end();
|
||||
}
|
||||
|
||||
// Returns false if cancelled by progress_callback
|
||||
static bool llm_load_tensors(
|
||||
llama_model_loader & ml,
|
||||
@@ -1634,6 +1750,16 @@ static bool llm_load_tensors(
|
||||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
if (!is_model_split_supported(model)) {
|
||||
LLAMA_LOG_WARN("\n=======================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n");
|
||||
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
|
||||
LLAMA_LOG_WARN("=======================================================\n\n");
|
||||
split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
}
|
||||
}
|
||||
|
||||
model.split_mode = split_mode;
|
||||
model.main_gpu = main_gpu;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
@@ -1652,10 +1778,7 @@ static bool llm_load_tensors(
|
||||
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
|
||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||
// calculate the split points
|
||||
// int device_count = llama_get_device_count(model);
|
||||
int device_count = model.devices.size();
|
||||
if (int device_count = model.devices.size(); device_count > 1) {
|
||||
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
|
||||
std::vector<float> splits(device_count);
|
||||
if (all_zero) {
|
||||
@@ -1676,46 +1799,47 @@ static bool llm_load_tensors(
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
splits[i] /= split_sum;
|
||||
}
|
||||
model.splits = std::move(splits);
|
||||
} else {
|
||||
model.splits = { 1.0f };
|
||||
}
|
||||
|
||||
int device_count = model.splits.size();
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
|
||||
#ifndef NDEBUG
|
||||
ggml_backend_buffer_type_t buft = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
const char* name = ggml_backend_buft_name(buft);
|
||||
LLAMA_LOG_DEBUG("load_tensors: layers %3d assigned to backend %s\n", i,
|
||||
name);
|
||||
#endif
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_layer[i] = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
}
|
||||
// assign the output layer
|
||||
if (n_gpu_layers > n_layer) {
|
||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
|
||||
#ifndef NDEBUG
|
||||
ggml_backend_buffer_type_t buft = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
const char* name = ggml_backend_buft_name(buft);
|
||||
LLAMA_LOG_DEBUG("load_tensors: output layers assigned to backend %s\n",
|
||||
name);
|
||||
#endif
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_output = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
} else {
|
||||
model.buft_output = llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
} else {
|
||||
ggml_backend_buffer_type_t split_buft;
|
||||
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu], tensor_split);
|
||||
if ((split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu]);
|
||||
model.split_buft = split_buft;
|
||||
} else {
|
||||
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
||||
split_buft = llama_default_buffer_type_offload(model, model.devices[main_gpu]);
|
||||
}
|
||||
auto buft_layer = llama_default_buffer_type_offload(model, model.devices[main_gpu]);
|
||||
// assign the repeating layers
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
model.buft_layer[i] = {
|
||||
split_buft,
|
||||
llama_default_buffer_type_offload(model, model.devices[main_gpu])
|
||||
};
|
||||
if (split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count,
|
||||
float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_layer[i] = { split_buft, llama_default_buffer_type_offload(model, model.devices[layer_gpu]) };
|
||||
printf("Layer %d: assigning buft_layer to GPU %d\n", i, layer_gpu);
|
||||
} else {
|
||||
model.buft_layer[i] = { split_buft, buft_layer };
|
||||
}
|
||||
}
|
||||
// assign the output layer
|
||||
if (n_gpu_layers > n_layer) {
|
||||
@@ -1807,24 +1931,33 @@ static bool llm_load_tensors(
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (buf == nullptr) {
|
||||
throw std::runtime_error("unable to allocate backend buffer");
|
||||
int ntensor = 0;
|
||||
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
++ntensor;
|
||||
}
|
||||
model.bufs.push_back(buf);
|
||||
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
|
||||
model.mlock_bufs.emplace_back(new llama_mlock);
|
||||
auto & mlock_buf = model.mlock_bufs.back();
|
||||
mlock_buf->init (ggml_backend_buffer_get_base(buf));
|
||||
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
|
||||
}
|
||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||
bufs.emplace(idx, buf);
|
||||
if (ntensor > 0) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (buf == nullptr) {
|
||||
LLAMA_LOG_ERROR("Failed to allocate buffer type %s\n", ggml_backend_buft_name(buft));
|
||||
throw std::runtime_error("unable to allocate backend buffer");
|
||||
}
|
||||
model.bufs.push_back(buf);
|
||||
if (use_mlock && ggml_backend_buffer_is_host(buf)) {
|
||||
model.mlock_bufs.emplace_back(new llama_mlock);
|
||||
auto & mlock_buf = model.mlock_bufs.back();
|
||||
mlock_buf->init (ggml_backend_buffer_get_base(buf));
|
||||
mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
|
||||
}
|
||||
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
|
||||
bufs.emplace(idx, buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bufs.empty()) {
|
||||
throw std::runtime_error("failed to allocate buffer");
|
||||
LLAMA_LOG_WARN("No tensors in buffer type %s\n", ggml_backend_buft_name(buft));
|
||||
continue;
|
||||
//throw std::runtime_error("failed to allocate buffer (1)");
|
||||
}
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
@@ -4326,8 +4459,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_backend_add_from_device(ctx, ctx->backend_metal);
|
||||
}
|
||||
#elif defined(GGML_USE_CUDA)
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_GRAPH, only the main GPU backend is used
|
||||
ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu, cparams.cuda_params);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
|
||||
@@ -4337,7 +4470,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_backend_add_from_device(ctx, backend);
|
||||
|
||||
} else {
|
||||
// LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
|
||||
// LLAMA_SPLIT_MODE_LAYER and LLAMA_SPLIT_MODE_GRAPH require a backend for each GPU
|
||||
for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
|
||||
ggml_backend_t backend = ggml_backend_cuda_init(device, cparams.cuda_params);
|
||||
if (backend == nullptr) {
|
||||
@@ -4346,12 +4479,11 @@ struct llama_context * llama_new_context_with_model(
|
||||
return nullptr;
|
||||
}
|
||||
ggml_backend_add_from_device(ctx, backend);
|
||||
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
LLAMA_LOG_ERROR("%s: split mode 'graph' or 'attn' not supported. Failed to initialize Vulkan backend\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -4375,8 +4507,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_GRAPH, only the main GPU backend is used
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
|
||||
@@ -4407,9 +4539,9 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_backend_add_from_device(ctx, backend);
|
||||
}
|
||||
#elif defined(GGML_USE_CANN)
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_GRAPH, only the main GPU backend is used
|
||||
// TODO: ggml_backend_cann is not support split tensor now, just leave code here.
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
|
||||
|
||||
Reference in New Issue
Block a user