diff --git a/common/common.cpp b/common/common.cpp index 341c4f97..df4d5d0f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1020,6 +1020,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.fused_up_gate = false; return true; } + if (arg == "-no-mmad" || arg == "--no-fused-mul-multiadd") { + params.fused_mmad = false; + return true; + } if (arg == "-ser" || arg == "--smart-expert-reduction") { CHECK_ARG auto values = string_split_pairs(argv[i], ','); @@ -1806,6 +1810,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" }); options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -2762,6 +2767,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.grouped_expert_routing = params.grouped_expert_routing; cparams.fused_up_gate = params.fused_up_gate; + cparams.fused_mmad = params.fused_mmad; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; @@ -3879,6 +3885,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false"); fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false"); + fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 17628b98..327afb24 100644 --- a/common/common.h +++ b/common/common.h @@ -235,6 +235,7 @@ struct gpt_params { int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool fused_up_gate = true; // fused up*unary(gate) op + bool fused_mmad = true; // fused mul+multi_add op bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch) int min_experts = -1; float thresh_experts = 0; diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 41d691bb..d2cb164c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -619,6 +619,7 @@ extern "C" { GGML_OP_OUT_PROD, GGML_OP_FUSED_UP_GATE, GGML_OP_MOE_FUSED_UP_GATE, + GGML_OP_MUL_MULTI_ADD, GGML_OP_SCALE, GGML_OP_SET, @@ -1083,6 +1084,11 @@ extern "C" { struct ggml_tensor * a, int n_experts); + GGML_API struct ggml_tensor * ggml_mul_multi_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // dst = a // view(dst, nb1, nb2, nb3, offset) += b // return dst diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 29f9d26c..9f7fd33f 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -46,6 +46,7 @@ #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/argmax.cuh" +#include "ggml-cuda/multiadd.cuh" #include #include @@ -3178,6 +3179,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MULTI_ADD: ggml_cuda_op_multi_add(ctx, dst); break; + case GGML_OP_MUL_MULTI_ADD: + ggml_cuda_op_mul_multi_add(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -4408,6 +4412,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_MULTI_ADD: + case GGML_OP_MUL_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_FUSED_RMS_NORM: diff --git a/ggml/src/ggml-cuda/multiadd.cu b/ggml/src/ggml-cuda/multiadd.cu new file mode 100644 index 00000000..fba7271a --- /dev/null +++ b/ggml/src/ggml-cuda/multiadd.cu @@ -0,0 +1,87 @@ +#include "multiadd.cuh" + +static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + int64_t k = ne0*ne1; + if (i >= k) { + return; + } + int i1 = i / ne0; + int i0 = i % ne0; + float * result = (float *)(dst + i1*nb1); + const float * s = (const float *)(src0 + i1*nb01) + i0; + if (nused == 1) { + result[i0] = s[0]; + } else { + float sum = s[0] + s[ne0]; + for (int j = 2; j < nused; ++j) sum += s[j*ne0]; + result[i0] = sum; + } +} + +static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) { + int64_t k = ne0 * ne1; + const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE; + multi_add_f32<<>>(nused, ne0, ne1, nb1, nb01, src0, dst); +} + +void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + int nused = dst->op_params[0]; + GGML_ASSERT(nused >= 1); + const char * src0 = (const char *)dst->src[0]->data; + cudaStream_t stream = ctx.stream(); + multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream); +} + + +static __global__ void mul_multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, int64_t nb02, int64_t nb11, int64_t nb12, const char * src0, const char * src1, char * dst) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + int64_t k = ne0*ne1; + if (i >= k) { + return; + } + int i1 = i / ne0; + int i0 = i % ne0; + float * result = (float *)(dst + i1*nb1); + + auto c0 = src0 + i1*nb02; + auto c1 = src1 + i1*nb12; + + float sum = 0; + for (int j = 0; j < nused; ++j) { + auto x0 = (const float *)c0; + auto x1 = (const float *)c1; + sum += x0[i0] * x1[0]; + c0 += nb01; + c1 += nb11; + } + result[i0] = sum; +} + +static void mul_multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, int64_t nb02, int64_t nb11, int64_t nb12, + const char * src0, const char * src1, char * dst, cudaStream_t stream) { + int64_t k = ne0 * ne1; + const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE; + mul_multi_add_f32<<>>(nused, ne0, ne1, nb1, nb01, nb02, nb11, nb12, src0, src1, dst); +} + +void ggml_cuda_op_mul_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + auto src0 = dst->src[0]; + auto src1 = dst->src[1]; + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->ne[2] == dst->ne[1]); + GGML_ASSERT(src0->ne[1] == src1->ne[1]); + GGML_ASSERT(src0->ne[2] == src1->ne[2]); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[0] == 1); + + mul_multi_add_f32_cuda(src0->ne[1], dst->ne[0], dst->ne[1], dst->nb[1], src0->nb[1], src0->nb[2], src1->nb[1], src1->nb[2], + (const char *)src0->data, (const char *)src1->data, (char *)dst->data, ctx.stream()); +} diff --git a/ggml/src/ggml-cuda/multiadd.cuh b/ggml/src/ggml-cuda/multiadd.cuh new file mode 100644 index 00000000..f923597b --- /dev/null +++ b/ggml/src/ggml-cuda/multiadd.cuh @@ -0,0 +1,14 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "common.cuh" + +#define CUDA_MULTI_ADD_BLOCK_SIZE 256 + +void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_mul_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 090f5e86..49f22b98 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -59,25 +59,6 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); } -static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; - int64_t k = ne0*ne1; - if (i >= k) { - return; - } - int i1 = i / ne0; - int i0 = i % ne0; - float * result = (float *)(dst + i1*nb1); - const float * s = (const float *)(src0 + i1*nb01) + i0; - if (nused == 1) { - result[i0] = s[0]; - } else { - float sum = s[0] + s[ne0]; - for (int j = 2; j < nused; ++j) sum += s[j*ne0]; - result[i0] = sum; - } -} - static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -261,23 +242,6 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_ sqrt_f32<<>>(x, dst, k); } -static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) { - int64_t k = ne0 * ne1; - const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE; - multi_add_f32<<>>(nused, ne0, ne1, nb1, nb01, src0, dst); -} - -void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); - GGML_ASSERT(dst->nb[0] == sizeof(float)); - int nused = dst->op_params[0]; - GGML_ASSERT(nused >= 1); - const char * src0 = (const char *)dst->src[0]->data; - cudaStream_t stream = ctx.stream(); - multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream); -} - void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9434c8be..d66750be 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4222,6 +4222,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OUT_PROD", "FUSED_UP_GATE", "MOE_FUSED_UP_GATE", + "MUL_MULTI_ADD", "SCALE", "SET", @@ -4289,7 +4290,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); +static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4326,6 +4327,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "X*Y", "X*Y1&X*Y2", "X*Y1&X*Y2", + "x1*y1+x2*y2+...", "x*v", "y-\\>view(x)", @@ -4393,7 +4395,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)," }; -static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); +static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6103,6 +6105,31 @@ struct ggml_tensor * ggml_multi_add( return result; } +struct ggml_tensor * ggml_mul_multi_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + + bool is_node = false; + + GGML_ASSERT(a->ne[1] == b->ne[1]); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(a->ne[3] == b->ne[3]); + GGML_ASSERT(a->ne[3] == 1); + GGML_ASSERT(b->ne[0] == 1); + + int64_t ne[GGML_MAX_DIMS] = { a->ne[0], a->ne[2], 1, 1 }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); + + result->op = GGML_OP_MUL_MULTI_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_add_cast static struct ggml_tensor * ggml_add_cast_impl( @@ -22319,6 +22346,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_multi_add(params, tensor); } break; + case GGML_OP_MUL_MULTI_ADD: + { + iqk_mul_multi_add(tensor, params->ith, params->nth); + } break; case GGML_OP_ACC: { ggml_compute_forward_acc(params, tensor); @@ -23157,6 +23188,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: implement } + case GGML_OP_MUL_MULTI_ADD: + { + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_CONCAT: { GGML_ABORT("fatal error"); // TODO: implement @@ -24241,6 +24276,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD1: case GGML_OP_ACC: case GGML_OP_MULTI_ADD: + case GGML_OP_MUL_MULTI_ADD: { n_tasks = n_threads; } break; diff --git a/ggml/src/iqk/iqk_cpu_ops.cpp b/ggml/src/iqk/iqk_cpu_ops.cpp index a9bffadb..a27b0282 100644 --- a/ggml/src/iqk/iqk_cpu_ops.cpp +++ b/ggml/src/iqk/iqk_cpu_ops.cpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace { // Playing around with group scores: use sum of probabilities in the group @@ -409,3 +410,41 @@ void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, for (int j = 0; j < ne0; ++j) weights[j] *= norm; } } + +void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth) { + auto src0 = dst->src[0]; + auto src1 = dst->src[1]; + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->ne[2] == dst->ne[1]); + GGML_ASSERT(src0->ne[1] == src1->ne[1]); + GGML_ASSERT(src0->ne[2] == src1->ne[2]); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[0] == 1); + + int nrows = dst->ne[1]; + int npt = (nrows + nth - 1)/nth; + int first = ith*npt; + int last = std::min(nrows, first + npt); + + int ne01 = src0->ne[1]; + int ne00 = src0->ne[0]; + + for (int ir = first; ir < last; ++ir) { + auto c0 = (const char *)src0->data + ir*src0->nb[2]; + auto c1 = (const char *)src1->data + ir*src1->nb[2]; + auto cy = ( char *) dst->data + ir* dst->nb[1]; + std::memset(cy, 0, ne00*sizeof(float)); + for (int j = 0; j < ne01; ++j) { + auto x0 = (const float *)c0; + auto x1 = (const float *)c1; + auto y = ( float *)cy; + for (int k = 0; k < ne00; ++k) y[k] += x0[k] * x1[0]; + c0 += src0->nb[1]; + c1 += src1->nb[1]; + } + } +} diff --git a/ggml/src/iqk/iqk_cpu_ops.h b/ggml/src/iqk/iqk_cpu_ops.h index e00157a5..833eb9a5 100644 --- a/ggml/src/iqk/iqk_cpu_ops.h +++ b/ggml/src/iqk/iqk_cpu_ops.h @@ -26,6 +26,8 @@ void iqk_glm45moe_experts(struct ggml_tensor * dst, struct ggml_tensor * topk_vi void iqk_openai_experts(struct ggml_tensor * topk, struct ggml_tensor * softmax, int ith, int nth); +void iqk_mul_multi_add(struct ggml_tensor * dst, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index d24de230..504f9280 100644 --- a/include/llama.h +++ b/include/llama.h @@ -422,6 +422,7 @@ extern "C" { bool fused_moe_up_gate; // whether to use fused MoE up/gate op bool grouped_expert_routing; // whether to use grouped expert routing (BailingMoeV2 arch) bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] + bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL] int min_experts; float thresh_experts; bool only_active_experts; diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 5a5e1f27..b8c7393d 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -50,6 +50,7 @@ llm_build_context::llm_build_context( fused_moe_up_gate(cparams.fused_moe_up_gate), grouped_expert_routing(cparams.grouped_expert_routing), fused_up_gate (cparams.fused_up_gate), + fused_mmad (cparams.fused_mmad), min_experts (cparams.min_experts), thresh_experts (cparams.thresh_experts), pooling_type (cparams.pooling_type), @@ -941,6 +942,11 @@ llm_expert_gating_func_type gating_op, } if (!weight_before_ffn) { + if (lctx.cparams.fused_mmad) { + experts = ggml_mul_multi_add(ctx, experts, weights); + cb(experts, "ffn_moe_weighted", il); + return experts; + } experts = ggml_mul(ctx, experts, weights); cb(experts, "ffn_moe_weighted", il); } diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 2381a656..f3eeae75 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -80,6 +80,7 @@ struct llm_build_context { const bool fused_moe_up_gate; const bool grouped_expert_routing; const bool fused_up_gate; + const bool fused_mmad; const int min_experts; const float thresh_experts; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index cbfb4949..528184c8 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -33,6 +33,7 @@ struct llama_cparams { bool fused_moe_up_gate; bool grouped_expert_routing; bool fused_up_gate; + bool fused_mmad; int min_experts; float thresh_experts; diff --git a/src/llama.cpp b/src/llama.cpp index 57f2b75c..fe3bfbce 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3756,6 +3756,7 @@ struct llama_context_params llama_context_default_params() { /*.fused_moe_up_gate =*/ false, /*.grouped_expert_routing =*/ false, /*.fused_up_gate =*/ true, + /*.fused_mmad =*/ true, /*.min_experts =*/ -1, /*.thtesh_experts =*/ 0.0f, /*.only_active_experts =*/ false, @@ -3966,6 +3967,7 @@ struct llama_context * llama_new_context_with_model( cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.grouped_expert_routing = params.grouped_expert_routing; cparams.fused_up_gate = params.fused_up_gate; + cparams.fused_mmad = params.fused_mmad; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; @@ -4047,6 +4049,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: grouped er = %d\n", __func__, cparams.grouped_expert_routing); LLAMA_LOG_INFO("%s: fused_up_gate = %d\n", __func__, cparams.fused_up_gate); + LLAMA_LOG_INFO("%s: fused_mmad = %d\n", __func__, cparams.fused_mmad); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);