diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4a4bbdaf..85b84483 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -568,6 +568,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t } GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + printf("%s(%s -> %s)\n", __func__, src->name, dst->name); if (ggml_backend_buffer_is_cuda(src->buffer)) { ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context; ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context; @@ -788,6 +789,37 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor([[maybe_unused] GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]] ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { if (!tensor->extra) return; + static std::map k_map = { + { GGML_TYPE_Q4_0_R8 , 8}, + { GGML_TYPE_Q5_0_R4 , 4}, + { GGML_TYPE_Q8_0_R8 , 8}, + { GGML_TYPE_Q2_K_R4 , 4}, + { GGML_TYPE_Q3_K_R4 , 4}, + { GGML_TYPE_Q4_K_R4 , 4}, + { GGML_TYPE_Q5_K_R4 , 4}, + { GGML_TYPE_Q6_K_R4 , 4}, + { GGML_TYPE_IQ2_XXS_R4, 4}, + { GGML_TYPE_IQ2_XS_R4 , 4}, + { GGML_TYPE_IQ3_XXS_R4, 4}, + { GGML_TYPE_IQ1_S_R4 , 4}, + { GGML_TYPE_IQ4_NL_R4 , 4}, + { GGML_TYPE_IQ3_S_R4 , 4}, + { GGML_TYPE_IQ2_S_R4 , 4}, + { GGML_TYPE_IQ4_XS_R8 , 8}, + { GGML_TYPE_IQ1_M_R4 , 4}, + { GGML_TYPE_BF16_R16 , 16}, + { GGML_TYPE_Q6_0_R4 , 4}, + { GGML_TYPE_IQ2_BN_R4 , 4}, + { GGML_TYPE_IQ2_K_R4 , 4}, + { GGML_TYPE_IQ3_K_R4 , 4}, + { GGML_TYPE_IQ4_K_R4 , 4}, + { GGML_TYPE_IQ5_K_R4 , 4}, + { GGML_TYPE_IQ4_KS_R4 , 4}, + { GGML_TYPE_IQ5_KS_R4 , 4}, + { GGML_TYPE_Q8_K_R16 , 4}, + { GGML_TYPE_Q8_KV_R8 , 4}, + { GGML_TYPE_Q8_K_R8 , 8}, + }; //printf("%s(%s)\n", __func__, tensor->name); // split tensors must always be set in their entirety at once @@ -811,9 +843,11 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]] } } else if (extra->split_dim == 0) { - if (tensor->type >= GGML_TYPE_Q4_0_R8) { - GGML_ABORT("Dim 0 copy of row-interleaved quants is not supported yet"); - } + int n_interleave = 1; + if (auto it = k_map.find(tensor->type); it != k_map.end()) n_interleave = 1; + //if (tensor->type >= GGML_TYPE_Q4_0_R8) { + // GGML_ABORT("Dim 0 copy of row-interleaved quants is not supported yet"); + //} auto tt = ggml_internal_get_type_traits(tensor->type); std::vector host_buffer; GGML_ASSERT(ggml_is_contiguous(tensor)); @@ -825,21 +859,22 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]] for (int i = 0; i < extra->n_device; ++i) { auto split = extra->splits[i]; if (!split) continue; + GGML_ASSERT(split->ne[1]%n_interleave == 0); ggml_cuda_set_device(i); GGML_ASSERT(split->type == tensor->type); GGML_ASSERT((int)ggml_nrows(split) == nrows); GGML_ASSERT(split->ne[0] % bs == 0); - auto source_offset = tt.row_meta_size + (ne / bs) * ts; + auto source_offset = n_interleave*(tt.row_meta_size + (ne / bs) * ts); auto split_row_size = ggml_row_size(split->type, split->ne[0]); if (host_buffer.size() < nrows*split_row_size) host_buffer.resize(nrows*split_row_size); for (int64_t i02 = 0; i02 < split->ne[2]; ++i02) { - for (int64_t i01 = 0; i01 < split->ne[1]; ++i01) { + for (int64_t i01 = 0; i01 < split->ne[1]; i01 += n_interleave) { auto dst = host_buffer.data() + (i02*split->ne[1] + i01)*split_row_size; auto src = (const char *)data + i02*tensor->nb[2] + i01*tensor->nb[1]; if (tt.row_meta_size > 0) { - memcpy(dst, src, tt.row_meta_size); + memcpy(dst, src, tt.row_meta_size*n_interleave); } - memcpy(dst + tt.row_meta_size, src + source_offset, split_row_size - tt.row_meta_size); + memcpy(dst + tt.row_meta_size*n_interleave, src + source_offset, n_interleave*(split_row_size - tt.row_meta_size)); } } CUDA_CHECK(cudaMemcpyAsync(split->data, host_buffer.data(), nrows*split_row_size, cudaMemcpyHostToDevice, cudaStreamPerThread)); @@ -3487,7 +3522,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture #ifndef NDEBUG - GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer %s\n", __func__, node->src[0]->name); #endif } diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 9866124d..245dbe55 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1075,36 +1075,83 @@ llm_expert_gating_func_type gating_op, llm_ffn_op_type type_op_shexp, const llm_build_cb & cb, int il, ggml_cgraph * graph) { - auto split_up_exps = (ggml_split_tensor_t *)up_exps->extra; - auto split_gate_exps = (ggml_split_tensor_t *)gate_exps->extra; - auto split_down_exps = (ggml_split_tensor_t *)down_exps->extra; + auto split_up_exps = (ggml_split_tensor_t *)up_exps->extra; + auto split_gate_exps = (ggml_split_tensor_t *)gate_exps->extra; + auto split_down_exps = (ggml_split_tensor_t *)down_exps->extra; + auto split_up_shexp = up_shexp ? (ggml_split_tensor_t *)up_shexp->extra : nullptr; + auto split_gate_shexp = gate_shexp ? (ggml_split_tensor_t *)gate_shexp->extra : nullptr; + auto split_down_shexp = down_shexp ? (ggml_split_tensor_t *)down_shexp->extra : nullptr; + auto split_up_b_shexp = up_b_shexp ? (ggml_split_tensor_t *)up_b_shexp : nullptr; + auto split_gate_b_shexp = gate_b_shexp ? (ggml_split_tensor_t *)gate_b_shexp : nullptr; + auto split_down_b_shexp = down_b_shexp ? (ggml_split_tensor_t *)down_b_shexp : nullptr; if (!split_up_exps && !split_gate_exps && !split_down_exps) { auto cur = input; if (ffn_norm) { - cur = llm_build_norm(ctx, input, lctx.model.hparams, ffn_norm, nullptr, LLM_NORM_RMS, cb, il); + auto the_ffn_norm = ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[lctx.model.main_gpu] : ffn_norm; + cur = llm_build_norm(ctx, input, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il); cb(cur, "ffn_inp_normed", il); } + else if (cur->type != GGML_TYPE_F32) { + cur = ggml_cast(ctx, cur, GGML_TYPE_F32); + } + auto the_gate_inp = gate_inp->extra ? ((ggml_split_tensor_t *)gate_inp->extra)->splits[lctx.model.main_gpu] : gate_inp; + auto the_gate_inp_b = gate_inp_b ? gate_inp_b->extra ? ((ggml_split_tensor_t *)gate_inp_b->extra)->splits[lctx.model.main_gpu] : gate_inp_b : nullptr; + auto the_exp_probs_b = exp_probs_b ? exp_probs_b->extra ? ((ggml_split_tensor_t *)exp_probs_b->extra)->splits[lctx.model.main_gpu] : exp_probs_b : nullptr; + //printf("Using non-split llm_build_moe_ffn for layer %d\n", il); auto routed_out = llm_build_moe_ffn(ctx, lctx, cur, - gate_inp, gate_inp_b, + the_gate_inp, the_gate_inp_b, up_exps, up_exps_b, gate_exps, gate_exps_b, down_exps, down_exps_b, - exp_probs_b, + the_exp_probs_b, n_expert, n_expert_used, type_op, norm_w, scale_w, w_scale, gating_op, cb, il, graph); cb(routed_out, "routed_out", il); if (up_shexp && gate_shexp && down_shexp) { - auto shared_out = llm_build_ffn(ctx, lctx, nullptr, cur, - up_shexp, up_b_shexp, nullptr, - gate_shexp, gate_b_shexp, nullptr, - down_shexp, down_b_shexp, nullptr, - nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); - cb(shared_out, "ffn_shexp_out", il); - - cur = ggml_add(ctx, routed_out, shared_out); - cb(cur, "ffn_out", il); + if (split_up_shexp) { + //printf("Using split ffn for shared experts in layer %d\n", il); + std::vector results(split_up_shexp->n_device); + GGML_ASSERT(!split_up_b_shexp || split_up_b_shexp->n_device == split_up_shexp->n_device); + GGML_ASSERT(!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_shexp->n_device); + GGML_ASSERT(!split_down_b_shexp || split_down_b_shexp->n_device == split_up_shexp->n_device); + for (int id = 0; id < split_up_shexp->n_device; ++id) { + int il_cb = 1000*id + il; + auto shared_out = llm_build_ffn(ctx, lctx, nullptr, cur, + split_up_shexp->splits[id], split_up_b_shexp ? split_up_b_shexp->splits[id] : nullptr, nullptr, + split_gate_shexp->splits[id], split_gate_b_shexp ? split_gate_b_shexp->splits[id] : nullptr, nullptr, + split_down_shexp->splits[id], split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, + nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); + cb(shared_out, "ffn_shexp_out", il_cb); + if (shared_out->ne[1] > 32) { + shared_out = ggml_cast(ctx, shared_out, GGML_TYPE_F16); + } + results[id] = shared_out; + } + auto cur = ggml_add(ctx, results[0], results[1]); + cur->op_params[0] = 0xff; + cb(cur, "ffn_shared_combined", il); + for (int id = 2; id < int(results.size()); ++id) { + cur = ggml_add(ctx, cur, results[id]); + cb(cur, "ffn_shared_combined", il); + } + if (cur->type == GGML_TYPE_F16) { + cur = ggml_cast(ctx, cur, GGML_TYPE_F32); + } + cur = ggml_add(ctx, routed_out, cur); + cb(cur, "ffn_out", il); + } else { + //printf("Using non-split ffn for shared experts in layer %d\n", il); + auto shared_out = llm_build_ffn(ctx, lctx, nullptr, cur, + up_shexp, up_b_shexp, nullptr, + gate_shexp, gate_b_shexp, nullptr, + down_shexp, down_b_shexp, nullptr, + nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); + cb(shared_out, "ffn_shexp_out", il); + cur = ggml_add(ctx, routed_out, shared_out); + cb(cur, "ffn_out", il); + } } else { cur = routed_out; } @@ -1113,16 +1160,12 @@ llm_expert_gating_func_type gating_op, GGML_ASSERT(split_up_exps && split_gate_exps && split_down_exps); GGML_ASSERT(split_up_exps->n_device == split_gate_exps->n_device && split_up_exps->n_device == split_down_exps->n_device); std::vector results(split_up_exps->n_device); - auto split_up_shexp = up_shexp ? (ggml_split_tensor_t *)up_shexp->extra : nullptr; - auto split_gate_shexp = gate_shexp ? (ggml_split_tensor_t *)gate_shexp->extra : nullptr; - auto split_down_shexp = down_shexp ? (ggml_split_tensor_t *)down_shexp->extra : nullptr; GGML_ASSERT((!split_up_shexp && !split_gate_shexp && !split_down_shexp) || ( split_up_shexp && split_gate_shexp && split_down_shexp)); auto split_gate_inp = (ggml_split_tensor_t *)gate_inp->extra; GGML_ASSERT(split_gate_inp && split_gate_inp->n_device == split_up_exps->n_device); auto split_exp_probs_b = exp_probs_b ? (ggml_split_tensor_t *)exp_probs_b->extra : nullptr; GGML_ASSERT(!split_exp_probs_b || split_exp_probs_b->n_device == split_up_exps->n_device); - if (gate_inp_b || up_exps_b || gate_exps_b || down_exps_b) printf("Have expert biases %p, %p, %p, %p\n", (void *)gate_inp_b, (void *)up_exps_b, (void *)gate_exps_b, (void *)down_exps_b); for (int id = 0; id < split_up_exps->n_device; ++id) { int il_cb = 1000*(id + 1) + il; auto cur = input; @@ -1147,9 +1190,6 @@ llm_expert_gating_func_type gating_op, cb(routed_out, "routed_out", il_cb); if (split_up_shexp) { - auto split_up_b_shexp = up_b_shexp ? (ggml_split_tensor_t *)up_b_shexp : nullptr; - auto split_gate_b_shexp = gate_b_shexp ? (ggml_split_tensor_t *)gate_b_shexp : nullptr; - auto split_down_b_shexp = down_b_shexp ? (ggml_split_tensor_t *)down_b_shexp : nullptr; GGML_ASSERT(!split_up_b_shexp || split_up_b_shexp->n_device == split_up_exps->n_device); GGML_ASSERT(!split_gate_b_shexp || split_gate_b_shexp->n_device == split_up_exps->n_device); GGML_ASSERT(!split_down_b_shexp || split_down_b_shexp->n_device == split_up_exps->n_device); @@ -1499,12 +1539,12 @@ std::tuple llm_build_context::llm_buil cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); if (q_norm) { - Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); ggml_build_forward_expand(gf, Qcur); } if (k_norm) { - Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); ggml_build_forward_expand(gf, Kcur); } @@ -1536,12 +1576,12 @@ std::tuple llm_build_context::llm_buil cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); if (q_norm) { - Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il); cb(Qcur, "Qcur_normed", il); ggml_build_forward_expand(gf, Qcur); } if (k_norm) { - Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); ggml_build_forward_expand(gf, Kcur); } @@ -1559,7 +1599,7 @@ std::tuple llm_build_context::llm_buil auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, K->ne[0]/n_embd_head, n_tokens); if (k_norm) { - Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + Kcur = llm_build_norm(ctx0, Kcur, hparams, k_norm, NULL, LLM_NORM_RMS, cb, il); cb(Kcur, "Kcur_normed", il); } auto Vcur = V; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index fa2155fc..e5951b5b 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #define LLAMA_API_INTERNAL @@ -159,6 +160,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { ggml_context * ctx_output; ggml_context * ctx_output_split; + std::unordered_set split_tensors; + inline ggml_context * ctx_for_buft(ggml_backend_buffer_type_t buft) { if (auto it = ctx_map.find(buft); it != ctx_map.end()) return it->second; @@ -292,7 +295,7 @@ static std::vector create_split(int nr, int granularity, const std::vector< ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector & ne, int flags, ggml_context ** actual_context) { - //auto requested_ctx = ctx; + auto requested_ctx = ctx; if (ml.tensor_buft_overrides) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); @@ -305,39 +308,11 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std } if (actual_context) *actual_context = ctx; auto tensor = ml.create_tensor(ctx, name, ne, flags); - //if (tensor && requested_ctx == ctx && model.split_mode == LLAMA_SPLIT_MODE_GRAPH) { - // int i_layer = -1; - // if (auto pos = name.find("blk."); pos == 0) { - // GGML_ASSERT(sscanf(name.c_str(), "blk.%d.", &i_layer) == 1); - // } - // if (i_layer >= 0) { - // auto & layer = model.layers[i_layer]; - // auto & hparams = model.hparams; - // if (auto pos = name.find("attn_q.weight"); pos != std::string::npos) { - // auto split = create_split(tensor->ne[1], hparams.n_embd_head_k, model.splits); - // printf("%s(%s):", __func__, name.c_str()); - // for (auto s : split) printf(" %d", s); - // printf("\n"); - // layer.split_wq.tensor_splits.resize(split.size()); - // size_t offset = 0; - // for (int i = 0; i < (int)split.size(); ++i) { - // if (split[i] > 0) { - // layer.split_wq.tensor_splits[i] = ggml_view_2d(ctx, tensor, tensor->ne[0], split[i], tensor->nb[1], offset); - // auto split_name = name + '.' + std::to_string(i); - // ggml_set_name(layer.split_wq.tensor_splits[i], split_name.c_str()); - // offset += tensor->nb[1]*split[i]; - // } else { - // layer.split_wq.tensor_splits[i] = nullptr; - // } - // } - // layer.split_wq.ggml.n_device = split.size(); - // layer.split_wq.ggml.split_dim = 1; - // layer.split_wq.ggml.splits = layer.split_wq.tensor_splits.data(); - // } - // } - //} + if (tensor && ctx == requested_ctx) { + printf("%s: adding tensor %s to split tensors\n", __func__, tensor->name); + split_tensors.insert(tensor); + } return tensor; - //return ml.create_tensor(ctx, name, ne, flags); } #define LOADING_PRELUDE \ @@ -2998,41 +2973,45 @@ bool create_tensors_helper::create_tensors() { prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used); } + //bool any_ffn_split = false; if (layer.ffn_down_shexp && layer.ffn_up_shexp && layer.ffn_gate_shexp) { - int ffn_granularity = 16; - if (ggml_is_quantized(layer.ffn_down_shexp->type)) { - auto tt = ggml_internal_get_type_traits(layer.ffn_down_shexp->type); - if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size; + bool use_split = split_tensors.find(layer.ffn_down_shexp) != split_tensors.end() && + split_tensors.find(layer.ffn_gate_shexp) != split_tensors.end() && + split_tensors.find(layer.ffn_up_shexp) != split_tensors.end(); + if (use_split) { + //any_ffn_split = true; + int ffn_granularity = 16; + if (ggml_is_quantized(layer.ffn_down_shexp->type)) { + auto tt = ggml_internal_get_type_traits(layer.ffn_down_shexp->type); + if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size; + } + auto split = create_split(layer.ffn_down_shexp->ne[0], ffn_granularity, model.splits); + prepare_split_tensors(0, ctx_split, layer.ffn_down_shexp, layer.split_ffn_down_shexp, split, mem_used); + prepare_split_tensors(1, ctx_split, layer.ffn_up_shexp, layer.split_ffn_up_shexp, split, mem_used); + prepare_split_tensors(1, ctx_split, layer.ffn_gate_shexp, layer.split_ffn_gate_shexp, split, mem_used); } - auto split = create_split(layer.ffn_down_shexp->ne[0], ffn_granularity, model.splits); - prepare_split_tensors(0, ctx_split, layer.ffn_down_shexp, layer.split_ffn_down_shexp, split, mem_used); - prepare_split_tensors(1, ctx_split, layer.ffn_up_shexp, layer.split_ffn_up_shexp, split, mem_used); - prepare_split_tensors(1, ctx_split, layer.ffn_gate_shexp, layer.split_ffn_gate_shexp, split, mem_used); } if (layer.ffn_down_exps && layer.ffn_up_exps && layer.ffn_gate_exps) { - int ffn_granularity = 16; - if (ggml_is_quantized(layer.ffn_down_exps->type)) { - auto tt = ggml_internal_get_type_traits(layer.ffn_down_exps->type); - if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size; - } - auto split = create_split(layer.ffn_down_exps->ne[0], ffn_granularity, model.splits); - prepare_split_tensors(0, ctx_split, layer.ffn_down_exps, layer.split_ffn_down_exps, split, mem_used); - prepare_split_tensors(1, ctx_split, layer.ffn_up_exps, layer.split_ffn_up_exps, split, mem_used); - prepare_split_tensors(1, ctx_split, layer.ffn_gate_exps, layer.split_ffn_gate_exps, split, mem_used); - //printf("=== Layer %d routed experts, %s, %s, %s:\n", il, ggml_type_name(layer.ffn_down_exps->type), ggml_type_name(layer.ffn_gate_exps->type), ggml_type_name(layer.ffn_up_exps->type)); - //printf("mem_used:"); for (auto mem : mem_used) printf(" %8.2f", mem/1024./1024.); - //printf(" MiB\n"); - //printf(" down:"); - //for (auto split : layer.split_ffn_down_exps.tensor_splits) printf(" %ldx%ldx%ld", split->ne[0], split->ne[1], split->ne[2]); - //printf("\n"); - //printf(" gate:"); - //for (auto split : layer.split_ffn_gate_exps.tensor_splits) printf(" %ldx%ldx%ld", split->ne[0], split->ne[1], split->ne[2]); - //printf("\n"); - //printf(" up:"); - //for (auto split : layer.split_ffn_up_exps.tensor_splits) printf(" %ldx%ldx%ld", split->ne[0], split->ne[1], split->ne[2]); - //printf("\n"); + bool use_split = split_tensors.find(layer.ffn_down_exps) != split_tensors.end() && + split_tensors.find(layer.ffn_gate_exps) != split_tensors.end() && + split_tensors.find(layer.ffn_up_exps) != split_tensors.end(); + if (use_split) { + //any_ffn_split = true; + int ffn_granularity = 16; + if (ggml_is_quantized(layer.ffn_down_exps->type)) { + auto tt = ggml_internal_get_type_traits(layer.ffn_down_exps->type); + if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size; + } + auto split = create_split(layer.ffn_down_exps->ne[0], ffn_granularity, model.splits); + prepare_split_tensors(0, ctx_split, layer.ffn_down_exps, layer.split_ffn_down_exps, split, mem_used); + prepare_split_tensors(1, ctx_split, layer.ffn_up_exps, layer.split_ffn_up_exps, split, mem_used); + prepare_split_tensors(1, ctx_split, layer.ffn_gate_exps, layer.split_ffn_gate_exps, split, mem_used); + } + } + + //if (any_ffn_split) { if (layer.ffn_gate_inp) { auto shared_split = create_split(ggml_nrows(layer.ffn_gate_inp), -1, model.splits); prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp, layer.split_ffn_gate_inp, shared_split, mem_used); @@ -3041,7 +3020,7 @@ bool create_tensors_helper::create_tensors() { auto shared_split = create_split(ggml_nrows(layer.ffn_exp_probs_b), -1, model.splits); prepare_split_tensors(-1, ctx_split, layer.ffn_exp_probs_b, layer.split_ffn_exp_probs_b, shared_split, mem_used); } - } + //} } if (model.output) {