mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
POC: CUDA tensor parallel (MoE models) (#1022)
* Remove most of split mode row
* WIP
* WIP: also allocate the KV cache using tensor split
* WIP: it runs with wrong result
But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
wait for GPU 0 to finish its entore FFN calculation before it can
start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
* WIP
* This works, but it is slow
* This is slightly better
the graph is still not being computed in parallel.
Why? Because the scheduler creates graph splits where the
result of the computation on one GPU becomes an input for the
other split. Hence, to trigger the computation on the second GPU
one needs to wait for the computation on the first GPU to finish,
even thiough the two can be done in parallel up to the sunchronization
point. So, all that is left to do is to trick the scheduler to create
to splits that can be done in parallel, and then have a graph split
where the results get combined.
* Playing games with the scheduler
This change tricks it into doing the right thing^TM.
Still quite a bit slower than split mode layer for the 8B LlaMA model.
But for the 70B LlaMA it now beats split mode layer for TG:
28 t/s vs 24.4 t/s. PP is 627 t/s vs 744 t/s.
In comparison, split mode "row" in mainline gets
484 t/s PP and 19.3 t/s TG.
* Fix attn split
Granularity for Wq, Wo is not just head size, but
head size * gqa_ratio.
Else the Wk, Wv tensors end up not being a multiple of the
head size when we divide the split determined by Wo with
the gqa_ratio.
* Show memory used per device
* Make it work with partial offload
but no tensor overrides yet, just ngl < num_layers.
* Allow for f16 source in fused_rms_norm
* This results in faster PP.
Now PP is faster than split mode layer for L3-70B.
* Rename split mode "row" to split mode "graph"
* Leave FFN partial results as f16
* WIP GLM4.5 - runs with wrong results
* WIP GLM4.5 - this works
PP is already better than split mode layer, but TG for zero context
is kind of low - 60 vs 92 t/s. TG becomes better than split mode layer
at around 20k tokens. PP at 26k tokens is 1.55X of sm layer.
* Work around compiler bug
It issues a warning that there is an extra semicolon outside of a function,
but there isn't. If I remove the anonymous namespace and turn the
functions inside into static, the warning disapears, so clearly
a compiler bug.
* Make graph reuse work with split mode graph
* Remove more split mode row remnants
* WIP tensor overrides
Runs with wrong results, don't see where the issue could be.
* This works but is slow
Still does not work for row-interleaved quants
* Slightly better
* Slightly better
* Row-interleaved quants work
* Better
* Minor
* Guarad against using split mode "graph" for unsupported models
* Guards against using merge_qkv with split mode "graph"
* WIP split mode attn
Works for LlaMA models, but not for GLM-4.5.
Doesn't seem to improve performance, so I guess no point in trying to
fix it.
* Split mode graph for qwen3moe
* Try to better distribute the splits
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <array>
|
||||
#include <future>
|
||||
#include <regex>
|
||||
#include <unordered_set>
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
|
||||
@@ -139,7 +140,7 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
ggml_context ** actual_ctx = nullptr);
|
||||
|
||||
void create_default_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool norm_bias);
|
||||
void create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm = true);
|
||||
void create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm = true, bool use_ctx_split = false);
|
||||
|
||||
void create_std_attn(int i, const LLM_TN & tn, llama_layer & layer, int n_embd, int n_embd_gqa, ggml_context * ctx_split);
|
||||
void create_std_ffn(int i, const LLM_TN & tn, llama_layer & layer, int n_ff, int n_embd, ggml_context * ctx_split);
|
||||
@@ -153,12 +154,15 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
ggml_context * split_ctx = nullptr;
|
||||
size_t ctx_size;
|
||||
|
||||
ggml_context * ctx_input;
|
||||
ggml_context * ctx_output;
|
||||
ggml_context * ctx_output_split;
|
||||
|
||||
std::unordered_set<ggml_tensor *> split_tensors;
|
||||
|
||||
inline ggml_context * ctx_for_buft(ggml_backend_buffer_type_t buft) {
|
||||
if (auto it = ctx_map.find(buft); it != ctx_map.end()) return it->second;
|
||||
|
||||
@@ -179,6 +183,14 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_model & _model) : ml(_ml), model(_model) {
|
||||
|
||||
#if 0
|
||||
for (int i = 0; i < model.hparams.n_layer; ++i) {
|
||||
printf("Layer %2d: %s %s\n", i, ggml_backend_buft_name(model.buft_layer[i].buft_matrix), ggml_backend_buft_name(model.buft_layer[i].buft));
|
||||
}
|
||||
printf("Output: %s %s\n", ggml_backend_buft_name(model.buft_output.buft_matrix), ggml_backend_buft_name(model.buft_output.buft));
|
||||
printf(" Input: %s %s\n", ggml_backend_buft_name(model.buft_input.buft_matrix), ggml_backend_buft_name(model.buft_input.buft));
|
||||
#endif
|
||||
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
buft_layer_count[model.buft_input.buft]++;
|
||||
buft_layer_count[model.buft_input.buft_matrix]++;
|
||||
@@ -192,6 +204,11 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
|
||||
ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output
|
||||
ctx_size += ggml_tensor_overhead()*n_layer*3; // for moe merged tensors
|
||||
|
||||
if (model.splits.size() > 1) {
|
||||
ctx_size += ggml_tensor_overhead()*n_layer*4; // for KV cache
|
||||
ctx_size *= (model.splits.size() + 1);
|
||||
}
|
||||
|
||||
for (auto & it : buft_layer_count) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
@@ -205,10 +222,95 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
|
||||
ctx_map[it.first] = ctx;
|
||||
model.ctxs.push_back(ctx);
|
||||
}
|
||||
if (model.split_buft) {
|
||||
if (auto it = ctx_map.find(model.split_buft); it != ctx_map.end()) {
|
||||
split_ctx = it->second;
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
printf("=======================================================================\n");
|
||||
auto n_device = model.device_count();
|
||||
printf(" Model has %d devices:\n", n_device);
|
||||
for (int device = 0; device < n_device; ++device) {
|
||||
auto buft = model.default_buffer_type_offload(device);
|
||||
if (buft) {
|
||||
printf(" %d %s\n", device, ggml_backend_buft_name(buft));
|
||||
} else {
|
||||
printf(" Oops: null buft for debvice %d\n", device);
|
||||
}
|
||||
}
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
printf("model.splits:");
|
||||
for (auto s : model.splits) printf(" %g", s);
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::vector<int> create_split(int nr, int granularity, const std::vector<float> & splits, const std::vector<size_t> & mem_used) {
|
||||
GGML_ASSERT(nr % granularity == 0);
|
||||
GGML_ASSERT(!splits.empty());
|
||||
if (granularity < 0) return std::vector<int>(splits.size(), nr);
|
||||
GGML_ASSERT(mem_used.size() == splits.size());
|
||||
size_t tot_memory_used = 1;
|
||||
for (auto & mem : mem_used) tot_memory_used += mem;
|
||||
int nchunk = nr / granularity;
|
||||
std::vector<int> result(splits.size());
|
||||
float last_split = 0;
|
||||
int sum = 0;
|
||||
for (int i = 0; i < (int)splits.size(); ++i) {
|
||||
float p = splits[i] - last_split;
|
||||
p += (p - 1.f*mem_used[i]/tot_memory_used);
|
||||
result[i] = roundf(p*nchunk);
|
||||
if (result[i] < 0) result[i] = 0;
|
||||
sum += result[i];
|
||||
last_split = splits[i];
|
||||
}
|
||||
while (sum > nchunk) {
|
||||
last_split = 0;
|
||||
float best_err = std::numeric_limits<float>::max();
|
||||
int ibest = -1;
|
||||
for (int i = 0; i < (int)splits.size(); ++i) {
|
||||
if (result[i] > 0) {
|
||||
float p = splits[i] - last_split;
|
||||
float n_want = p*nchunk;
|
||||
float err = std::abs(n_want - result[i] + 1);
|
||||
//float err = std::abs(n_want - result[i] + 1) + std::abs(p - 1.f*mem_used[i]/tot_memory_used)*nchunk;
|
||||
if (err < best_err) {
|
||||
best_err = err; ibest = i;
|
||||
}
|
||||
}
|
||||
last_split = splits[i];
|
||||
}
|
||||
GGML_ASSERT(ibest >= 0 && result[ibest] > 0);
|
||||
--result[ibest];
|
||||
--sum;
|
||||
}
|
||||
while (sum < nchunk) {
|
||||
last_split = 0;
|
||||
float best_err = std::numeric_limits<float>::max();
|
||||
int ibest = -1;
|
||||
for (int i = 0; i < (int)splits.size(); ++i) {
|
||||
float p = splits[i] - last_split;
|
||||
float n_want = p*nchunk;
|
||||
float err = std::abs(n_want - result[i] - 1);
|
||||
//float err = std::abs(n_want - result[i] - 1) + std::abs(p - 1.f*mem_used[i]/tot_memory_used)*nchunk;
|
||||
if (err < best_err) {
|
||||
best_err = err; ibest = i;
|
||||
}
|
||||
last_split = splits[i];
|
||||
}
|
||||
GGML_ASSERT(ibest >= 0);
|
||||
++result[ibest];
|
||||
++sum;
|
||||
}
|
||||
for (auto & r : result) r *= granularity;
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne,
|
||||
int flags, ggml_context ** actual_context) {
|
||||
//auto requested_ctx = ctx;
|
||||
if (ml.tensor_buft_overrides) {
|
||||
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
|
||||
std::regex pattern(overrides->pattern);
|
||||
@@ -220,7 +322,12 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std
|
||||
}
|
||||
}
|
||||
if (actual_context) *actual_context = ctx;
|
||||
return ml.create_tensor(ctx, name, ne, flags);
|
||||
auto tensor = ml.create_tensor(ctx, name, ne, flags);
|
||||
if (tensor && ctx == split_ctx) {
|
||||
//printf("%s: adding tensor %s to split tensors\n", __func__, tensor->name);
|
||||
split_tensors.insert(tensor);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#define LOADING_PRELUDE \
|
||||
@@ -251,17 +358,18 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std
|
||||
bool use_mmap_buffer = true;
|
||||
|
||||
|
||||
void create_tensors_helper::create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm) {
|
||||
void create_tensors_helper::create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm, bool use_ctx_split) {
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
auto out_ctx = use_ctx_split ? ctx_output_split : ctx_output;
|
||||
if (has_norm) {
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm = create_tensor(out_ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
}
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.output = create_tensor(out_ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
model.output = create_tensor(out_ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +388,7 @@ void create_tensors_helper::create_std_ffn(int i, const LLM_TN & tn, llama_layer
|
||||
|
||||
bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
create_embd_output(tn, n_embd, n_vocab);
|
||||
create_embd_output(tn, n_embd, n_vocab, true, true);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
@@ -288,7 +396,7 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
use_mmap_buffer &= !merge_qkv(tn, i, 1);
|
||||
|
||||
@@ -297,12 +405,12 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
// optional bias tensors
|
||||
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm = create_tensor(model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_freqs = create_tensor(ctx_split, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
|
||||
if (n_expert == 0) {
|
||||
create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split);
|
||||
create_std_ffn(i, tn, layer, n_ff, n_embd, model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer);
|
||||
|
||||
// optional MLP bias
|
||||
layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
@@ -1043,11 +1151,11 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1057,18 +1165,19 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
use_mmap_buffer &= !merge_qkv(tn, i, 0);
|
||||
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||
|
||||
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
|
||||
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
|
||||
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
|
||||
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
|
||||
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
|
||||
if (n_expert == 0) {
|
||||
throw std::runtime_error("n_expert must be > 0 for QWEN3MOE");
|
||||
@@ -1080,9 +1189,9 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||
layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
}
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
@@ -1734,7 +1843,7 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers");
|
||||
GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers");
|
||||
|
||||
create_embd_output(tn, n_embd, n_vocab);
|
||||
create_embd_output(tn, n_embd, n_vocab, true, true);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
@@ -1748,7 +1857,7 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
// GLM-style attention with bias terms
|
||||
if (!flags) {
|
||||
@@ -1765,12 +1874,17 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
|
||||
|
||||
// K/Q norm tensors (optional for GLM-4.5 355B variant)
|
||||
layer.attn_q_norm = create_tensor(ctx_layer,
|
||||
layer.attn_q_norm = create_tensor(ctx_split,
|
||||
tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
|
||||
layer.attn_k_norm = create_tensor(ctx_layer,
|
||||
layer.attn_k_norm = create_tensor(ctx_split,
|
||||
tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
|
||||
|
||||
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
|
||||
|
||||
// Why are we adding an additional tensor type?
|
||||
// attn_post_norm is the exact same thing as ffn_norm
|
||||
//layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
|
||||
// Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
|
||||
// GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
|
||||
@@ -1778,35 +1892,35 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
if (use_moe) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
|
||||
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
|
||||
// gate bias
|
||||
layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
|
||||
layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split,
|
||||
layer.ffn_gate_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
|
||||
layer.ffn_down_exps = create_tensor(ctx_split,
|
||||
layer.ffn_down_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
|
||||
layer.ffn_up_exps = create_tensor(ctx_split,
|
||||
layer.ffn_up_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
|
||||
|
||||
// Shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_gate_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
|
||||
layer.ffn_down_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_down_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
|
||||
layer.ffn_up_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_up_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
|
||||
}
|
||||
} else {
|
||||
// Dense layers (first k layers) - GLM uses separate gate/up projections
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
|
||||
layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
|
||||
}
|
||||
// --- NextN / MTP tensors (preserved but unused), on the final layer ---
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
@@ -2629,18 +2743,77 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias, bool i
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
if (bias) {
|
||||
auto flags = bias == 1 ? llama_model_loader::TENSOR_NOT_REQUIRED : 0;
|
||||
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {layer.wq->ne[1]}, flags);
|
||||
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {layer.wk->ne[1]}, flags);
|
||||
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]}, flags);
|
||||
layer.bq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "bias", i), {layer.wq->ne[1]}, flags);
|
||||
layer.bk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "bias", i), {layer.wk->ne[1]}, flags);
|
||||
layer.bv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]}, flags);
|
||||
}
|
||||
}
|
||||
|
||||
return fused_qkv;
|
||||
}
|
||||
|
||||
static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor * tensor, llama_split_tensor & split_tensor,
|
||||
const std::vector<int> & splits, std::vector<size_t> & mem_used) {
|
||||
GGML_ASSERT(split_dim <= 1);
|
||||
GGML_ASSERT(splits.size() > 1);
|
||||
std::string name{tensor->name};
|
||||
split_tensor.tensor_splits.resize(splits.size());
|
||||
if (split_dim < 0) {
|
||||
for (int i = 0; i < int(splits.size()); ++i) {
|
||||
if (splits[i] > 0) {
|
||||
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2]);
|
||||
auto name_i = name + '.' + std::to_string(i);
|
||||
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
|
||||
} else {
|
||||
split_tensor.tensor_splits[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (split_dim == 1) {
|
||||
for (int i = 0; i < int(splits.size()); ++i) {
|
||||
if (splits[i] > 0) {
|
||||
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], splits[i], tensor->ne[2]);
|
||||
auto name_i = name + '.' + std::to_string(i);
|
||||
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
|
||||
} else {
|
||||
split_tensor.tensor_splits[i] = nullptr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < int(splits.size()); ++i) {
|
||||
if (splits[i] > 0) {
|
||||
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, splits[i], tensor->ne[1], tensor->ne[2]);
|
||||
auto name_i = name + '.' + std::to_string(i);
|
||||
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
|
||||
} else {
|
||||
split_tensor.tensor_splits[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
split_tensor.ggml.n_device = splits.size();
|
||||
split_tensor.ggml.split_dim = split_dim;
|
||||
split_tensor.ggml.splits = split_tensor.tensor_splits.data();
|
||||
tensor->extra = (void *)&split_tensor.ggml;
|
||||
GGML_ASSERT(mem_used.size() >= splits.size());
|
||||
for (int i = 0; i < split_tensor.ggml.n_device; ++i) {
|
||||
if (split_tensor.ggml.splits[i]) {
|
||||
//auto nbytes = ggml_nbytes(split_tensor.ggml.splits[i]);
|
||||
//printf("mem_used(%s): %8.2f, total: %8.2f\n", split_tensor.ggml.splits[i]->name, nbytes/1024./1024., (mem_used[i] + nbytes)/1024./1024.);
|
||||
mem_used[i] += ggml_nbytes(split_tensor.ggml.splits[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_tensors() {
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
bool use_mmap_buffer = true;
|
||||
if (ml.merge_qkv && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
|
||||
LLAMA_LOG_WARN("\n========================================================\n");
|
||||
LLAMA_LOG_WARN("merge_qkv is not compatible with split model 'graph'\n");
|
||||
LLAMA_LOG_WARN(" => turning off merge_qkv\n");
|
||||
LLAMA_LOG_WARN("========================================================\n\n");
|
||||
ml.merge_qkv = false;
|
||||
}
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_REFACT:
|
||||
@@ -2761,6 +2934,157 @@ bool create_tensors_helper::create_tensors() {
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
std::vector<size_t> mem_used(model.splits.size(), 0);
|
||||
const auto & hparams = model.hparams;
|
||||
int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
|
||||
//printf("GQA ratio: %d\n", gqa_ratio);
|
||||
for (int il = 0; il < int(model.layers.size()); ++il) {
|
||||
if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
auto & layer = model.layers[il];
|
||||
auto ctx_split = ctx_for_layer_split(il);
|
||||
if (layer.attn_norm) {
|
||||
auto split = create_split(ggml_nrows(layer.attn_norm), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split, mem_used);
|
||||
}
|
||||
if (layer.rope_freqs) {
|
||||
auto split = create_split(ggml_nrows(layer.rope_freqs), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split, mem_used);
|
||||
}
|
||||
if (layer.wo && layer.wq && layer.wk && layer.wv) {
|
||||
int attn_granularity = hparams.n_embd_head_k * gqa_ratio;
|
||||
if (ggml_is_quantized(layer.wo->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.wo->type);
|
||||
if (tt.blck_size > attn_granularity) attn_granularity = tt.blck_size;
|
||||
}
|
||||
GGML_ASSERT(attn_granularity % hparams.n_embd_head_k == 0);
|
||||
auto split = create_split(layer.wo->ne[0], attn_granularity, model.splits, mem_used);
|
||||
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split, mem_used);
|
||||
if (layer.bo) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.bo, layer.split_bo, split, mem_used);
|
||||
}
|
||||
if (layer.bq) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bq, layer.split_bq, split, mem_used);
|
||||
}
|
||||
if (layer.attn_q_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split, mem_used);
|
||||
}
|
||||
for (auto & s : split) s /= gqa_ratio;
|
||||
prepare_split_tensors(1, ctx_split, layer.wk, layer.split_wk, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wv, layer.split_wv, split, mem_used);
|
||||
if (layer.bk) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bk, layer.split_bk, split, mem_used);
|
||||
}
|
||||
if (layer.bv) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bv, layer.split_bv, split, mem_used);
|
||||
}
|
||||
if (layer.attn_k_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_norm) {
|
||||
if (auto it = split_tensors.find(layer.ffn_norm); it != split_tensors.end()) {
|
||||
auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up) != split_tensors.end();
|
||||
if (use_split) {
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down->ne[0], ffn_granularity, model.splits, mem_used);
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down, layer.split_ffn_down, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up, layer.split_ffn_up, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
//bool any_ffn_split = false;
|
||||
if (layer.ffn_down_shexp && layer.ffn_up_shexp && layer.ffn_gate_shexp) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down_shexp) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate_shexp) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up_shexp) != split_tensors.end();
|
||||
if (use_split) {
|
||||
//any_ffn_split = true;
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down_shexp->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down_shexp->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down_shexp->ne[0], ffn_granularity, model.splits, mem_used);
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down_shexp, layer.split_ffn_down_shexp, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up_shexp, layer.split_ffn_up_shexp, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate_shexp, layer.split_ffn_gate_shexp, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_down_exps && layer.ffn_up_exps && layer.ffn_gate_exps) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down_exps) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate_exps) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up_exps) != split_tensors.end();
|
||||
|
||||
if (use_split) {
|
||||
//any_ffn_split = true;
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down_exps->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down_exps->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down_exps->ne[0], ffn_granularity, model.splits, mem_used);
|
||||
//printf("split(%2d):", il); for (auto & s : split) printf(" %d", s); printf("\n");
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down_exps, layer.split_ffn_down_exps, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up_exps, layer.split_ffn_up_exps, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate_exps, layer.split_ffn_gate_exps, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_gate_inp) {
|
||||
if (auto it = split_tensors.find(layer.ffn_gate_inp); it != split_tensors.end()) {
|
||||
auto shared_split = create_split(ggml_nrows(layer.ffn_gate_inp), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp, layer.split_ffn_gate_inp, shared_split, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_exp_probs_b) {
|
||||
if (auto it = split_tensors.find(layer.ffn_exp_probs_b); it != split_tensors.end()) {
|
||||
auto shared_split = create_split(ggml_nrows(layer.ffn_exp_probs_b), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_exp_probs_b, layer.split_ffn_exp_probs_b, shared_split, mem_used);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.output) {
|
||||
if (auto it = split_tensors.find(model.output); it != split_tensors.end()) {
|
||||
if (ggml_backend_buft_is_host(model.buft_output.buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting output tensor becausee buffer is host\n", __func__);
|
||||
} else {
|
||||
auto ctx_split = ctx_map[model.buft_output.buft_matrix];
|
||||
auto split = create_split(model.output->ne[1], 16, model.splits, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, model.output, model.split_output, split, mem_used);
|
||||
if (auto it = split_tensors.find(model.output_norm); it != split_tensors.end() && !ggml_backend_buft_is_host(model.buft_output.buft_matrix)) {
|
||||
auto ctx_split = ctx_map[model.buft_output.buft_matrix];
|
||||
prepare_split_tensors(-1, ctx_split, model.output_norm, model.split_output_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("Estimated model buffer size per device:\n");
|
||||
for (int i = 0; i < int(mem_used.size()); ++i) {
|
||||
LLAMA_LOG_INFO(" Device %d: %8.2f MiB\n", i, mem_used[i]/1024./1024.);
|
||||
}
|
||||
}
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user