This works, but it is slow

This commit is contained in:
Kawrakow
2025-11-26 15:50:46 +00:00
parent 4303587f1c
commit 97143330a1
7 changed files with 110 additions and 38 deletions

View File

@@ -334,7 +334,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" --n-cpu-moe <n> (default: none)\n"); printf(" --n-cpu-moe <n> (default: none)\n");
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
printf(" -sm, --split-mode <none|layer> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -sm, --split-mode <none|row|layer> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
@@ -631,11 +631,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
} else if (m == "layer") { } else if (m == "layer") {
mode = LLAMA_SPLIT_MODE_LAYER; mode = LLAMA_SPLIT_MODE_LAYER;
} else if (m == "row") { } else if (m == "row") {
fprintf(stderr, "\n\n=======================================================================\n"); mode = LLAMA_SPLIT_MODE_ROW;
fprintf(stderr, "Split mode 'row' is no longer supported\n"); //fprintf(stderr, "\n\n=======================================================================\n");
fprintf(stderr, "=======================================================================\n\n\n"); //fprintf(stderr, "Split mode 'row' is no longer supported\n");
invalid_param = true; //fprintf(stderr, "=======================================================================\n\n\n");
break; //invalid_param = true;
//break;
} else { } else {
invalid_param = true; invalid_param = true;
break; break;

View File

@@ -1395,6 +1395,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
// do not overwrite user assignments // do not overwrite user assignments
if (*leaf_backend_id == -1) { if (*leaf_backend_id == -1) {
*leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf); *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
//printf("Pass 1: assigned backend %d to leaf %d, %s\n", *leaf_backend_id, i, graph->leafs[i]->name);
} }
} }
@@ -1404,6 +1405,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
// do not overwrite user assignments // do not overwrite user assignments
if (*node_backend_id == -1) { if (*node_backend_id == -1) {
*node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
//printf("Pass 1: assigned backend %d to node %d, %s(%s)\n", *node_backend_id, i, ggml_op_name(node->op), node->name);
#if 0 #if 0
// src // src
@@ -1447,6 +1449,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
cur_backend_id = *node_backend_id; cur_backend_id = *node_backend_id;
} }
} else if (cur_backend_id != -1) { } else if (cur_backend_id != -1) {
//printf("(u1) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
} }
} }
@@ -1468,6 +1471,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
cur_backend_id = *node_backend_id; cur_backend_id = *node_backend_id;
} }
} else if (cur_backend_id != -1) { } else if (cur_backend_id != -1) {
//printf("(d1) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
} }
} }
@@ -1484,6 +1488,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
if (*node_backend_id != -1) { if (*node_backend_id != -1) {
cur_backend_id = *node_backend_id; cur_backend_id = *node_backend_id;
} else if (cur_backend_id != -1) { } else if (cur_backend_id != -1) {
//printf("(u2) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
} }
} }
@@ -1500,6 +1505,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
if (*node_backend_id != -1) { if (*node_backend_id != -1) {
cur_backend_id = *node_backend_id; cur_backend_id = *node_backend_id;
} else if (cur_backend_id != -1) { } else if (cur_backend_id != -1) {
//printf("(d2) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
} }
} }
@@ -1537,6 +1543,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
if (n_supported > n_supported_best) { if (n_supported > n_supported_best) {
n_supported_best = n_supported; n_supported_best = n_supported;
*node_backend_id = b; *node_backend_id = b;
//printf("Pass 3: assigned backend %d to unassigned node %d, %s\n", b, i, node->name);
SET_CAUSE(node, "3.best"); SET_CAUSE(node, "3.best");
} }
} }
@@ -1557,6 +1564,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
} }
} }
if (supported) { if (supported) {
//printf("Pass 3: assigned backend %d to node %d, %s previously assigned to backend %d\n", b, i, node->name, *node_backend_id);
*node_backend_id = b; *node_backend_id = b;
SET_CAUSE(node, "3.upg"); SET_CAUSE(node, "3.upg");
break; break;
@@ -1585,9 +1593,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
// views are always on the same backend as the source // views are always on the same backend as the source
*src_backend_id = tensor_backend_id(src->view_src); *src_backend_id = tensor_backend_id(src->view_src);
SET_CAUSE(src, "4.vsrc"); SET_CAUSE(src, "4.vsrc");
//printf("Pass 4: assigned backend %d to src %d, %s in node %d, %s frpm view_src\n", *src_backend_id, j, src->name, i, node->name);
} else { } else {
*src_backend_id = *cur_backend_id; *src_backend_id = *cur_backend_id;
SET_CAUSE(src, "4.cur"); SET_CAUSE(src, "4.cur");
//printf("Pass 4: assigned backend %d to src %d, %s in node %d, %s frpm current\n", *src_backend_id, j, src->name, i, node->name);
} }
} }
} }

View File

@@ -787,7 +787,7 @@ GGML_CALL static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buf
GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused]] ggml_backend_buffer_t buffer, ggml_tensor * tensor) { GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused]] ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
if (!tensor->extra) return; if (!tensor->extra) return;
printf("%s(%s, %p)\n", __func__, tensor->name, tensor->extra); //printf("%s(%s, %p)\n", __func__, tensor->name, tensor->extra);
auto extra = (ggml_split_tensor_t *)tensor->extra; auto extra = (ggml_split_tensor_t *)tensor->extra;
GGML_ASSERT(extra->n_device <= ggml_backend_cuda_get_device_count()); GGML_ASSERT(extra->n_device <= ggml_backend_cuda_get_device_count());
for (int i = 0; i < extra->n_device; ++i) { for (int i = 0; i < extra->n_device; ++i) {
@@ -808,8 +808,8 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused]
if (padded_size > size) { if (padded_size > size) {
CUDA_CHECK(cudaMemset(buf + size, 0, padded_size - size)); CUDA_CHECK(cudaMemset(buf + size, 0, padded_size - size));
} }
printf(" allocated %zu bytes for tensor %s of type %s, dim = %ld x %ld x %ld. padding: %zu\n", padded_size, split->name, ggml_type_name(split->type), //printf(" allocated %zu bytes for tensor %s of type %s, dim = %ld x %ld x %ld. padding: %zu\n", padded_size, split->name, ggml_type_name(split->type),
split->ne[0], split->ne[1], split->ne[2], padded_size - size); // split->ne[0], split->ne[1], split->ne[2], padded_size - size);
split->data = buf; split->data = buf;
auto ctx = new ggml_backend_cuda_buffer_context(i, buf); auto ctx = new ggml_backend_cuda_buffer_context(i, buf);
auto buft = ggml_backend_cuda_buffer_type(i); auto buft = ggml_backend_cuda_buffer_type(i);
@@ -868,7 +868,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused]
GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]] ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]] ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
if (!tensor->extra) return; if (!tensor->extra) return;
printf("%s(%s)\n", __func__, tensor->name); //printf("%s(%s)\n", __func__, tensor->name);
// split tensors must always be set in their entirety at once // split tensors must always be set in their entirety at once
GGML_ASSERT(offset == 0); GGML_ASSERT(offset == 0);
@@ -880,10 +880,23 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]]
for (int i = 0; i < extra->n_device; ++i) { for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i]; auto split = extra->splits[i];
if (!split) continue; if (!split) continue;
printf(" Split %d: %p, %p, %s\n", i, (void *)split->data, (void *)split->buffer, split->buffer ? ggml_backend_buffer_name(split->buffer) : "none"); //printf(" Split %d: %p, %p, %s\n", i, (void *)split->data, (void *)split->buffer, split->buffer ? ggml_backend_buffer_name(split->buffer) : "none");
} }
if (extra->split_dim == 0) { if (extra->split_dim < 0) {
GGML_ASSERT(ggml_is_contiguous(tensor));
auto nbytes = ggml_nbytes(tensor);
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) continue;
GGML_ASSERT(split->type == tensor->type);
GGML_ASSERT(ggml_are_same_shape(tensor, split));
GGML_ASSERT(ggml_nbytes(split) == nbytes);
ggml_cuda_set_device(i);
CUDA_CHECK(cudaMemcpyAsync(split->data, data, nbytes, cudaMemcpyHostToDevice, cudaStreamPerThread));
}
}
else if (extra->split_dim == 0) {
if (tensor->type >= GGML_TYPE_Q4_0_R8) { if (tensor->type >= GGML_TYPE_Q4_0_R8) {
GGML_ABORT("Dim 0 copy of row-interleaved quants is not supported yet"); GGML_ABORT("Dim 0 copy of row-interleaved quants is not supported yet");
} }
@@ -3072,7 +3085,7 @@ static inline bool ops_are_same_device(const ggml_cgraph * cgraph, int first, in
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) { static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, const ggml_cgraph * cgraph, int & i) {
// why is this here instead of mul_mat? // why is this here instead of mul_mat?
if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
printf("%s: split buffer for %s(%s)\n", __func__, ggml_op_name(dst->op), dst->name); //printf("%s: split buffer for %s(%s)\n", __func__, ggml_op_name(dst->op), dst->name);
ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
} }
@@ -3084,7 +3097,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
auto fusion = ctx.fusion; auto fusion = ctx.fusion;
printf("%4d %s(%s) on device %d. time = %ld\n", i, ggml_op_name(dst->op), dst->name, ctx.device, ggml_time_us()); //printf("%4d %s(%s) on device %d. time = %ld\n", i, ggml_op_name(dst->op), dst->name, ctx.device, ggml_time_us());
switch (dst->op) { switch (dst->op) {
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst); ggml_cuda_argmax(ctx, dst);
@@ -3809,7 +3822,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// TODO // TODO
const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated; const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us()); //printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us());
while (!graph_evaluated_or_captured) { while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch. // With the use of CUDA graphs, the execution will be performed by the graph launch.

View File

@@ -621,7 +621,7 @@ ggml_tensor * llm_build_context::llm_build_norm(
ggml_tensor * llm_build_context::llm_build_ffn( ggml_tensor * llm_build_context::llm_build_ffn(
ggml_context * ctx, ggml_context * ctx,
llama_context & lctx, llama_context & lctx,
ggml_tensor * cur, ggml_tensor * input,
ggml_tensor * up, ggml_tensor * up,
ggml_tensor * up_b, ggml_tensor * up_b,
ggml_tensor * up_s, ggml_tensor * up_s,
@@ -654,7 +654,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
auto split_d = d->splits[id]; auto split_d = d->splits[id];
GGML_ASSERT((!split_u && !split_g && split_d) || (split_u && split_g && split_d)); GGML_ASSERT((!split_u && !split_g && split_d) || (split_u && split_g && split_d));
if (!split_u) continue; if (!split_u) continue;
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op); auto cur = ggml_fused_up_gate(ctx, split_u, split_g, input, unary_op);
cb(cur, "ffn_up_gate", il_cb); cb(cur, "ffn_up_gate", il_cb);
cur = llm_build_lora_mm(lctx, ctx, split_d, cur); cur = llm_build_lora_mm(lctx, ctx, split_d, cur);
cb(cur, "ffn_down", il_cb); cb(cur, "ffn_down", il_cb);
@@ -668,7 +668,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
ffn.push_back(cur); ffn.push_back(cur);
} }
if (ffn.size() == 1) return ffn.front(); if (ffn.size() == 1) return ffn.front();
cur = ggml_add(ctx, ffn[0], ffn[1]); auto cur = ggml_add(ctx, ffn[0], ffn[1]);
cb(cur, "combine_ffn", il); cb(cur, "combine_ffn", il);
for (int id = 2; id < int(ffn.size()); ++id) { for (int id = 2; id < int(ffn.size()); ++id) {
cur = ggml_add(ctx, cur, ffn[id]); cur = ggml_add(ctx, cur, ffn[id]);
@@ -682,7 +682,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU; type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
cur = ggml_fused_up_gate(ctx, up, gate, cur, unary_op); auto cur = ggml_fused_up_gate(ctx, up, gate, input, unary_op);
cb(cur, "ffn_up_gate", il); cb(cur, "ffn_up_gate", il);
if (down) { if (down) {
cur = llm_build_lora_mm(lctx, ctx, down, cur); cur = llm_build_lora_mm(lctx, ctx, down, cur);
@@ -704,7 +704,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
return cur; return cur;
} }
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur; struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, input) : input;
cb(tmp, "ffn_up", il); cb(tmp, "ffn_up", il);
if (up_b) { if (up_b) {
@@ -717,6 +717,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cb(tmp, "ffn_up_s", il); cb(tmp, "ffn_up_s", il);
} }
auto cur = input;
if (gate) { if (gate) {
switch (type_gate) { switch (type_gate) {
case LLM_FFN_SEQ: case LLM_FFN_SEQ:
@@ -1448,19 +1449,21 @@ ggml_cgraph * llm_build_context::build_llama() {
KQ_mask_swa : KQ_mask; KQ_mask_swa : KQ_mask;
int this_n_swa = this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0; int this_n_swa = this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0;
// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// rope freq factors for llama3; may return nullptr for llama2 and other models // rope freq factors for llama3; may return nullptr for llama2 and other models
auto rope_factors = build_rope_factors(il); //auto rope_factors = build_rope_factors(il);
// self-attention // self-attention
if (use_rope) { if (use_rope) {
cur = build_std_attention(gf, cur, inp_pos, rope_factors, this_KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il); cur = build_std_attention(gf, inpL, inp_pos, nullptr, this_KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il);
} }
else { else {
auto rope_factors = build_rope_factors(il);
// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv, model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk, model.layers[il].wqk, model.layers[il].bqk,
@@ -3595,10 +3598,10 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
// norm // norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); //cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); //cb(cur, "attn_norm", il);
cur = build_std_attention(gf, cur, inp_pos, nullptr, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il); cur = build_std_attention(gf, inpL, inp_pos, nullptr, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il);
if (il == n_layer - 1) { if (il == n_layer - 1) {
// skip computing output for unused tokens // skip computing output for unused tokens
@@ -9041,11 +9044,12 @@ ggml_cgraph * llm_build_context::llama_build_graph(
return result; return result;
} }
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors, ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il) { ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il) {
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn && if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) { model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) { if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
ggml_split_tensor_t * attn_norm = model.layers[il].attn_norm ? (ggml_split_tensor_t *)model.layers[il].attn_norm->extra : nullptr;
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra; auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra; auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra; auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
@@ -9066,9 +9070,20 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) || GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl)); (split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) continue; if (!split_wq) continue;
auto cur = input;
if (attn_norm) {
auto split_norm = attn_norm->splits[id];
cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il_cb);
}
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr, auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
split_wq, nullptr, split_wk, nullptr, split_wv, nullptr, split_wq, nullptr, split_wk, nullptr, split_wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il_cb); model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il_cb);
auto rope_factors = rope_factors_in;
if (!rope_factors && model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
auto extra = (ggml_split_tensor_t *)model.layers[il].rope_freqs->extra;
rope_factors = extra->splits[id];
}
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9160,7 +9175,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
attn.push_back(cur); attn.push_back(cur);
} }
if (attn.size() == 1) return attn.front(); if (attn.size() == 1) return attn.front();
cur = ggml_add(ctx0, attn[0], attn[1]); auto cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il); cb(cur, "combine_attn", il);
for (int id = 2; id < (int)attn.size(); ++id) { for (int id = 2; id < (int)attn.size(); ++id) {
cur = ggml_add(ctx0, cur, attn[id]); cur = ggml_add(ctx0, cur, attn[id]);
@@ -9170,15 +9185,21 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
} }
} }
auto cur = input;
if (model.layers[il].attn_norm) {
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
}
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv, model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk, model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv, model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il); model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);

View File

@@ -210,6 +210,7 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
ctx_map[it.first] = ctx; ctx_map[it.first] = ctx;
model.ctxs.push_back(ctx); model.ctxs.push_back(ctx);
} }
#if 0
printf("=======================================================================\n"); printf("=======================================================================\n");
auto n_device = model.device_count(); auto n_device = model.device_count();
printf(" Model has %d devices:\n", n_device); printf(" Model has %d devices:\n", n_device);
@@ -226,11 +227,13 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
for (auto s : model.splits) printf(" %g", s); for (auto s : model.splits) printf(" %g", s);
printf("\n"); printf("\n");
} }
#endif
} }
static std::vector<int> create_split(int nr, int granularity, const std::vector<float> & splits) { static std::vector<int> create_split(int nr, int granularity, const std::vector<float> & splits) {
GGML_ASSERT(nr % granularity == 0); GGML_ASSERT(nr % granularity == 0);
GGML_ASSERT(!splits.empty()); GGML_ASSERT(!splits.empty());
if (granularity < 0) return std::vector<int>(splits.size(), nr);
int nchunk = nr / granularity; int nchunk = nr / granularity;
std::vector<int> result(splits.size()); std::vector<int> result(splits.size());
float last_split = 0; float last_split = 0;
@@ -394,7 +397,7 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
auto & layer = model.layers[i]; auto & layer = model.layers[i];
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
use_mmap_buffer &= !merge_qkv(tn, i, 1); use_mmap_buffer &= !merge_qkv(tn, i, 1);
@@ -405,7 +408,7 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); layer.rope_freqs = create_tensor(ctx_split, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
if (n_expert == 0) { if (n_expert == 0) {
create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split); create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split);
@@ -2745,11 +2748,22 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias, bool i
} }
static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor * tensor, llama_split_tensor & split_tensor, const std::vector<int> & splits) { static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor * tensor, llama_split_tensor & split_tensor, const std::vector<int> & splits) {
GGML_ASSERT(split_dim == 0 || split_dim == 1); GGML_ASSERT(split_dim <= 1);
GGML_ASSERT(splits.size() > 1); GGML_ASSERT(splits.size() > 1);
std::string name{tensor->name}; std::string name{tensor->name};
split_tensor.tensor_splits.resize(splits.size()); split_tensor.tensor_splits.resize(splits.size());
if (split_dim == 1) { if (split_dim < 0) {
for (int i = 0; i < int(splits.size()); ++i) {
if (splits[i] > 0) {
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2]);
auto name_i = name + '.' + std::to_string(i);
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
} else {
split_tensor.tensor_splits[i] = nullptr;
}
}
}
else if (split_dim == 1) {
for (int i = 0; i < int(splits.size()); ++i) { for (int i = 0; i < int(splits.size()); ++i) {
if (splits[i] > 0) { if (splits[i] > 0) {
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], splits[i], tensor->ne[2]); split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], splits[i], tensor->ne[2]);
@@ -2902,10 +2916,18 @@ bool create_tensors_helper::create_tensors() {
if (model.split_mode == LLAMA_SPLIT_MODE_ROW) { if (model.split_mode == LLAMA_SPLIT_MODE_ROW) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
int gqa_ratio = hparams.n_head() / hparams.n_head_kv(); int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
printf("GQA ratio: %d\n", gqa_ratio); //printf("GQA ratio: %d\n", gqa_ratio);
for (int il = 0; il < int(model.layers.size()); ++il) { for (int il = 0; il < int(model.layers.size()); ++il) {
auto & layer = model.layers[il]; auto & layer = model.layers[il];
auto ctx_split = ctx_for_layer_split(il); auto ctx_split = ctx_for_layer_split(il);
if (layer.attn_norm) {
auto split = create_split(ggml_nrows(layer.attn_norm), -1, model.splits);
prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split);
}
if (layer.rope_freqs) {
auto split = create_split(ggml_nrows(layer.rope_freqs), -1, model.splits);
prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split);
}
if (layer.wo && layer.wq && layer.wk && layer.wv) { if (layer.wo && layer.wq && layer.wk && layer.wv) {
int attn_granularity = hparams.n_embd_head_k; int attn_granularity = hparams.n_embd_head_k;
if (ggml_is_quantized(layer.wo->type)) { if (ggml_is_quantized(layer.wo->type)) {

View File

@@ -183,6 +183,7 @@ struct llama_layer {
struct ggml_tensor * bqk = nullptr; struct ggml_tensor * bqk = nullptr;
struct ggml_tensor * bkv = nullptr; struct ggml_tensor * bkv = nullptr;
llama_split_tensor split_attn_norm;
llama_split_tensor split_wq; llama_split_tensor split_wq;
llama_split_tensor split_wk; llama_split_tensor split_wk;
llama_split_tensor split_wv; llama_split_tensor split_wv;
@@ -280,6 +281,8 @@ struct llama_layer {
struct ggml_tensor * rope_short = nullptr; struct ggml_tensor * rope_short = nullptr;
struct ggml_tensor * rope_freqs = nullptr; struct ggml_tensor * rope_freqs = nullptr;
llama_split_tensor split_rope_freqs;
// bitnet scale // bitnet scale
struct ggml_tensor * wq_scale = nullptr; struct ggml_tensor * wq_scale = nullptr;
struct ggml_tensor * wk_scale = nullptr; struct ggml_tensor * wk_scale = nullptr;

View File

@@ -806,6 +806,7 @@ static bool llama_kv_cache_init(
cache.bufs.push_back(buf); cache.bufs.push_back(buf);
} }
#if 0
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
if (cache.k_l[il]->extra) { if (cache.k_l[il]->extra) {
printf("Layer %2d, K-buffer: %p:", il, (void *)cache.k_l[il]->buffer); printf("Layer %2d, K-buffer: %p:", il, (void *)cache.k_l[il]->buffer);
@@ -824,6 +825,7 @@ static bool llama_kv_cache_init(
printf("\n"); printf("\n");
} }
} }
#endif
return true; return true;
} }