From 288a8cf842f486c73f0d14df2eaf5c35e5bd8892 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 17 Jan 2026 15:09:26 +0000 Subject: [PATCH] WIP: add Q8_0 and BF16 as possible reduce types Does not work - there is a big somewhere --- common/common.cpp | 38 ++++++++++++--- common/common.h | 4 +- ggml/src/ggml-cuda/norm.cu | 68 +++++++++++++++++++++----- ggml/src/ggml-cuda/reduce.cu | 77 +++++++++++++++++++++++------ include/llama.h | 3 +- src/llama-build-context.cpp | 95 ++++++++++++++++++++---------------- src/llama-cparams.h | 3 +- src/llama.cpp | 22 +++++++-- 8 files changed, 226 insertions(+), 84 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cfb96963..0b890b59 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1459,11 +1459,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-smf16" || arg == "--split-mode-f16") { - params.split_mode_f16 = true; + params.reduce_type = "f16"; + //params.split_mode_f16 = true; return true; } if (arg == "-smf32" || arg == "--split-mode-f32") { - params.split_mode_f16 = false; + params.reduce_type = "f32"; + //params.split_mode_f16 = false; + return true; + } + if (arg == "-grt" || arg == "--graph-reduce-type") { + CHECK_ARG + params.reduce_type = argv[i]; return true; } if (arg == "--numa") { @@ -2154,8 +2161,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv}); options.push_back({ "*", "-muge, --merge-up-gate-experts,","merge ffn_up/gate_exps (default: %d)", params.merge_up_gate_exps}); options.push_back({ "*", "-khad, --k-cache-hadamard,", "Use Hadamard transform for K-cache (default: %d)", params.k_cache_hadamard}); - options.push_back({ "*", "-smf16, --split-mode-f16,", "Use f16 for data exchange between GPUs (default: %d)", params.split_mode_f16}); - options.push_back({ "*", "-smf32, --split-mode-f32,", "Use f32 for data exchange between GPUs (default: %d)", !params.split_mode_f16}); + options.push_back({ "*", "-smf16, --split-mode-f16,", "Use f16 for data exchange between GPUs (default: %d)", true}); + options.push_back({ "*", "-smf32, --split-mode-f32,", "Use f32 for data exchange between GPUs (default: %d)", false}); + options.push_back({ "*", "-grt, --graph-reduce-type", "Type for data exchange between GPUs (default: %s)", "f32"}); options.push_back({ "*", "-smgs, --split-mode-graph-scheduling,", "Force Split Mode Graph Scheduling (default: %d)", params.split_mode_graph_scheduling}); options.push_back({ "*", "-sas, ==scheduler_async,", "Async evaluation of compute graphs: %d)", params.scheduler_async}); options.push_back({ "*", "-vq, --validate-quants", "validate quantized data while loading the model (default: %d)", params.validate_quants}); @@ -3148,6 +3156,22 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { throw std::runtime_error("Invalid cache type: " + s); } +static ggml_type ggml_type_from_str(const std::string & s) { + if (s == "f32") { + return GGML_TYPE_F32; + } + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "bf16") { + return GGML_TYPE_BF16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + throw std::runtime_error("Invalid graph reduce type: " + s); +} + struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); int n_batch = params.n_batch; @@ -3194,7 +3218,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.graph_reuse = params.graph_reuse; cparams.k_cache_hadamard = params.k_cache_hadamard; cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; - cparams.split_mode_f16 = params.split_mode_f16; + //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; @@ -3203,6 +3227,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); + cparams.type_reduce = ggml_type_from_str(params.reduce_type); if (!params.offload_policy.empty()) cparams.offload_policy = (void *)¶ms.offload_policy; if (!params.cuda_params.empty()) cparams.cuda_params = (void *)params.cuda_params.data(); @@ -4180,7 +4205,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false"); fprintf(stream, "k_cache_hadamard: %s # default: false\n", params.k_cache_hadamard ? "true" : "false"); fprintf(stream, "split_mode_graph_scheduling: %s # default: false\n", params.split_mode_graph_scheduling ? "true" : "false"); - fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false"); + //fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false"); + fprintf(stream, "reduce_type: %s # default f16\n", params.reduce_type.c_str()); fprintf(stream, "scheduler_async: %s # default: false\n", params.scheduler_async ? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index d73bfd28..569e189c 100644 --- a/common/common.h +++ b/common/common.h @@ -290,7 +290,7 @@ struct gpt_params { bool merge_up_gate_exps= false; // if true, merge ffn_up_exps and ffn_gate_exps into a single, contiguous tensor bool k_cache_hadamard = false; // if true, use Hadamard transform for the K-cache (only makes sense with quantized cache) bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling - bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops + //bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph std::string cache_type_k = "f16"; // KV cache data type for the K @@ -298,6 +298,8 @@ struct gpt_params { std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model + std::string reduce_type = "f16"; + // multimodal models (see examples/mtmd) model_paths mmproj; bool mmproj_use_gpu = true; // use GPU for multimodal model diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 99c69503..86f311cd 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -43,10 +43,20 @@ static __global__ void fused_norm_f32(const T * x, const float * c, float * dst, float2 mean_var = make_float2(0.f, 0.f); - for (int col = tid; col < ncols; col += block_size) { - const float xi = (float)x[row*ncols + col]; - mean_var.x += xi; - mean_var.y += xi * xi; + if constexpr (std::is_same_v) { + static_assert(block_size % QK8_0 == 0); + auto xr = x + (row*ncols)/QK8_0; + for (int col = tid; col < ncols; col += block_size) { + const float xi = (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + } else { + for (int col = tid; col < ncols; col += block_size) { + const float xi = (float)x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } } // sum up partial sums @@ -67,8 +77,16 @@ static __global__ void fused_norm_f32(const T * x, const float * c, float * dst, const float var = mean_var.y / ncols - mean * mean; const float inv_std = rsqrtf(var + eps); - for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std * c[col]); + if constexpr (std::is_same_v) { + static_assert(block_size % QK8_0 == 0); + auto xr = x + (row*ncols)/QK8_0; + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = ((float)xr[col/QK8_0].d*xr[col/QK8_0].qs[col%QK8_0] - mean) * inv_std * c[col]; + } + } else { + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = ((float)x[row*ncols + col] - mean) * inv_std * c[col]; + } } } @@ -219,9 +237,19 @@ static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, floa float tmp = 0.0f; // partial sum for thread in warp - for (int col = tid; col < ncols; col += block_size) { - const float xi = (float)x[row*ncols + col]; - tmp += xi * xi; + if constexpr (std::is_same_v) { + static_assert(block_size % QK8_0 == 0); + auto xr = x + (row*ncols)/QK8_0; + for (int col = tid; col < ncols; col += block_size) { + const float xi = (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0]; + tmp += xi * xi; + } + + } else { + for (int col = tid; col < ncols; col += block_size) { + const float xi = (float)x[row*ncols + col]; + tmp += xi * xi; + } } // sum up partial sums @@ -241,8 +269,15 @@ static __global__ void fused_rms_norm_f32(const src_t * x, const float * y, floa const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); - for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col]; + if constexpr (std::is_same_v) { + auto xr = x + (row*ncols)/QK8_0; + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * y[col] * (float)xr[col / QK8_0].d * xr[col / QK8_0].qs[col % QK8_0]; + } + } else { + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * y[col] * (float)x[row*ncols + col]; + } } } @@ -496,7 +531,8 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || + (ggml_is_contiguous(src0) && src0->type == GGML_TYPE_Q8_0)); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->ne[0] == src1->ne[0]); @@ -511,8 +547,12 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t nrows = ggml_nrows(src0); if (src0->type == GGML_TYPE_F32) { fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream); - } else { + } else if (src0->type == GGML_TYPE_F16) { fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream); + } else if (src0->type == GGML_TYPE_Q8_0) { + fused_rms_norm_f32_cuda((const block_q8_0 *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream); + } else { + fused_rms_norm_f32_cuda((const nv_bfloat16 *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream); } } else { if (is_norm) { @@ -525,6 +565,8 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * auto s03 = src0->nb[3] / ts0; if (src0->type == GGML_TYPE_F32) { fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } else if (src0->type == GGML_TYPE_BF16) { + fused_rms_norm_f32_nc_cuda((const nv_bfloat16 *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); } else { fused_rms_norm_f32_nc_cuda((const half *)src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); } diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 6898c5fd..408ee2fd 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -6,6 +6,7 @@ // #include "reduce.cuh" +#include "ggml-common.h" #include @@ -16,6 +17,23 @@ static __global__ void k_add(int nelem, const T * src, T * dst) { dst[i] += src[i]; } +template +static __global__ void k_add(int nelem, const block_q8_0 * src, block_q8_0 * dst) { + int i = blockIdx.x*block_size + threadIdx.x; + if (i >= nelem) return; + int ib = i / QK8_0; + int iq = i % QK8_0; + float x = (float)src[ib].d * src[ib].qs[iq] + (float)dst[ib].d * dst[ib].qs[iq]; + float ax = fabsf(x); + float max = warp_reduce_max(ax); + float d = max / 127; + float id = d > 0 ? 1/d : 0; + dst[ib].qs[iq] = roundf(x * id); + if (threadIdx.x % WARP_SIZE == 0) { + dst[ib].d = (half)d; + } +} + template static __global__ void k_add_sym(int nelem, T * src, T * dst) { int i = blockIdx.x*block_size + threadIdx.x; @@ -68,7 +86,8 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ GGML_ASSERT(op == GGML_OP_ADD); int nreduce = dst->op_params[1]; int nhave = dst->op_params[2]; - GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32 || + dst->type == GGML_TYPE_Q8_0 || dst->type == GGML_TYPE_BF16); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(nhave >=2 && nhave <= nreduce); if (dst->op_params[3] == 1) { @@ -82,10 +101,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ // It does not work at all if not all GPUs participate in the reduce op, and we // get suboptimal prompt processing performance when we have more than 2 GPUs. // Hence, if enabled, we use NCCL only for the cases where it works and performs well. - if (info.have_nccl && nhave == nreduce && (nhave == 2 || dst->ne[1] < 32)) { + if (false && info.have_nccl && dst->type != GGML_TYPE_Q8_0 && nhave == nreduce && (nhave == 2 || dst->ne[1] < 32)) { GGML_ASSERT(info.have_nccl); GGML_ASSERT(info.device_count == nreduce); - auto data_type = dst->type == GGML_TYPE_F32 ? ncclFloat : ncclHalf; + auto data_type = dst->type == GGML_TYPE_F32 ? ncclFloat : dst->type == GGML_TYPE_BF16 ? ncclBfloat16 : ncclHalf; ncclGroupStart(); for (int i = 0; i < nreduce; ++i) { ggml_cuda_set_device(i); @@ -198,13 +217,25 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ // if (dst->ne[1] >= 32) { auto nelem = ggml_nelements(dst); - auto elem_size = ggml_element_size(dst); - auto nelem_per_device = (nelem + nhave - 1)/nhave; - auto required_size = nelem_per_device*elem_size; + auto tt = ggml_internal_get_type_traits(dst->type); + GGML_ASSERT(nelem % tt.blck_size == 0); + auto nblocks = nelem / tt.blck_size; + auto nblocks_per_device = (nblocks + nhave - 1)/nhave; + auto nelem_per_device = nblocks_per_device * tt.blck_size; + auto size_per_device = nblocks_per_device * tt.type_size; + //size_t nelem_per_device, required_size; + //if (dst->type == GGML_TYPE_Q8_0) { + // GGML_ASSERT(nelem % QK8_0 == 0); + // nelem_per_device = QK8_0*((nelem/QK8_0 + nhave - 1)/nhave); + // required_size nelem_per_device/QK8_0 * sizeof(ggml_block_q8_0); + //} + //auto elem_size = ggml_element_size(dst); + //auto nelem_per_device = (nelem + nhave - 1)/nhave; + //auto required_size = nelem_per_device*elem_size; for (int ii = 0; ii < nhave; ++ii) { int i = idx[ii]; auto this_ctx = info.all_ctx[i]; - if (!this_ctx->copy_event || !this_ctx->compute_event || required_size > this_ctx->copy_size) { + if (!this_ctx->copy_event || !this_ctx->compute_event || size_per_device > this_ctx->copy_size) { ggml_cuda_set_device(this_ctx->device); if (!this_ctx->copy_event) { CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->copy_event, cudaEventDisableTiming)); @@ -212,12 +243,12 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ if (!this_ctx->compute_event) { CUDA_CHECK(cudaEventCreateWithFlags(&this_ctx->compute_event, cudaEventDisableTiming)); } - if (required_size > this_ctx->copy_size) { + if (size_per_device > this_ctx->copy_size) { if (this_ctx->copy_buffer) { CUDA_CHECK(cudaFree(this_ctx->copy_buffer)); } - CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, required_size, this_ctx->device)); - this_ctx->copy_size = required_size; + CUDA_CHECK(ggml_cuda_device_malloc(&this_ctx->copy_buffer, size_per_device, this_ctx->device)); + this_ctx->copy_size = size_per_device; } } } @@ -227,13 +258,14 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ int i = idx[ii]; int peer = idx[(ii+1)%nhave]; auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device); + auto this_size = (this_nelem / tt.blck_size) * tt.type_size; ggml_cuda_set_device(info.all_ctx[peer]->device); if (stage > 0) { CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[peer]->stream(), info.all_ctx[i]->compute_event, 0)); } CUDA_CHECK(cudaMemcpyPeerAsync(info.all_ctx[i]->copy_buffer, info.all_ctx[i]->device, - (const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device, - this_nelem*elem_size, info.all_ctx[peer]->stream())); + (const char *)dst->src[peer]->data + ichunk*size_per_device, info.all_ctx[peer]->device, + this_size, info.all_ctx[peer]->stream())); CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream())); ichunk = (ichunk + 1)%nhave; } @@ -248,6 +280,13 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ if (dst->type == GGML_TYPE_F16) { k_add<<stream()>>>(this_nelem, (const half *)info.all_ctx[i]->copy_buffer, (half *)dst->src[i]->data + ichunk*nelem_per_device); + } else if (dst->type == GGML_TYPE_Q8_0) { + k_add<<stream()>>>(this_nelem, + (const block_q8_0 *)info.all_ctx[i]->copy_buffer, (block_q8_0 *)dst->src[i]->data + ichunk*nelem_per_device/tt.blck_size); + } else if (dst->type == GGML_TYPE_BF16) { + k_add<<stream()>>>( + this_nelem, (const nv_bfloat16 *)info.all_ctx[i]->copy_buffer, + (nv_bfloat16 *)dst->src[i]->data + ichunk*nelem_per_device); } else { k_add<<stream()>>>(this_nelem, (const float *)info.all_ctx[i]->copy_buffer, (float *)dst->src[i]->data + ichunk*nelem_per_device); @@ -262,13 +301,14 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ int i = idx[ii]; int peer = idx[(ii+1)%nhave]; auto this_nelem = std::min(nelem_per_device, nelem - ichunk*nelem_per_device); + auto this_size = (this_nelem / tt.blck_size) * tt.type_size; ggml_cuda_set_device(info.all_ctx[peer]->device); if (stage == 0) { CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[peer]->stream(), info.all_ctx[i]->compute_event, 0)); } - CUDA_CHECK(cudaMemcpyPeerAsync((char *)dst->src[i]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[i]->device, - (const char *)dst->src[peer]->data + ichunk*nelem_per_device*elem_size, info.all_ctx[peer]->device, - this_nelem*elem_size, info.all_ctx[peer]->stream())); + CUDA_CHECK(cudaMemcpyPeerAsync((char *)dst->src[i]->data + ichunk*size_per_device, info.all_ctx[i]->device, + (const char *)dst->src[peer]->data + ichunk*size_per_device, info.all_ctx[peer]->device, + this_size, info.all_ctx[peer]->stream())); CUDA_CHECK(cudaEventRecord(info.all_ctx[peer]->copy_event, info.all_ctx[peer]->stream())); //ggml_cuda_set_device(info.all_ctx[i]->device); //CUDA_CHECK(cudaStreamWaitEvent(info.all_ctx[i]->stream(), info.all_ctx[peer]->copy_event, 0)); @@ -351,6 +391,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ return; } if (dst->ne[1] < 32 && ctx.p2p_enabled) { + GGML_ASSERT(dst->type != GGML_TYPE_Q8_0); for (int ii = 0; ii < nhave; ++ii) { int i = idx[ii]; GGML_ASSERT(dst->src[i]->type == dst->type); @@ -464,6 +505,12 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), info.all_ctx[i]->copy_event, 0)); if (dst->type == GGML_TYPE_F16) { k_add<<>>(nelem, (const half *)ptr, (half *)dst->data); + } else if (dst->type == GGML_TYPE_BF16) { + k_add<<>>(nelem, + (const nv_bfloat16*)ptr, (nv_bfloat16 *)dst->data); + } else if (dst->type == GGML_TYPE_Q8_0) { + k_add<<>>(nelem, (const block_q8_0 *)ptr, + (block_q8_0 *)dst->data); } else { k_add<<>>(nelem, (const float *)ptr, (float *)dst->data); } diff --git a/include/llama.h b/include/llama.h index 2473af9e..31325cc7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -426,6 +426,7 @@ extern "C" { enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] + enum ggml_type type_reduce; // data type for reduce operations // Keep the booleans together to avoid misalignment during copy-by-value. bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) @@ -445,7 +446,7 @@ extern "C" { bool only_active_experts; bool k_cache_hadamard; // if true, apply Hadamard transfrom to K-cache bool split_mode_graph_scheduling; // if true, force split mode graph scheduling - bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs + //bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads // Abort callback diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 0a251b33..d51f8980 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -633,6 +633,27 @@ static ggml_tensor * get_input_tensor_sm_graph(ggml_tensor * input, int id) { return cur; } +static inline ggml_tensor * do_split_norm(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * the_norm, const llama_hparams & hparams, + const llm_build_cb & cb, int id, int il_cb, bool is_norm) { + if (the_norm && the_norm->extra) { + auto norm = (ggml_split_tensor_t *)the_norm->extra; + GGML_ASSERT(norm->splits[id]); + //if (cur->type != GGML_TYPE_F16 && cur->type != GGML_TYPE_F32) { + // cur = ggml_cast(ctx, cur, GGML_TYPE_F32); + //} + if (is_norm) { + cur = ggml_fused_norm(ctx, cur, norm->splits[id], hparams.f_norm_eps); + } else { + cur = llm_build_context::llm_build_norm(ctx, cur, hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il_cb); + } + cb(cur, "inp_normed", il_cb); + } + if (cur->type != GGML_TYPE_F32) { + cur = ggml_cast(ctx, cur, GGML_TYPE_F32); + } + return cur; +} + ggml_tensor * llm_build_context::llm_build_ffn( ggml_context * ctx, llama_context & lctx, @@ -673,19 +694,7 @@ ggml_tensor * llm_build_context::llm_build_ffn( GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d)); if (!split_u) continue; auto cur = get_input_tensor_sm_graph(input, id); - if (ffn_norm && ffn_norm->extra) { - auto norm = (ggml_split_tensor_t *)ffn_norm->extra; - GGML_ASSERT(norm->splits[id]); - if (is_norm) { - cur = ggml_fused_norm(ctx, cur, norm->splits[id], lctx.model.hparams.f_norm_eps); - } else { - cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il); - } - cb(cur, "ffn_inp_normed", il_cb); - } - else if (cur->type != GGML_TYPE_F32) { - cur = ggml_cast(ctx, cur, GGML_TYPE_F32); - } + cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, is_norm); cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op); cb(cur, "ffn_up_gate", il_cb); cur = llm_build_lora_mm(lctx, ctx, split_d, cur); @@ -694,8 +703,8 @@ ggml_tensor * llm_build_context::llm_build_ffn( // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } - if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) { - cur = ggml_cast(ctx, cur, GGML_TYPE_F16); + if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { + cur = ggml_cast(ctx, cur, lctx.cparams.reduce_type); } if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) { // When the reduce op is turned off via op_params[3] == 1, we need to add each src @@ -1205,8 +1214,8 @@ llm_expert_gating_func_type gating_op, split_down_shexp->splits[id], split_down_b_shexp ? split_down_b_shexp->splits[id] : nullptr, nullptr, nullptr, type_op_shexp, LLM_FFN_PAR, cb, il); cb(shared_out, "ffn_shexp_out", il_cb); - if (shared_out->ne[1] > 32 && lctx.cparams.split_mode_f16) { - shared_out = ggml_cast(ctx, shared_out, GGML_TYPE_F16); + if (shared_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { + shared_out = ggml_cast(ctx, shared_out, lctx.cparams.reduce_type); } results.push_back(shared_out); } @@ -1222,8 +1231,8 @@ llm_expert_gating_func_type gating_op, cb(cur, "ffn_shared_combined", il); } } - if (routed_out->ne[1] > 32 && lctx.cparams.split_mode_f16) { - auto routed_out_f16 = ggml_cast(ctx, routed_out, GGML_TYPE_F16); + if (routed_out->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { + auto routed_out_f16 = ggml_cast(ctx, routed_out, lctx.cparams.reduce_type); cur = ggml_add(ctx, routed_out_f16, cur); } else { cur = ggml_add(ctx, routed_out, cur); @@ -1269,15 +1278,16 @@ llm_expert_gating_func_type gating_op, if (!split_up_exps->splits[id]) continue; int il_cb = 1000*(id + 1) + il; auto cur = get_input_tensor_sm_graph(input, id); - if (ffn_norm) { - auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra; - GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device); - cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_inp_normed", il_cb); - } - if (cur->type != GGML_TYPE_F32) { - cur = ggml_cast(ctx, cur, GGML_TYPE_F32); - } + cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, false); + //if (ffn_norm) { + // auto split_ffn_norm = (ggml_split_tensor_t *)ffn_norm->extra; + // GGML_ASSERT(split_ffn_norm && split_ffn_norm->n_device == split_up_exps->n_device); + // cur = llm_build_norm(ctx, cur, lctx.model.hparams, split_ffn_norm->splits[id], nullptr, LLM_NORM_RMS, cb, il); + // cb(cur, "ffn_inp_normed", il_cb); + //} + //if (cur->type != GGML_TYPE_F32) { + // cur = ggml_cast(ctx, cur, GGML_TYPE_F32); + //} GGML_ASSERT(!split_gate_inp_b || split_gate_inp_b->splits[id]); GGML_ASSERT(!split_exps_down_b || split_exps_down_b->splits[id]); GGML_ASSERT(!split_exps_gate_b || split_exps_gate_b->splits[id]); @@ -1309,8 +1319,8 @@ llm_expert_gating_func_type gating_op, } else { cur = routed_out; } - if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) { - cur = ggml_cast(ctx, cur, GGML_TYPE_F16); + if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { + cur = ggml_cast(ctx, cur, lctx.cparams.reduce_type); cb(cur, "ffn_out_f16", il_cb); } ggml_build_forward_expand(graph, cur); @@ -9221,16 +9231,17 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens (split_wq && split_wk && split_wv && split_wo && split_kl && split_vl)); if (!split_wq) continue; auto cur = get_input_tensor_sm_graph(input, id); - if (attn_norm) { - if (is_norm) { - cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps); - } else { - cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il); - } - } - if (cur->type != GGML_TYPE_F32) { - cur = ggml_cast(ctx0, cur, GGML_TYPE_F32); - } + cur = do_split_norm(ctx0, cur, the_attn_norm, lctx.model.hparams, cb, id, il_cb, is_norm); + //if (attn_norm) { + // if (is_norm) { + // cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps); + // } else { + // cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il); + // } + //} + //if (cur->type != GGML_TYPE_F32) { + // cur = ggml_cast(ctx0, cur, GGML_TYPE_F32); + //} auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr; auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ? @@ -9368,8 +9379,8 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cb(cur, "kqv_wo_biased", il_cb); output_bias_added = true; } - if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) { - cur = ggml_cast(ctx0, cur, GGML_TYPE_F16); + if (cur->ne[1] > 32 && lctx.cparams.reduce_type != GGML_TYPE_F32) { + cur = ggml_cast(ctx0, cur, lctx.cparams.reduce_type); } ggml_build_forward_expand(gf, cur); attn[id] = cur; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 3735a474..3ee26c55 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -41,11 +41,12 @@ struct llama_cparams { bool graph_reuse; bool k_cache_hadamard; bool split_mode_graph_scheduling; - bool split_mode_f16; + //bool split_mode_f16; bool scheduler_async; int min_experts; float thresh_experts; + enum ggml_type reduce_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/src/llama.cpp b/src/llama.cpp index fd254752..347b3b70 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4070,6 +4070,7 @@ struct llama_context_params llama_context_default_params() { /*.cb_eval_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, + /*.type_reduce =*/ GGML_TYPE_F16, /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, @@ -4087,7 +4088,7 @@ struct llama_context_params llama_context_default_params() { /*.only_active_experts =*/ false, /*.k_cache_hadamard =*/ false, /*.split_mode_graph_scheduling =*/ false, - /*.split_mode_f16 =*/ true, + // /*.split_mode_f16 =*/ true, /*.scheduler_async =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -4382,6 +4383,8 @@ struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params) { + printf("===================================== %s: %s\n", __func__, ggml_type_name(params.type_reduce)); + if (!model) { LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); return nullptr; @@ -4452,12 +4455,13 @@ struct llama_context * llama_new_context_with_model( cparams.graph_reuse = params.graph_reuse; cparams.k_cache_hadamard = params.k_cache_hadamard; cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; - cparams.split_mode_f16 = params.split_mode_f16; + //cparams.split_mode_f16 = params.split_mode_f16; cparams.scheduler_async = params.scheduler_async; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; + cparams.reduce_type = params.type_reduce; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -4527,12 +4531,19 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = 0; } if (model->arch == LLM_ARCH_OPENAI_MOE && model->split_mode == LLAMA_SPLIT_MODE_GRAPH) { - if (cparams.split_mode_f16) { + //if (cparams.split_mode_f16) { + // LLAMA_LOG_WARN("=====================================================================\n"); + // LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n"); + // LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n"); + // LLAMA_LOG_WARN("=====================================================================\n"); + // cparams.split_mode_f16 = false; + //} + if (cparams.reduce_type == GGML_TYPE_F16) { LLAMA_LOG_WARN("=====================================================================\n"); LLAMA_LOG_WARN("GPT-OSS with split mode graph requires f32 precision\n"); LLAMA_LOG_WARN(" => changing cparams.split_mode_f16 to 'false'\n"); LLAMA_LOG_WARN("=====================================================================\n"); - cparams.split_mode_f16 = false; + cparams.reduce_type = GGML_TYPE_F32; } } @@ -4552,7 +4563,8 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: graph_reuse = %d\n", __func__, cparams.graph_reuse); LLAMA_LOG_INFO("%s: k_cache_hadam = %d\n", __func__, cparams.k_cache_hadamard); LLAMA_LOG_INFO("%s: split_mode_graph_scheduling = %d\n", __func__, cparams.split_mode_graph_scheduling); - LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16); + //LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16); + LLAMA_LOG_INFO("%s: reduce_type = %s\n", __func__, ggml_type_name(cparams.reduce_type)); LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);