mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
POC: CUDA tensor parallel (MoE models) (#1022)
* Remove most of split mode row
* WIP
* WIP: also allocate the KV cache using tensor split
* WIP: it runs with wrong result
But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
wait for GPU 0 to finish its entore FFN calculation before it can
start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
* WIP
* This works, but it is slow
* This is slightly better
the graph is still not being computed in parallel.
Why? Because the scheduler creates graph splits where the
result of the computation on one GPU becomes an input for the
other split. Hence, to trigger the computation on the second GPU
one needs to wait for the computation on the first GPU to finish,
even thiough the two can be done in parallel up to the sunchronization
point. So, all that is left to do is to trick the scheduler to create
to splits that can be done in parallel, and then have a graph split
where the results get combined.
* Playing games with the scheduler
This change tricks it into doing the right thing^TM.
Still quite a bit slower than split mode layer for the 8B LlaMA model.
But for the 70B LlaMA it now beats split mode layer for TG:
28 t/s vs 24.4 t/s. PP is 627 t/s vs 744 t/s.
In comparison, split mode "row" in mainline gets
484 t/s PP and 19.3 t/s TG.
* Fix attn split
Granularity for Wq, Wo is not just head size, but
head size * gqa_ratio.
Else the Wk, Wv tensors end up not being a multiple of the
head size when we divide the split determined by Wo with
the gqa_ratio.
* Show memory used per device
* Make it work with partial offload
but no tensor overrides yet, just ngl < num_layers.
* Allow for f16 source in fused_rms_norm
* This results in faster PP.
Now PP is faster than split mode layer for L3-70B.
* Rename split mode "row" to split mode "graph"
* Leave FFN partial results as f16
* WIP GLM4.5 - runs with wrong results
* WIP GLM4.5 - this works
PP is already better than split mode layer, but TG for zero context
is kind of low - 60 vs 92 t/s. TG becomes better than split mode layer
at around 20k tokens. PP at 26k tokens is 1.55X of sm layer.
* Work around compiler bug
It issues a warning that there is an extra semicolon outside of a function,
but there isn't. If I remove the anonymous namespace and turn the
functions inside into static, the warning disapears, so clearly
a compiler bug.
* Make graph reuse work with split mode graph
* Remove more split mode row remnants
* WIP tensor overrides
Runs with wrong results, don't see where the issue could be.
* This works but is slow
Still does not work for row-interleaved quants
* Slightly better
* Slightly better
* Row-interleaved quants work
* Better
* Minor
* Guarad against using split mode "graph" for unsupported models
* Guards against using merge_qkv with split mode "graph"
* WIP split mode attn
Works for LlaMA models, but not for GLM-4.5.
Doesn't seem to improve performance, so I guess no point in trying to
fix it.
* Split mode graph for qwen3moe
* Try to better distribute the splits
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -1276,12 +1276,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
else if (arg_next == "layer") {
|
||||
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
}
|
||||
else if (arg_next == "row") {
|
||||
fprintf(stderr, "\n\n=====================================================================================\n");
|
||||
fprintf(stderr, " Split mode row is no longer supported\n");
|
||||
fprintf(stderr, "=====================================================================================\n\n\n");
|
||||
GGML_ABORT("fatal error");
|
||||
params.split_mode = LLAMA_SPLIT_MODE_ROW;
|
||||
else if (arg_next == "attn") {
|
||||
params.split_mode = LLAMA_SPLIT_MODE_ATTN;
|
||||
}
|
||||
else if (arg_next == "graph") {
|
||||
params.split_mode = LLAMA_SPLIT_MODE_GRAPH;
|
||||
}
|
||||
else {
|
||||
invalid_param = true;
|
||||
@@ -2249,6 +2248,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||
options.push_back({ "*", "-sm, --split-mode SPLIT_MODE",
|
||||
"how to split the model across multiple GPUs, one of:\n"
|
||||
" - none: use one GPU only\n"
|
||||
" - graph: split model tensors and computation graph across GPUs\n"
|
||||
" - layer (default): split layers and KV across GPUs\n" });
|
||||
options.push_back({ "*", "-ts, --tensor-split SPLIT",
|
||||
"fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1" });
|
||||
|
||||
@@ -217,7 +217,7 @@ static const char * split_mode_str(llama_split_mode mode) {
|
||||
switch (mode) {
|
||||
case LLAMA_SPLIT_MODE_NONE: return "none";
|
||||
case LLAMA_SPLIT_MODE_LAYER: return "layer";
|
||||
case LLAMA_SPLIT_MODE_ROW: return "row";
|
||||
case LLAMA_SPLIT_MODE_GRAPH: return "graph";
|
||||
default: GGML_ABORT("invalid split mode");
|
||||
}
|
||||
}
|
||||
@@ -334,7 +334,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
||||
printf(" --n-cpu-moe <n> (default: none)\n");
|
||||
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
|
||||
printf(" -sm, --split-mode <none|layer> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||
printf(" -sm, --split-mode <none|row|layer> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||
@@ -630,12 +630,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
mode = LLAMA_SPLIT_MODE_NONE;
|
||||
} else if (m == "layer") {
|
||||
mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
} else if (m == "row") {
|
||||
fprintf(stderr, "\n\n=======================================================================\n");
|
||||
fprintf(stderr, "Split mode 'row' is no longer supported\n");
|
||||
fprintf(stderr, "=======================================================================\n\n\n");
|
||||
invalid_param = true;
|
||||
break;
|
||||
} else if (m == "graph") {
|
||||
mode = LLAMA_SPLIT_MODE_GRAPH;
|
||||
} else {
|
||||
invalid_param = true;
|
||||
break;
|
||||
|
||||
@@ -3021,6 +3021,13 @@ extern "C" {
|
||||
|
||||
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
||||
|
||||
typedef struct {
|
||||
int n_device;
|
||||
int split_dim;
|
||||
struct ggml_tensor * tensor;
|
||||
struct ggml_tensor ** splits;
|
||||
} ggml_split_tensor_t;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -43,7 +43,7 @@ GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buf
|
||||
// get_alloc_size is optional, defaults to ggml_nbytes
|
||||
if (buft->iface.get_alloc_size) {
|
||||
size_t size = buft->iface.get_alloc_size(buft, tensor);
|
||||
assert(size >= ggml_nbytes(tensor));
|
||||
//assert(size >= ggml_nbytes(tensor));
|
||||
return size;
|
||||
}
|
||||
return ggml_nbytes(tensor);
|
||||
@@ -1216,8 +1216,10 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
|
||||
return -1;
|
||||
}
|
||||
|
||||
//printf("%s: have %d backends, buffer is %s\n", __func__, sched->n_backends, ggml_backend_buffer_name(buffer));
|
||||
// find highest prio backend that supports the buffer type and the op
|
||||
for (int i = 0; i < sched->n_backends; i++) {
|
||||
//printf(" Checking bacckend %d (%s)\n", i, ggml_backend_name(sched->backends[i]));
|
||||
if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&
|
||||
ggml_backend_supports_op(sched->backends[i], op)) {
|
||||
return i;
|
||||
@@ -1393,6 +1395,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
// do not overwrite user assignments
|
||||
if (*leaf_backend_id == -1) {
|
||||
*leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
|
||||
//printf("Pass 1: assigned backend %d to leaf %d, %s\n", *leaf_backend_id, i, graph->leafs[i]->name);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1402,6 +1405,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
// do not overwrite user assignments
|
||||
if (*node_backend_id == -1) {
|
||||
*node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
|
||||
//printf("Pass 1: assigned backend %d to node %d, %s(%s)\n", *node_backend_id, i, ggml_op_name(node->op), node->name);
|
||||
|
||||
#if 0
|
||||
// src
|
||||
@@ -1445,6 +1449,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
cur_backend_id = *node_backend_id;
|
||||
}
|
||||
} else if (cur_backend_id != -1) {
|
||||
//printf("(u1) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
|
||||
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
|
||||
}
|
||||
}
|
||||
@@ -1466,6 +1471,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
cur_backend_id = *node_backend_id;
|
||||
}
|
||||
} else if (cur_backend_id != -1) {
|
||||
//printf("(d1) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
|
||||
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
|
||||
}
|
||||
}
|
||||
@@ -1482,6 +1488,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
if (*node_backend_id != -1) {
|
||||
cur_backend_id = *node_backend_id;
|
||||
} else if (cur_backend_id != -1) {
|
||||
//printf("(u2) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
|
||||
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
|
||||
}
|
||||
}
|
||||
@@ -1498,6 +1505,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
if (*node_backend_id != -1) {
|
||||
cur_backend_id = *node_backend_id;
|
||||
} else if (cur_backend_id != -1) {
|
||||
//printf("(d2) invoking ggml_backend_sched_set_if_supported for node %d, %s with cur_backend_id = %d, node_backend_id = %d\n", i, node->name, cur_backend_id, *node_backend_id);
|
||||
ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
|
||||
}
|
||||
}
|
||||
@@ -1535,6 +1543,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
if (n_supported > n_supported_best) {
|
||||
n_supported_best = n_supported;
|
||||
*node_backend_id = b;
|
||||
//printf("Pass 3: assigned backend %d to unassigned node %d, %s\n", b, i, node->name);
|
||||
SET_CAUSE(node, "3.best");
|
||||
}
|
||||
}
|
||||
@@ -1555,6 +1564,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
}
|
||||
}
|
||||
if (supported) {
|
||||
//printf("Pass 3: assigned backend %d to node %d, %s previously assigned to backend %d\n", b, i, node->name, *node_backend_id);
|
||||
*node_backend_id = b;
|
||||
SET_CAUSE(node, "3.upg");
|
||||
break;
|
||||
@@ -1583,9 +1593,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
// views are always on the same backend as the source
|
||||
*src_backend_id = tensor_backend_id(src->view_src);
|
||||
SET_CAUSE(src, "4.vsrc");
|
||||
//printf("Pass 4: assigned backend %d to src %d, %s in node %d, %s frpm view_src\n", *src_backend_id, j, src->name, i, node->name);
|
||||
} else {
|
||||
*src_backend_id = *cur_backend_id;
|
||||
SET_CAUSE(src, "4.cur");
|
||||
//printf("Pass 4: assigned backend %d to src %d, %s in node %d, %s frpm current\n", *src_backend_id, j, src->name, i, node->name);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1620,7 +1632,10 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
|
||||
// check if we should start a new split based on the sources of the current node
|
||||
bool need_new_split = false;
|
||||
if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
|
||||
if (node->op == GGML_OP_ADD && node->op_params[0] == 0xff) {
|
||||
need_new_split = true;
|
||||
}
|
||||
else if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,17 +18,8 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M_R4> {
|
||||
static constexpr int qi = 4;
|
||||
};
|
||||
|
||||
// Reminder:
|
||||
// constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
// constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
// constexpr int vdr = get_vdr_mmvq(type);
|
||||
|
||||
// QI4_XS = 256/(4*2) = 32
|
||||
// vdr = 4, qi = 32 -> qi/vdr = 8, kqs = 4*(tid%8), blocks_per_iter = 4*1*32/32 = 4
|
||||
// vdr = 2, qi = 32 -> qi/vdr =16, kqs = 2*(tid%16), blocks_per_iter = 2*1*32/32 = 2
|
||||
namespace {
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_mul_mat_vec_q_kerne(
|
||||
static __device__ void iqk_mul_mat_vec_q_kerne(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy,
|
||||
const float * bias, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
|
||||
@@ -110,7 +101,7 @@ __device__ void iqk_mul_mat_vec_q_kerne(
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
|
||||
__device__ void iqk_fused_mul_mat_vec_q_kernel(
|
||||
static __device__ void iqk_fused_mul_mat_vec_q_kernel(
|
||||
const void * __restrict__ vup, const void * __restrict__ vgate, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const float * __restrict__ bias_u, const float * __restrict__ bias_g,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
@@ -228,7 +219,7 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y,
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_mul_mat_vec_q(
|
||||
static __global__ void iqk_mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const char * __restrict__ ids_data, const void * __restrict__ bias,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
@@ -248,7 +239,7 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y,
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__global__ void iqk_fused_mul_mat_vec_q(
|
||||
static __global__ void iqk_fused_mul_mat_vec_q(
|
||||
const void * __restrict__ vx_u, const void * __restrict__ vx_g, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const char * __restrict__ ids_data, const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
@@ -269,7 +260,7 @@ __global__ void iqk_fused_mul_mat_vec_q(
|
||||
}
|
||||
|
||||
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int n_interleaved = 1>
|
||||
void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
static void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(args.ncols_x % ggml_blck_size(type) == 0);
|
||||
//GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
||||
@@ -428,7 +419,7 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
|
||||
static __device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
|
||||
int & val1, int & val2) {
|
||||
|
||||
uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
|
||||
@@ -476,7 +467,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t * values) {
|
||||
static __device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t * values) {
|
||||
uint16_t v1 = values[a8[0]] | (values[a8[1]] << 8);
|
||||
uint16_t v2 = values[a8[2]] | (values[a8[3]] << 8);
|
||||
return v1 | (v2 << 16);
|
||||
@@ -506,8 +497,6 @@ __device__ __forceinline__ int int_from_table(const uint8_t * a8, const uint8_t
|
||||
#define VDR_IQ3_K_Q8_1_MMVQ 4
|
||||
#define VDR_IQ3_K_Q8_1_MMQ 4
|
||||
|
||||
} // namespace
|
||||
|
||||
extern void mul_mat_vec_iq2_k_q8_1_cuda(const mmvq_args & args, cudaStream_t stream);
|
||||
extern void mul_mat_vec_iq3_k_q8_1_cuda(const mmvq_args & args, cudaStream_t stream);
|
||||
extern void mul_mat_vec_iq4_k_q8_1_cuda(const mmvq_args & args, cudaStream_t stream);
|
||||
|
||||
@@ -176,15 +176,15 @@ static __global__ void rms_norm_f32_nc(
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
|
||||
template <int block_size, typename src_t>
|
||||
static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
const float xi = (float)x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@@ -206,13 +206,13 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
|
||||
dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
template <int block_size, typename src_t>
|
||||
static __global__ void fused_rms_norm_f32_nc(
|
||||
const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const src_t * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
@@ -229,7 +229,7 @@ static __global__ void fused_rms_norm_f32_nc(
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = (float)x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@@ -257,7 +257,7 @@ static __global__ void fused_rms_norm_f32_nc(
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * y[col] * x[col];
|
||||
dst[col] = scale * y[col] * (float)x[col];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,7 +307,8 @@ static void rms_norm_f32_nc_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
|
||||
template <typename src_t>
|
||||
static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst,
|
||||
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 256;
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
@@ -331,8 +332,9 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t>
|
||||
static void fused_rms_norm_f32_nc_cuda(
|
||||
const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const src_t * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
if (ncols < 1024) {
|
||||
@@ -432,7 +434,7 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
||||
@@ -445,14 +447,22 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
if (ggml_is_contiguous(src0)) {
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
} else {
|
||||
auto ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(src0->nb[0] == ts0);
|
||||
auto s01 = src0->nb[1] / ts0;
|
||||
auto s02 = src0->nb[2] / ts0;
|
||||
auto s03 = src0->nb[3] / ts0;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
} else {
|
||||
fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7284,7 +7284,19 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl(
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result;
|
||||
if (inplace) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
result = ggml_view_tensor(ctx, a);
|
||||
} else {
|
||||
if (a->type == GGML_TYPE_F32) {
|
||||
result = ggml_dup_tensor(ctx, a);
|
||||
} else {
|
||||
result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||
|
||||
|
||||
@@ -275,7 +275,8 @@ extern "C" {
|
||||
enum llama_split_mode {
|
||||
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
|
||||
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
|
||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||
LLAMA_SPLIT_MODE_ATTN = 2, // splits self-attention computations across GPUs
|
||||
LLAMA_SPLIT_MODE_GRAPH = 3, // splits computations across GPUs
|
||||
};
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -148,7 +148,7 @@ struct llm_build_context {
|
||||
ggml_tensor * wq, ggml_tensor * bq,
|
||||
ggml_tensor * wk, ggml_tensor * bk,
|
||||
ggml_tensor * wv, ggml_tensor * bv,
|
||||
float attention_scale, int il);
|
||||
float attention_scale, int il) const;
|
||||
|
||||
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_mul_mat_qkv(ggml_cgraph * gf, ggml_tensor * cur,
|
||||
ggml_tensor * wqkv, ggml_tensor * bqkv,
|
||||
@@ -156,7 +156,7 @@ struct llm_build_context {
|
||||
ggml_tensor * wq, ggml_tensor * bq,
|
||||
ggml_tensor * wk, ggml_tensor * bk,
|
||||
ggml_tensor * wv, ggml_tensor * bv,
|
||||
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il);
|
||||
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) const;
|
||||
|
||||
ggml_cgraph * build_llama();
|
||||
|
||||
@@ -317,7 +317,7 @@ struct llm_build_context {
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0);
|
||||
|
||||
static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx,
|
||||
static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx, ggml_tensor * ffn_norm,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * up,
|
||||
ggml_tensor * up_b,
|
||||
@@ -331,7 +331,7 @@ struct llm_build_context {
|
||||
ggml_tensor * act_scales,
|
||||
llm_ffn_op_type type_op,
|
||||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb & cb, int il);
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr);
|
||||
|
||||
static ggml_tensor * llm_build_moe_ffn(ggml_context * ctx, llama_context & lctx,
|
||||
ggml_tensor * cur,
|
||||
@@ -375,6 +375,27 @@ llm_expert_gating_func_type gating_op,
|
||||
gating_op, cb, il, graph);
|
||||
}
|
||||
|
||||
static ggml_tensor * llm_build_std_moe_ffn(ggml_context * ctx, llama_context & lctx,
|
||||
ggml_tensor * ffn_norm,
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * gate_inp, ggml_tensor * gate_inp_b,
|
||||
ggml_tensor * up_exps, ggml_tensor * up_exps_b,
|
||||
ggml_tensor * gate_exps, ggml_tensor * gate_exps_b,
|
||||
ggml_tensor * down_exps, ggml_tensor * down_exps_b,
|
||||
ggml_tensor * exp_probs_b,
|
||||
ggml_tensor * up_shexp, ggml_tensor * up_b_shexp,
|
||||
ggml_tensor * gate_shexp, ggml_tensor * gate_b_shexp,
|
||||
ggml_tensor * down_shexp, ggml_tensor * down_b_shexp,
|
||||
int64_t n_expert,
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llm_expert_gating_func_type gating_op,
|
||||
llm_ffn_op_type type_op_shexp,
|
||||
const llm_build_cb & cb, int il, ggml_cgraph * graph);
|
||||
|
||||
static ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids);
|
||||
|
||||
static ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx);
|
||||
@@ -383,4 +404,7 @@ llm_expert_gating_func_type gating_op,
|
||||
|
||||
static ggml_cgraph * llama_build_graph(llama_context & lctx, const llama_batch & batch, bool worst_case);
|
||||
|
||||
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
|
||||
ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il);
|
||||
|
||||
};
|
||||
|
||||
@@ -57,6 +57,9 @@ struct llama_kv_cache {
|
||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||
std::vector<struct ggml_tensor *> v_l;
|
||||
|
||||
std::vector<llama_split_tensor> split_k_l;
|
||||
std::vector<llama_split_tensor> split_v_l;
|
||||
|
||||
std::vector<struct ggml_context *> ctxs;
|
||||
std::vector<ggml_backend_buffer_t> bufs;
|
||||
|
||||
|
||||
@@ -224,3 +224,8 @@ struct gguf_context;
|
||||
std::string gguf_kv_to_str(const gguf_context * ctx_gguf, int i);
|
||||
|
||||
ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer);
|
||||
|
||||
struct llama_split_tensor {
|
||||
std::vector<ggml_tensor *> tensor_splits;
|
||||
ggml_split_tensor_t ggml;
|
||||
};
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <array>
|
||||
#include <future>
|
||||
#include <regex>
|
||||
#include <unordered_set>
|
||||
|
||||
#define LLAMA_API_INTERNAL
|
||||
|
||||
@@ -139,7 +140,7 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
ggml_context ** actual_ctx = nullptr);
|
||||
|
||||
void create_default_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool norm_bias);
|
||||
void create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm = true);
|
||||
void create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm = true, bool use_ctx_split = false);
|
||||
|
||||
void create_std_attn(int i, const LLM_TN & tn, llama_layer & layer, int n_embd, int n_embd_gqa, ggml_context * ctx_split);
|
||||
void create_std_ffn(int i, const LLM_TN & tn, llama_layer & layer, int n_ff, int n_embd, ggml_context * ctx_split);
|
||||
@@ -153,12 +154,15 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
ggml_context * split_ctx = nullptr;
|
||||
size_t ctx_size;
|
||||
|
||||
ggml_context * ctx_input;
|
||||
ggml_context * ctx_output;
|
||||
ggml_context * ctx_output_split;
|
||||
|
||||
std::unordered_set<ggml_tensor *> split_tensors;
|
||||
|
||||
inline ggml_context * ctx_for_buft(ggml_backend_buffer_type_t buft) {
|
||||
if (auto it = ctx_map.find(buft); it != ctx_map.end()) return it->second;
|
||||
|
||||
@@ -179,6 +183,14 @@ struct create_tensors_helper : public create_tensors_helper_interface {
|
||||
|
||||
create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_model & _model) : ml(_ml), model(_model) {
|
||||
|
||||
#if 0
|
||||
for (int i = 0; i < model.hparams.n_layer; ++i) {
|
||||
printf("Layer %2d: %s %s\n", i, ggml_backend_buft_name(model.buft_layer[i].buft_matrix), ggml_backend_buft_name(model.buft_layer[i].buft));
|
||||
}
|
||||
printf("Output: %s %s\n", ggml_backend_buft_name(model.buft_output.buft_matrix), ggml_backend_buft_name(model.buft_output.buft));
|
||||
printf(" Input: %s %s\n", ggml_backend_buft_name(model.buft_input.buft_matrix), ggml_backend_buft_name(model.buft_input.buft));
|
||||
#endif
|
||||
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
buft_layer_count[model.buft_input.buft]++;
|
||||
buft_layer_count[model.buft_input.buft_matrix]++;
|
||||
@@ -192,6 +204,11 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
|
||||
ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output
|
||||
ctx_size += ggml_tensor_overhead()*n_layer*3; // for moe merged tensors
|
||||
|
||||
if (model.splits.size() > 1) {
|
||||
ctx_size += ggml_tensor_overhead()*n_layer*4; // for KV cache
|
||||
ctx_size *= (model.splits.size() + 1);
|
||||
}
|
||||
|
||||
for (auto & it : buft_layer_count) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
@@ -205,10 +222,95 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod
|
||||
ctx_map[it.first] = ctx;
|
||||
model.ctxs.push_back(ctx);
|
||||
}
|
||||
if (model.split_buft) {
|
||||
if (auto it = ctx_map.find(model.split_buft); it != ctx_map.end()) {
|
||||
split_ctx = it->second;
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
printf("=======================================================================\n");
|
||||
auto n_device = model.device_count();
|
||||
printf(" Model has %d devices:\n", n_device);
|
||||
for (int device = 0; device < n_device; ++device) {
|
||||
auto buft = model.default_buffer_type_offload(device);
|
||||
if (buft) {
|
||||
printf(" %d %s\n", device, ggml_backend_buft_name(buft));
|
||||
} else {
|
||||
printf(" Oops: null buft for debvice %d\n", device);
|
||||
}
|
||||
}
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
printf("model.splits:");
|
||||
for (auto s : model.splits) printf(" %g", s);
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::vector<int> create_split(int nr, int granularity, const std::vector<float> & splits, const std::vector<size_t> & mem_used) {
|
||||
GGML_ASSERT(nr % granularity == 0);
|
||||
GGML_ASSERT(!splits.empty());
|
||||
if (granularity < 0) return std::vector<int>(splits.size(), nr);
|
||||
GGML_ASSERT(mem_used.size() == splits.size());
|
||||
size_t tot_memory_used = 1;
|
||||
for (auto & mem : mem_used) tot_memory_used += mem;
|
||||
int nchunk = nr / granularity;
|
||||
std::vector<int> result(splits.size());
|
||||
float last_split = 0;
|
||||
int sum = 0;
|
||||
for (int i = 0; i < (int)splits.size(); ++i) {
|
||||
float p = splits[i] - last_split;
|
||||
p += (p - 1.f*mem_used[i]/tot_memory_used);
|
||||
result[i] = roundf(p*nchunk);
|
||||
if (result[i] < 0) result[i] = 0;
|
||||
sum += result[i];
|
||||
last_split = splits[i];
|
||||
}
|
||||
while (sum > nchunk) {
|
||||
last_split = 0;
|
||||
float best_err = std::numeric_limits<float>::max();
|
||||
int ibest = -1;
|
||||
for (int i = 0; i < (int)splits.size(); ++i) {
|
||||
if (result[i] > 0) {
|
||||
float p = splits[i] - last_split;
|
||||
float n_want = p*nchunk;
|
||||
float err = std::abs(n_want - result[i] + 1);
|
||||
//float err = std::abs(n_want - result[i] + 1) + std::abs(p - 1.f*mem_used[i]/tot_memory_used)*nchunk;
|
||||
if (err < best_err) {
|
||||
best_err = err; ibest = i;
|
||||
}
|
||||
}
|
||||
last_split = splits[i];
|
||||
}
|
||||
GGML_ASSERT(ibest >= 0 && result[ibest] > 0);
|
||||
--result[ibest];
|
||||
--sum;
|
||||
}
|
||||
while (sum < nchunk) {
|
||||
last_split = 0;
|
||||
float best_err = std::numeric_limits<float>::max();
|
||||
int ibest = -1;
|
||||
for (int i = 0; i < (int)splits.size(); ++i) {
|
||||
float p = splits[i] - last_split;
|
||||
float n_want = p*nchunk;
|
||||
float err = std::abs(n_want - result[i] - 1);
|
||||
//float err = std::abs(n_want - result[i] - 1) + std::abs(p - 1.f*mem_used[i]/tot_memory_used)*nchunk;
|
||||
if (err < best_err) {
|
||||
best_err = err; ibest = i;
|
||||
}
|
||||
last_split = splits[i];
|
||||
}
|
||||
GGML_ASSERT(ibest >= 0);
|
||||
++result[ibest];
|
||||
++sum;
|
||||
}
|
||||
for (auto & r : result) r *= granularity;
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne,
|
||||
int flags, ggml_context ** actual_context) {
|
||||
//auto requested_ctx = ctx;
|
||||
if (ml.tensor_buft_overrides) {
|
||||
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
|
||||
std::regex pattern(overrides->pattern);
|
||||
@@ -220,7 +322,12 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std
|
||||
}
|
||||
}
|
||||
if (actual_context) *actual_context = ctx;
|
||||
return ml.create_tensor(ctx, name, ne, flags);
|
||||
auto tensor = ml.create_tensor(ctx, name, ne, flags);
|
||||
if (tensor && ctx == split_ctx) {
|
||||
//printf("%s: adding tensor %s to split tensors\n", __func__, tensor->name);
|
||||
split_tensors.insert(tensor);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#define LOADING_PRELUDE \
|
||||
@@ -251,17 +358,18 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std
|
||||
bool use_mmap_buffer = true;
|
||||
|
||||
|
||||
void create_tensors_helper::create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm) {
|
||||
void create_tensors_helper::create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm, bool use_ctx_split) {
|
||||
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
auto out_ctx = use_ctx_split ? ctx_output_split : ctx_output;
|
||||
if (has_norm) {
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm = create_tensor(out_ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
}
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.output = create_tensor(out_ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
model.output = create_tensor(out_ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +388,7 @@ void create_tensors_helper::create_std_ffn(int i, const LLM_TN & tn, llama_layer
|
||||
|
||||
bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
LOADING_PRELUDE
|
||||
create_embd_output(tn, n_embd, n_vocab);
|
||||
create_embd_output(tn, n_embd, n_vocab, true, true);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
@@ -288,7 +396,7 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
use_mmap_buffer &= !merge_qkv(tn, i, 1);
|
||||
|
||||
@@ -297,12 +405,12 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) {
|
||||
// optional bias tensors
|
||||
layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm = create_tensor(model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_freqs = create_tensor(ctx_split, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
|
||||
if (n_expert == 0) {
|
||||
create_std_ffn(i, tn, layer, n_ff, n_embd, ctx_split);
|
||||
create_std_ffn(i, tn, layer, n_ff, n_embd, model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer);
|
||||
|
||||
// optional MLP bias
|
||||
layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
@@ -1043,11 +1151,11 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1057,18 +1165,19 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
use_mmap_buffer &= !merge_qkv(tn, i, 0);
|
||||
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||
|
||||
layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
|
||||
layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
|
||||
layer.attn_k_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
|
||||
layer.attn_q_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
|
||||
|
||||
layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
|
||||
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
|
||||
if (n_expert == 0) {
|
||||
throw std::runtime_error("n_expert must be > 0 for QWEN3MOE");
|
||||
@@ -1080,9 +1189,9 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) {
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||
layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||
layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
}
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
@@ -1734,7 +1843,7 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers");
|
||||
GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers");
|
||||
|
||||
create_embd_output(tn, n_embd, n_vocab);
|
||||
create_embd_output(tn, n_embd, n_vocab, true, true);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
@@ -1748,7 +1857,7 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
layer.attn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
// GLM-style attention with bias terms
|
||||
if (!flags) {
|
||||
@@ -1765,12 +1874,17 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
|
||||
|
||||
// K/Q norm tensors (optional for GLM-4.5 355B variant)
|
||||
layer.attn_q_norm = create_tensor(ctx_layer,
|
||||
layer.attn_q_norm = create_tensor(ctx_split,
|
||||
tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
|
||||
layer.attn_k_norm = create_tensor(ctx_layer,
|
||||
layer.attn_k_norm = create_tensor(ctx_split,
|
||||
tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED | flags);
|
||||
|
||||
layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
auto ffn_ctx = model.split_mode == LLAMA_SPLIT_MODE_GRAPH ? ctx_split : ctx_layer;
|
||||
|
||||
// Why are we adding an additional tensor type?
|
||||
// attn_post_norm is the exact same thing as ffn_norm
|
||||
//layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
layer.ffn_norm = create_tensor(ffn_ctx, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
|
||||
|
||||
// Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
|
||||
// GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
|
||||
@@ -1778,35 +1892,35 @@ bool create_tensors_helper::create_glm4_moe_tensors(const LLM_TN & tn) {
|
||||
|
||||
if (use_moe) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
|
||||
layer.ffn_gate_inp = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
|
||||
// gate bias
|
||||
layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
|
||||
layer.ffn_exp_probs_b = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(ctx_split,
|
||||
layer.ffn_gate_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
|
||||
layer.ffn_down_exps = create_tensor(ctx_split,
|
||||
layer.ffn_down_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
|
||||
layer.ffn_up_exps = create_tensor(ctx_split,
|
||||
layer.ffn_up_exps = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
|
||||
|
||||
// Shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_gate_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
|
||||
layer.ffn_down_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_down_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
|
||||
layer.ffn_up_shexp = create_tensor(ctx_split,
|
||||
layer.ffn_up_shexp = create_tensor(ffn_ctx,
|
||||
tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
|
||||
}
|
||||
} else {
|
||||
// Dense layers (first k layers) - GLM uses separate gate/up projections
|
||||
layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
|
||||
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_gate = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
|
||||
layer.ffn_down = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
|
||||
layer.ffn_up = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags);
|
||||
}
|
||||
// --- NextN / MTP tensors (preserved but unused), on the final layer ---
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
@@ -2629,18 +2743,77 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias, bool i
|
||||
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
if (bias) {
|
||||
auto flags = bias == 1 ? llama_model_loader::TENSOR_NOT_REQUIRED : 0;
|
||||
layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {layer.wq->ne[1]}, flags);
|
||||
layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {layer.wk->ne[1]}, flags);
|
||||
layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]}, flags);
|
||||
layer.bq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "bias", i), {layer.wq->ne[1]}, flags);
|
||||
layer.bk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "bias", i), {layer.wk->ne[1]}, flags);
|
||||
layer.bv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "bias", i), {layer.wv->ne[1]}, flags);
|
||||
}
|
||||
}
|
||||
|
||||
return fused_qkv;
|
||||
}
|
||||
|
||||
static void prepare_split_tensors(int split_dim, ggml_context * ctx, ggml_tensor * tensor, llama_split_tensor & split_tensor,
|
||||
const std::vector<int> & splits, std::vector<size_t> & mem_used) {
|
||||
GGML_ASSERT(split_dim <= 1);
|
||||
GGML_ASSERT(splits.size() > 1);
|
||||
std::string name{tensor->name};
|
||||
split_tensor.tensor_splits.resize(splits.size());
|
||||
if (split_dim < 0) {
|
||||
for (int i = 0; i < int(splits.size()); ++i) {
|
||||
if (splits[i] > 0) {
|
||||
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2]);
|
||||
auto name_i = name + '.' + std::to_string(i);
|
||||
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
|
||||
} else {
|
||||
split_tensor.tensor_splits[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (split_dim == 1) {
|
||||
for (int i = 0; i < int(splits.size()); ++i) {
|
||||
if (splits[i] > 0) {
|
||||
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, tensor->ne[0], splits[i], tensor->ne[2]);
|
||||
auto name_i = name + '.' + std::to_string(i);
|
||||
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
|
||||
} else {
|
||||
split_tensor.tensor_splits[i] = nullptr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < int(splits.size()); ++i) {
|
||||
if (splits[i] > 0) {
|
||||
split_tensor.tensor_splits[i] = ggml_new_tensor_3d(ctx, tensor->type, splits[i], tensor->ne[1], tensor->ne[2]);
|
||||
auto name_i = name + '.' + std::to_string(i);
|
||||
ggml_set_name(split_tensor.tensor_splits[i], name_i.c_str());
|
||||
} else {
|
||||
split_tensor.tensor_splits[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
split_tensor.ggml.n_device = splits.size();
|
||||
split_tensor.ggml.split_dim = split_dim;
|
||||
split_tensor.ggml.splits = split_tensor.tensor_splits.data();
|
||||
tensor->extra = (void *)&split_tensor.ggml;
|
||||
GGML_ASSERT(mem_used.size() >= splits.size());
|
||||
for (int i = 0; i < split_tensor.ggml.n_device; ++i) {
|
||||
if (split_tensor.ggml.splits[i]) {
|
||||
//auto nbytes = ggml_nbytes(split_tensor.ggml.splits[i]);
|
||||
//printf("mem_used(%s): %8.2f, total: %8.2f\n", split_tensor.ggml.splits[i]->name, nbytes/1024./1024., (mem_used[i] + nbytes)/1024./1024.);
|
||||
mem_used[i] += ggml_nbytes(split_tensor.ggml.splits[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool create_tensors_helper::create_tensors() {
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
bool use_mmap_buffer = true;
|
||||
if (ml.merge_qkv && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) {
|
||||
LLAMA_LOG_WARN("\n========================================================\n");
|
||||
LLAMA_LOG_WARN("merge_qkv is not compatible with split model 'graph'\n");
|
||||
LLAMA_LOG_WARN(" => turning off merge_qkv\n");
|
||||
LLAMA_LOG_WARN("========================================================\n\n");
|
||||
ml.merge_qkv = false;
|
||||
}
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_REFACT:
|
||||
@@ -2761,6 +2934,157 @@ bool create_tensors_helper::create_tensors() {
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
if (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
std::vector<size_t> mem_used(model.splits.size(), 0);
|
||||
const auto & hparams = model.hparams;
|
||||
int gqa_ratio = hparams.n_head() / hparams.n_head_kv();
|
||||
//printf("GQA ratio: %d\n", gqa_ratio);
|
||||
for (int il = 0; il < int(model.layers.size()); ++il) {
|
||||
if (ggml_backend_buft_is_host(model.buft_layer[il].buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting layer %d because buffer type is host\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
auto & layer = model.layers[il];
|
||||
auto ctx_split = ctx_for_layer_split(il);
|
||||
if (layer.attn_norm) {
|
||||
auto split = create_split(ggml_nrows(layer.attn_norm), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_norm, layer.split_attn_norm, split, mem_used);
|
||||
}
|
||||
if (layer.rope_freqs) {
|
||||
auto split = create_split(ggml_nrows(layer.rope_freqs), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.rope_freqs, layer.split_rope_freqs, split, mem_used);
|
||||
}
|
||||
if (layer.wo && layer.wq && layer.wk && layer.wv) {
|
||||
int attn_granularity = hparams.n_embd_head_k * gqa_ratio;
|
||||
if (ggml_is_quantized(layer.wo->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.wo->type);
|
||||
if (tt.blck_size > attn_granularity) attn_granularity = tt.blck_size;
|
||||
}
|
||||
GGML_ASSERT(attn_granularity % hparams.n_embd_head_k == 0);
|
||||
auto split = create_split(layer.wo->ne[0], attn_granularity, model.splits, mem_used);
|
||||
prepare_split_tensors(0, ctx_split, layer.wo, layer.split_wo, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wq, layer.split_wq, split, mem_used);
|
||||
if (layer.bo) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.bo, layer.split_bo, split, mem_used);
|
||||
}
|
||||
if (layer.bq) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bq, layer.split_bq, split, mem_used);
|
||||
}
|
||||
if (layer.attn_q_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_q_norm, layer.split_q_norm, split, mem_used);
|
||||
}
|
||||
for (auto & s : split) s /= gqa_ratio;
|
||||
prepare_split_tensors(1, ctx_split, layer.wk, layer.split_wk, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.wv, layer.split_wv, split, mem_used);
|
||||
if (layer.bk) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bk, layer.split_bk, split, mem_used);
|
||||
}
|
||||
if (layer.bv) {
|
||||
prepare_split_tensors(0, ctx_split, layer.bv, layer.split_bv, split, mem_used);
|
||||
}
|
||||
if (layer.attn_k_norm) {
|
||||
prepare_split_tensors(-1, ctx_split, layer.attn_k_norm, layer.split_k_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_norm) {
|
||||
if (auto it = split_tensors.find(layer.ffn_norm); it != split_tensors.end()) {
|
||||
auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up) != split_tensors.end();
|
||||
if (use_split) {
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down->ne[0], ffn_granularity, model.splits, mem_used);
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down, layer.split_ffn_down, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up, layer.split_ffn_up, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate, layer.split_ffn_gate, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
//bool any_ffn_split = false;
|
||||
if (layer.ffn_down_shexp && layer.ffn_up_shexp && layer.ffn_gate_shexp) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down_shexp) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate_shexp) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up_shexp) != split_tensors.end();
|
||||
if (use_split) {
|
||||
//any_ffn_split = true;
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down_shexp->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down_shexp->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down_shexp->ne[0], ffn_granularity, model.splits, mem_used);
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down_shexp, layer.split_ffn_down_shexp, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up_shexp, layer.split_ffn_up_shexp, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate_shexp, layer.split_ffn_gate_shexp, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_down_exps && layer.ffn_up_exps && layer.ffn_gate_exps) {
|
||||
bool use_split = split_tensors.find(layer.ffn_down_exps) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_gate_exps) != split_tensors.end() &&
|
||||
split_tensors.find(layer.ffn_up_exps) != split_tensors.end();
|
||||
|
||||
if (use_split) {
|
||||
//any_ffn_split = true;
|
||||
int ffn_granularity = 16;
|
||||
if (ggml_is_quantized(layer.ffn_down_exps->type)) {
|
||||
auto tt = ggml_internal_get_type_traits(layer.ffn_down_exps->type);
|
||||
if (tt.blck_size > ffn_granularity) ffn_granularity = tt.blck_size;
|
||||
}
|
||||
auto split = create_split(layer.ffn_down_exps->ne[0], ffn_granularity, model.splits, mem_used);
|
||||
//printf("split(%2d):", il); for (auto & s : split) printf(" %d", s); printf("\n");
|
||||
prepare_split_tensors(0, ctx_split, layer.ffn_down_exps, layer.split_ffn_down_exps, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_up_exps, layer.split_ffn_up_exps, split, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, layer.ffn_gate_exps, layer.split_ffn_gate_exps, split, mem_used);
|
||||
}
|
||||
}
|
||||
|
||||
if (layer.ffn_gate_inp) {
|
||||
if (auto it = split_tensors.find(layer.ffn_gate_inp); it != split_tensors.end()) {
|
||||
auto shared_split = create_split(ggml_nrows(layer.ffn_gate_inp), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_gate_inp, layer.split_ffn_gate_inp, shared_split, mem_used);
|
||||
}
|
||||
}
|
||||
if (layer.ffn_exp_probs_b) {
|
||||
if (auto it = split_tensors.find(layer.ffn_exp_probs_b); it != split_tensors.end()) {
|
||||
auto shared_split = create_split(ggml_nrows(layer.ffn_exp_probs_b), -1, model.splits, mem_used);
|
||||
prepare_split_tensors(-1, ctx_split, layer.ffn_exp_probs_b, layer.split_ffn_exp_probs_b, shared_split, mem_used);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.output) {
|
||||
if (auto it = split_tensors.find(model.output); it != split_tensors.end()) {
|
||||
if (ggml_backend_buft_is_host(model.buft_output.buft_matrix)) {
|
||||
LLAMA_LOG_INFO("%s: not splitting output tensor becausee buffer is host\n", __func__);
|
||||
} else {
|
||||
auto ctx_split = ctx_map[model.buft_output.buft_matrix];
|
||||
auto split = create_split(model.output->ne[1], 16, model.splits, mem_used);
|
||||
prepare_split_tensors(1, ctx_split, model.output, model.split_output, split, mem_used);
|
||||
if (auto it = split_tensors.find(model.output_norm); it != split_tensors.end() && !ggml_backend_buft_is_host(model.buft_output.buft_matrix)) {
|
||||
auto ctx_split = ctx_map[model.buft_output.buft_matrix];
|
||||
prepare_split_tensors(-1, ctx_split, model.output_norm, model.split_output_norm, split, mem_used);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("Estimated model buffer size per device:\n");
|
||||
for (int i = 0; i < int(mem_used.size()); ++i) {
|
||||
LLAMA_LOG_INFO(" Device %d: %8.2f MiB\n", i, mem_used[i]/1024./1024.);
|
||||
}
|
||||
}
|
||||
return use_mmap_buffer;
|
||||
}
|
||||
|
||||
|
||||
@@ -183,6 +183,24 @@ struct llama_layer {
|
||||
struct ggml_tensor * bqk = nullptr;
|
||||
struct ggml_tensor * bkv = nullptr;
|
||||
|
||||
llama_split_tensor split_attn_norm;
|
||||
llama_split_tensor split_wq;
|
||||
llama_split_tensor split_wk;
|
||||
llama_split_tensor split_wv;
|
||||
llama_split_tensor split_wo;
|
||||
llama_split_tensor split_wqkv;
|
||||
llama_split_tensor split_wqk;
|
||||
llama_split_tensor split_wkv;
|
||||
llama_split_tensor split_bq;
|
||||
llama_split_tensor split_bk;
|
||||
llama_split_tensor split_bv;
|
||||
llama_split_tensor split_bo;
|
||||
llama_split_tensor split_bqkv;
|
||||
llama_split_tensor split_bqk;
|
||||
llama_split_tensor split_bkv;
|
||||
llama_split_tensor split_q_norm;
|
||||
llama_split_tensor split_k_norm;
|
||||
|
||||
// relative position bias
|
||||
struct ggml_tensor * attn_rel_b = nullptr;
|
||||
struct ggml_tensor * attn_rel_b_enc = nullptr;
|
||||
@@ -205,12 +223,22 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_down_enc = nullptr;
|
||||
struct ggml_tensor * ffn_up_enc = nullptr;
|
||||
|
||||
llama_split_tensor split_ffn_up;
|
||||
llama_split_tensor split_ffn_gate;
|
||||
llama_split_tensor split_ffn_down;
|
||||
llama_split_tensor split_ffn_norm;
|
||||
|
||||
// ff MoE
|
||||
struct ggml_tensor * ffn_gate_inp = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps = nullptr;
|
||||
|
||||
llama_split_tensor split_ffn_gate_inp;
|
||||
llama_split_tensor split_ffn_up_exps;
|
||||
llama_split_tensor split_ffn_gate_exps;
|
||||
llama_split_tensor split_ffn_down_exps;
|
||||
|
||||
// ff MoE bias
|
||||
struct ggml_tensor * ffn_gate_inp_b = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps_b = nullptr;
|
||||
@@ -226,6 +254,15 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_down_shexp = nullptr;
|
||||
struct ggml_tensor * ffn_up_shexp = nullptr;
|
||||
|
||||
llama_split_tensor split_ffn_up_shexp;
|
||||
llama_split_tensor split_ffn_gate_shexp;
|
||||
llama_split_tensor split_ffn_down_shexp;
|
||||
|
||||
llama_split_tensor split_ffn_gate_inp_b;
|
||||
llama_split_tensor split_ffn_gate_exps_b;
|
||||
llama_split_tensor split_ffn_down_exps_b;
|
||||
llama_split_tensor split_ffn_up_exps_b;
|
||||
|
||||
// ff bias
|
||||
struct ggml_tensor * ffn_gate_b = nullptr;
|
||||
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
||||
@@ -233,6 +270,12 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_act = nullptr;
|
||||
struct ggml_tensor * ffn_exp_probs_b = nullptr;
|
||||
|
||||
llama_split_tensor split_ffn_gate_b;
|
||||
llama_split_tensor split_ffn_down_b;
|
||||
llama_split_tensor split_ffn_up_b;
|
||||
llama_split_tensor split_ffn_act;
|
||||
llama_split_tensor split_ffn_exp_probs_b;
|
||||
|
||||
// mamba proj
|
||||
struct ggml_tensor * ssm_in = nullptr;
|
||||
struct ggml_tensor * ssm_x = nullptr;
|
||||
@@ -253,6 +296,8 @@ struct llama_layer {
|
||||
struct ggml_tensor * rope_short = nullptr;
|
||||
struct ggml_tensor * rope_freqs = nullptr;
|
||||
|
||||
llama_split_tensor split_rope_freqs;
|
||||
|
||||
// bitnet scale
|
||||
struct ggml_tensor * wq_scale = nullptr;
|
||||
struct ggml_tensor * wk_scale = nullptr;
|
||||
@@ -298,6 +343,9 @@ struct llama_model {
|
||||
struct ggml_tensor * output_b;
|
||||
struct ggml_tensor * output_norm_enc;
|
||||
|
||||
llama_split_tensor split_output;
|
||||
llama_split_tensor split_output_norm;
|
||||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
llama_split_mode split_mode;
|
||||
@@ -358,6 +406,12 @@ struct llama_model {
|
||||
}
|
||||
|
||||
void set_tensor_overrides(const llama_model_params& params);
|
||||
|
||||
int device_count() const;
|
||||
ggml_backend_buffer_type_t default_buffer_type_offload(int device) const;
|
||||
|
||||
std::vector<float> splits;
|
||||
ggml_backend_buffer_type_t split_buft = nullptr;
|
||||
};
|
||||
|
||||
struct llama_lora_weight {
|
||||
|
||||
266
src/llama.cpp
266
src/llama.cpp
@@ -108,6 +108,7 @@
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
@@ -460,18 +461,18 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
|
||||
GGML_UNUSED(gpu);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
|
||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu) {
|
||||
ggml_backend_buffer_type_t buft = nullptr;
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ggml_backend_cuda_get_device_count() > 1) {
|
||||
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
|
||||
buft = ggml_backend_cuda_split_buffer_type(model.splits.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
if (ggml_backend_sycl_get_device_count() > 1) {
|
||||
buft = ggml_backend_sycl_split_buffer_type(tensor_split);
|
||||
buft = ggml_backend_sycl_split_buffer_type(model.splits.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -480,7 +481,14 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo
|
||||
}
|
||||
return buft;
|
||||
|
||||
GGML_UNUSED(tensor_split);
|
||||
}
|
||||
|
||||
int llama_model::device_count() const {
|
||||
return llama_get_device_count(*this);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t llama_model::default_buffer_type_offload(int device) const {
|
||||
return llama_default_buffer_type_offload(*this, device);
|
||||
}
|
||||
|
||||
static size_t llama_get_device_memory(const llama_model & model, int device) {
|
||||
@@ -548,9 +556,34 @@ bool llama_context::can_reuse_graph(const llama_batch & u_batch) {
|
||||
}
|
||||
|
||||
bool llama_context::update_cache_copies() {
|
||||
int n_layer = cache_copies.size()/2;
|
||||
int n_layer = model.hparams.n_layer - model.hparams.nextn_predict_layers; //cache_copies.size()/2;
|
||||
if ((int)kv_self.k_l.size() != n_layer) return false;
|
||||
if (!(kv_self.v_l.empty() || (int)kv_self.v_l.size() == n_layer)) return false;
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
|
||||
auto vl = !kv_self.v_l.empty() && kv_self.v_l[il] ? (ggml_split_tensor_t *)kv_self.v_l[il]->extra : nullptr;
|
||||
GGML_ASSERT(kl && (!kv_self.v_l[il] || vl));
|
||||
if (vl) {
|
||||
GGML_ASSERT(kl->n_device == vl->n_device);
|
||||
}
|
||||
for (int id = 0; id < kl->n_device; ++id) {
|
||||
auto& c = cache_copies[2*model.splits.size()*il + 2*id + 0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kl->splits[id]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)kl->splits[id]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
if (!vl) continue;
|
||||
for (int id = 0; id < vl->n_device; ++id) {
|
||||
auto& c = cache_copies[2*model.splits.size()*il + 2*id + 1];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != vl->splits[id]) return false;
|
||||
c.cpy->view_offs = kv_self.head*c.step;
|
||||
c.cpy->src[1]->data = (char *)vl->splits[id]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto& c = cache_copies[2*il+0];
|
||||
if (!c.cpy || c.cpy->op != GGML_OP_CPY || c.cpy->view_src != kv_self.k_l[il]) return false;
|
||||
@@ -566,14 +599,19 @@ bool llama_context::update_cache_copies() {
|
||||
c.cpy->src[1]->data = (char *)kv_self.v_l[il]->data + c.cpy->view_offs;
|
||||
c.cpy->data = c.cpy->src[1]->data;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_context::llama_context(const llama_model & model)
|
||||
: model(model) , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {
|
||||
const auto & hparams = model.hparams;
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
cache_copies.resize(2*model.splits.size()*hparams.n_layer);
|
||||
} else {
|
||||
cache_copies.resize(2*hparams.n_layer);
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
ggml_backend_sched_free(sched);
|
||||
@@ -626,42 +664,35 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
bool split_cache = false;
|
||||
if ((model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN) && model.arch != LLM_ARCH_DEEPSEEK2 && offload) {
|
||||
cache.split_k_l.reserve(n_layer);
|
||||
cache.split_v_l.reserve(n_layer);
|
||||
split_cache = true;
|
||||
}
|
||||
|
||||
// count used buffer types
|
||||
std::map<ggml_backend_buffer_type_t, int> buft_layer_count;
|
||||
if (offload) {
|
||||
for (int64_t i = 0; i < n_layer; ++i) {
|
||||
if (split_cache) {
|
||||
buft_layer_count[model.buft_layer[i].buft_matrix]++;
|
||||
} else {
|
||||
buft_layer_count[model.buft_layer[i].buft]++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
buft_layer_count[llama_default_buffer_type_cpu(true)] = n_layer;
|
||||
}
|
||||
|
||||
//if (cparams.fused_moe_up_gate) {
|
||||
// int nbad = 0;
|
||||
// for (int i = 0; i < (int) n_layer; i++) {
|
||||
// auto& layer = model.layers[i];
|
||||
// if (layer.ffn_gate_exps && layer.ffn_up_exps && layer.ffn_gate_exps->type != layer.ffn_up_exps->type) {
|
||||
// ++nbad;
|
||||
// }
|
||||
// }
|
||||
// if (nbad > 0) {
|
||||
// if (nbad == (int)n_layer) {
|
||||
// LLAMA_LOG_WARN("=============== ffn_up and ffn_gate are of different type => disabling fmoe\n");
|
||||
// const_cast<llama_cparams&>(cparams).fused_moe_up_gate = false;
|
||||
// }
|
||||
// else {
|
||||
// LLAMA_LOG_WARN("=============== ffn_up and ffn_gate are of different in %d out of %d layers, where fmoe will be disabled\n",
|
||||
// nbad, (int)n_layer);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
for (auto & it : buft_layer_count) {
|
||||
int n_layers = it.second;
|
||||
size_t ctx_mem_size = 5u*n_layers*ggml_tensor_overhead();
|
||||
if (split_cache) ctx_mem_size += 2*model.splits.size()*n_layers*ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(),
|
||||
/*.mem_size =*/ ctx_mem_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
@@ -698,24 +729,25 @@ static bool llama_kv_cache_init(
|
||||
}
|
||||
}
|
||||
|
||||
cache.k_l.reserve(n_layer);
|
||||
bool needs_v_cache = true;
|
||||
cache.k_l.reserve(n_layer);
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
|
||||
needs_v_cache = cparams.mla_attn == 1 && !cparams.flash_attn;
|
||||
}
|
||||
if (needs_v_cache) cache.v_l.reserve(n_layer);
|
||||
|
||||
std::vector<size_t> mem_split(model.splits.size(), 0);
|
||||
|
||||
int n_mla = 0;
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
const uint32_t n_head_kv = hparams.n_head_kv(i);
|
||||
const uint32_t n_embd_head_k= hparams.n_embd_head_k;
|
||||
|
||||
|
||||
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
struct ggml_context * ctx = split_cache ? ctx_map.at(model.buft_layer[i].buft_matrix) : offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
if (cparams.mla_attn) {
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) {
|
||||
// DeepSeek MLA
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
@@ -740,10 +772,53 @@ static bool llama_kv_cache_init(
|
||||
else {
|
||||
k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size);
|
||||
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
auto k_name = std::string{"cache_k_l"} + std::to_string(i);
|
||||
auto v_name = std::string{"cache_v_l"} + std::to_string(i);
|
||||
ggml_set_name(k, k_name.c_str());
|
||||
ggml_set_name(v, v_name.c_str());
|
||||
//ggml_format_name(k, "cache_k_l%d", i);
|
||||
//ggml_format_name(v, "cache_v_l%d", i);
|
||||
cache.k_l.push_back(k);
|
||||
cache.v_l.push_back(v);
|
||||
if (split_cache) {
|
||||
auto K = model.layers[i].wk;
|
||||
auto V = model.layers[i].wv;
|
||||
if (K && V && K->extra && V->extra) {
|
||||
auto extra_K = (const ggml_split_tensor_t *)K->extra;
|
||||
auto extra_V = (const ggml_split_tensor_t *)V->extra;
|
||||
auto & split_k_l = cache.split_k_l.emplace_back();
|
||||
auto & split_v_l = cache.split_v_l.emplace_back();
|
||||
split_k_l.tensor_splits.resize(extra_K->n_device, nullptr);
|
||||
split_v_l.tensor_splits.resize(extra_V->n_device, nullptr);
|
||||
for (int is = 0; is < extra_K->n_device; ++is) {
|
||||
auto split = extra_K->splits[is];
|
||||
if (!split) continue;
|
||||
split_k_l.tensor_splits[is] = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, split->ne[1]/n_embd_head_k * kv_size);
|
||||
auto split_name = k_name + '.' + std::to_string(is);
|
||||
ggml_set_name(split_k_l.tensor_splits[is], split_name.c_str());
|
||||
mem_split[is] += ggml_nbytes(split_k_l.tensor_splits[is]);
|
||||
}
|
||||
split_k_l.ggml.n_device = extra_K->n_device;
|
||||
split_k_l.ggml.split_dim = 0;
|
||||
split_k_l.ggml.splits = split_k_l.tensor_splits.data();
|
||||
for (int is = 0; is < extra_V->n_device; ++is) {
|
||||
auto split = extra_V->splits[is];
|
||||
if (!split) continue;
|
||||
split_v_l.tensor_splits[is] = ggml_new_tensor_1d(ctx, type_v, split->ne[1] * kv_size);
|
||||
auto split_name = v_name + '.' + std::to_string(is);
|
||||
ggml_set_name(split_v_l.tensor_splits[is], split_name.c_str());
|
||||
mem_split[is] += ggml_nbytes(split_v_l.tensor_splits[is]);
|
||||
}
|
||||
split_v_l.ggml.n_device = extra_V->n_device;
|
||||
split_v_l.ggml.split_dim = 0;
|
||||
split_v_l.ggml.splits = split_v_l.tensor_splits.data();
|
||||
k->extra = (void *)&split_k_l.ggml;
|
||||
v->extra = (void *)&split_v_l.ggml;
|
||||
}
|
||||
//} else {
|
||||
// printf("Oops: don't have yet K and V for layer %d\n", i);
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) {
|
||||
@@ -756,6 +831,11 @@ static bool llama_kv_cache_init(
|
||||
for (auto it : ctx_map) {
|
||||
ggml_backend_buffer_type_t buft = it.first;
|
||||
ggml_context * ctx = it.second;
|
||||
int ntensor = 0;
|
||||
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
++ntensor;
|
||||
}
|
||||
if (ntensor > 0) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (!buf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
|
||||
@@ -765,6 +845,32 @@ static bool llama_kv_cache_init(
|
||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
cache.bufs.push_back(buf);
|
||||
}
|
||||
}
|
||||
if (split_cache) {
|
||||
LLAMA_LOG_INFO("%s: KV cache size per device:\n", __func__);
|
||||
for (int i = 0; i < int(mem_split.size()); ++i) printf(" Device %d: %g MiB\n", i, mem_split[i]/1024./1024.);
|
||||
}
|
||||
|
||||
#if 0
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
if (cache.k_l[il]->extra) {
|
||||
printf("Layer %2d, K-buffer: %p:", il, (void *)cache.k_l[il]->buffer);
|
||||
auto split_kl = (ggml_split_tensor_t *)cache.k_l[il]->extra;
|
||||
for (int id = 0; id < split_kl->n_device; ++id) {
|
||||
if (split_kl->splits[id]) printf(" %p,%p", (void *)split_kl->splits[id]->data, (void *)split_kl->splits[id]->buffer);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
if (cache.v_l[il]->extra) {
|
||||
printf("Layer %2d, V-buffer: %p:", il, (void *)cache.v_l[il]->buffer);
|
||||
auto split_vl = (ggml_split_tensor_t *)cache.v_l[il]->extra;
|
||||
for (int id = 0; id < split_vl->n_device; ++id) {
|
||||
if (split_vl->splits[id]) printf(" %p,%p", (void *)split_vl->splits[id]->data, (void *)split_vl->splits[id]->buffer);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1617,6 +1723,16 @@ static void ggml_backend_add_from_device(llama_context* ctx, ggml_backend_t back
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_model_split_supported(const llama_model & model) {
|
||||
static std::unordered_set<llm_arch> k_supported = {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_QWEN3MOE,
|
||||
LLM_ARCH_GLM4_MOE,
|
||||
};
|
||||
auto it = k_supported.find(model.arch);
|
||||
return it != k_supported.end();
|
||||
}
|
||||
|
||||
// Returns false if cancelled by progress_callback
|
||||
static bool llm_load_tensors(
|
||||
llama_model_loader & ml,
|
||||
@@ -1634,6 +1750,16 @@ static bool llm_load_tensors(
|
||||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
if (split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
if (!is_model_split_supported(model)) {
|
||||
LLAMA_LOG_WARN("\n=======================================================\n");
|
||||
LLAMA_LOG_WARN("Split mode 'graph' is not supported for this model\n");
|
||||
LLAMA_LOG_WARN(" => changing split mode to 'layer'\n");
|
||||
LLAMA_LOG_WARN("=======================================================\n\n");
|
||||
split_mode = LLAMA_SPLIT_MODE_LAYER;
|
||||
}
|
||||
}
|
||||
|
||||
model.split_mode = split_mode;
|
||||
model.main_gpu = main_gpu;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
@@ -1652,10 +1778,7 @@ static bool llm_load_tensors(
|
||||
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
|
||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||
// calculate the split points
|
||||
// int device_count = llama_get_device_count(model);
|
||||
int device_count = model.devices.size();
|
||||
if (int device_count = model.devices.size(); device_count > 1) {
|
||||
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
|
||||
std::vector<float> splits(device_count);
|
||||
if (all_zero) {
|
||||
@@ -1676,46 +1799,47 @@ static bool llm_load_tensors(
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
splits[i] /= split_sum;
|
||||
}
|
||||
model.splits = std::move(splits);
|
||||
} else {
|
||||
model.splits = { 1.0f };
|
||||
}
|
||||
|
||||
int device_count = model.splits.size();
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
|
||||
#ifndef NDEBUG
|
||||
ggml_backend_buffer_type_t buft = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
const char* name = ggml_backend_buft_name(buft);
|
||||
LLAMA_LOG_DEBUG("load_tensors: layers %3d assigned to backend %s\n", i,
|
||||
name);
|
||||
#endif
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_layer[i] = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
}
|
||||
// assign the output layer
|
||||
if (n_gpu_layers > n_layer) {
|
||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
|
||||
#ifndef NDEBUG
|
||||
ggml_backend_buffer_type_t buft = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
const char* name = ggml_backend_buft_name(buft);
|
||||
LLAMA_LOG_DEBUG("load_tensors: output layers assigned to backend %s\n",
|
||||
name);
|
||||
#endif
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_output = llama_default_buffer_type_offload(model, model.devices[layer_gpu]);
|
||||
} else {
|
||||
model.buft_output = llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
} else {
|
||||
ggml_backend_buffer_type_t split_buft;
|
||||
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu], tensor_split);
|
||||
if ((split_mode == LLAMA_SPLIT_MODE_GRAPH || split_mode == LLAMA_SPLIT_MODE_ATTN) && model.splits.size() > 1) {
|
||||
split_buft = llama_default_buffer_type_split(model, model.devices[main_gpu]);
|
||||
model.split_buft = split_buft;
|
||||
} else {
|
||||
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
||||
split_buft = llama_default_buffer_type_offload(model, model.devices[main_gpu]);
|
||||
}
|
||||
auto buft_layer = llama_default_buffer_type_offload(model, model.devices[main_gpu]);
|
||||
// assign the repeating layers
|
||||
for (int i = i_gpu_start; i < n_layer; ++i) {
|
||||
model.buft_layer[i] = {
|
||||
split_buft,
|
||||
llama_default_buffer_type_offload(model, model.devices[main_gpu])
|
||||
};
|
||||
if (split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
int layer_gpu = std::upper_bound(model.splits.begin(), model.splits.begin() + device_count,
|
||||
float(i - i_gpu_start)/act_gpu_layers) - model.splits.begin();
|
||||
model.buft_layer[i] = { split_buft, llama_default_buffer_type_offload(model, model.devices[layer_gpu]) };
|
||||
printf("Layer %d: assigning buft_layer to GPU %d\n", i, layer_gpu);
|
||||
} else {
|
||||
model.buft_layer[i] = { split_buft, buft_layer };
|
||||
}
|
||||
}
|
||||
// assign the output layer
|
||||
if (n_gpu_layers > n_layer) {
|
||||
@@ -1807,8 +1931,14 @@ static bool llm_load_tensors(
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
int ntensor = 0;
|
||||
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
++ntensor;
|
||||
}
|
||||
if (ntensor > 0) {
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (buf == nullptr) {
|
||||
LLAMA_LOG_ERROR("Failed to allocate buffer type %s\n", ggml_backend_buft_name(buft));
|
||||
throw std::runtime_error("unable to allocate backend buffer");
|
||||
}
|
||||
model.bufs.push_back(buf);
|
||||
@@ -1822,9 +1952,12 @@ static bool llm_load_tensors(
|
||||
bufs.emplace(idx, buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (bufs.empty()) {
|
||||
throw std::runtime_error("failed to allocate buffer");
|
||||
LLAMA_LOG_WARN("No tensors in buffer type %s\n", ggml_backend_buft_name(buft));
|
||||
continue;
|
||||
//throw std::runtime_error("failed to allocate buffer (1)");
|
||||
}
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
@@ -4326,8 +4459,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_backend_add_from_device(ctx, ctx->backend_metal);
|
||||
}
|
||||
#elif defined(GGML_USE_CUDA)
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_GRAPH, only the main GPU backend is used
|
||||
ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu, cparams.cuda_params);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
|
||||
@@ -4337,7 +4470,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_backend_add_from_device(ctx, backend);
|
||||
|
||||
} else {
|
||||
// LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
|
||||
// LLAMA_SPLIT_MODE_LAYER and LLAMA_SPLIT_MODE_GRAPH require a backend for each GPU
|
||||
for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
|
||||
ggml_backend_t backend = ggml_backend_cuda_init(device, cparams.cuda_params);
|
||||
if (backend == nullptr) {
|
||||
@@ -4346,12 +4479,11 @@ struct llama_context * llama_new_context_with_model(
|
||||
return nullptr;
|
||||
}
|
||||
ggml_backend_add_from_device(ctx, backend);
|
||||
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH || model->split_mode == LLAMA_SPLIT_MODE_ATTN) {
|
||||
LLAMA_LOG_ERROR("%s: split mode 'graph' or 'attn' not supported. Failed to initialize Vulkan backend\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -4375,8 +4507,8 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
}
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_GRAPH, only the main GPU backend is used
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
|
||||
@@ -4407,9 +4539,9 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_backend_add_from_device(ctx, backend);
|
||||
}
|
||||
#elif defined(GGML_USE_CANN)
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
|
||||
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_GRAPH, only the main GPU backend is used
|
||||
// TODO: ggml_backend_cann is not support split tensor now, just leave code here.
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||
ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
|
||||
|
||||
Reference in New Issue
Block a user