From 9e07839ba3077c1e5eda99895a418525ae14cea8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 21 Jan 2026 07:53:18 +0200 Subject: [PATCH 1/6] Correct GLM-4.7-Flash gating function (#1174) * Correct GLM-4.7-Flash gating function * This is better --- src/llama-hparams.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 79e0a1d0..7b889c9c 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -27,6 +27,14 @@ const char * llama_hparams::rope_scaling_type_name(llama_rope_scaling_type type) return LLAMA_ROPE_SCALING_TYPES.at(type); } +static inline const char * llm_expert_gating_func_name(llm_expert_gating_func_type type) { + switch (type) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax"; + case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid"; + case LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT: return "weight"; + default: return "none"; + } +} void llm_load_hparams( @@ -778,11 +786,17 @@ void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_TYPE_NONE; ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == 0) { - // for compatibility with existing DeepSeek V2 and V2.5 GGUFs - // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; + if (hparams.expert_gating_func == LLM_EXPERT_GATING_FUNC_TYPE_NONE) { + // Some models don't have the experts gating function recorded in the GGUF + // Hence, we make the LLM_KV_EXPERT_GATING_FUNC entry optional, and set here if missing. + // DeepSeek models normally have softmax as gating function, but there is GLM-4.7-Flash now + // (identified via number of layers being 47 or 48), which uses sigmoid. + hparams.expert_gating_func = hparams.n_layer == 47 || hparams.n_layer == 48 ? + LLM_EXPERT_GATING_FUNC_SIGMOID : LLM_EXPERT_GATING_FUNC_SOFTMAX; + printf("================= Missing experts gating function -> set to %s\n", + llm_expert_gating_func_name(llm_expert_gating_func_type(hparams.expert_gating_func))); } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); From 987651e54c6432defb8342d9c8ba1bb5467c4dd5 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 21 Jan 2026 09:12:40 +0200 Subject: [PATCH 2/6] Make comments more precise when experts gating function is missing (#1175) --- src/llama-hparams.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 7b889c9c..24d07a9a 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -789,10 +789,13 @@ void llm_load_hparams( hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_TYPE_NONE; ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLM_EXPERT_GATING_FUNC_TYPE_NONE) { - // Some models don't have the experts gating function recorded in the GGUF + // Older DeepSeek models from the 2.0/2.5 series may not have the experts gating function recorded in the GGUF. + // Such models use SOFTMAX as the experts gating function. + // The new (new as of this commit) GLM-4.7-Flash may also be missing the experts gating function. + // GLM-4.7-Flash uses SIGMOID as the experts gating function. // Hence, we make the LLM_KV_EXPERT_GATING_FUNC entry optional, and set here if missing. - // DeepSeek models normally have softmax as gating function, but there is GLM-4.7-Flash now - // (identified via number of layers being 47 or 48), which uses sigmoid. + // We distinguish between GLM-4.7-Flash and DeepSeek-2/2.5 models by the number of layers. + // GLM-4.7-Flash has 47 layers (or 48, if an MTP layer is included in the GGUF). hparams.expert_gating_func = hparams.n_layer == 47 || hparams.n_layer == 48 ? LLM_EXPERT_GATING_FUNC_SIGMOID : LLM_EXPERT_GATING_FUNC_SOFTMAX; printf("================= Missing experts gating function -> set to %s\n", From 77c18acc90acba215d7f7979a1004632d9aa9aa4 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 22 Jan 2026 12:25:05 +0200 Subject: [PATCH 3/6] Fix non-contiguous batched cuBLAS (#1178) --- ggml/src/ggml-cuda.cu | 243 ++++++++++++++++++++++----------- ggml/src/ggml-cuda/convert.cu | 66 +++++++++ ggml/src/ggml-cuda/convert.cuh | 22 +++ ggml/src/ggml-cuda/dmmv.cu | 3 +- 4 files changed, 249 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index eef83572..a306ee26 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1836,7 +1836,7 @@ static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml } static __global__ void k_compute_batched_ptrs( - const half * src0_as_f16, const half * src1_as_f16, char * dst, + const void * src0_as_f16, const void * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, @@ -1844,86 +1844,155 @@ static __global__ void k_compute_batched_ptrs( size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, int64_t r2, int64_t r3) { - int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; - int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; + const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; if (i13 >= ne13 || i12 >= ne12) { return; } - int64_t i03 = i13 / r3; - int64_t i02 = i12 / r2; + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } -static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +// Type traits for mapping ggml types to CUDA/cuBLAS types +template +struct batched_mul_mat_traits; + +template<> +struct batched_mul_mat_traits { + using cuda_type = float; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_32F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F32; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits { + using cuda_type = nv_bfloat16; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_16BF; + static inline const ggml_type ggml_type_val = GGML_TYPE_BF16; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits { + using cuda_type = half; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + static inline const cudaDataType_t data_type = CUDA_R_16F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F16; + static inline const half alpha = 1.0; + static inline const half beta = 0.0; + static inline const void* get_alpha() { static const half val = alpha; return &val; } + static inline const void* get_beta() { static const half val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } +}; + +template +static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + using traits = batched_mul_mat_traits; + using cuda_t = typename traits::cuda_type; + GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(src0->type == src0_type); + GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. + // As long as dst is contiguous this does not matter though. GGML_TENSOR_BINARY_OP_LOCALS const int64_t ne_dst = ggml_nelements(dst); - cudaStream_t main_stream = ctx.stream(); - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); - void * src0_ddq = src0->data; - half * src0_f16 = (half *) src0_ddq; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; + float * dst_ddf = (float *) dst->data; + const size_t ts_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == ts_src1); + int64_t s11 = nb11 / ts_src1; + int64_t s12 = nb12 / ts_src1; + int64_t s13 = nb13 / ts_src1; - // convert src1 to fp16 - ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); - if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + const cuda_t * src0_ptr = nullptr; + const cuda_t * src1_ptr = nullptr; + + ggml_cuda_pool_alloc src0_alloc(ctx.pool()); + ggml_cuda_pool_alloc src1_alloc(ctx.pool()); + + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + + // Handle src0 + src0_ptr = (const cuda_t *) src0->data; + + // Handle src1 - convert if necessary + if (src1->type == src0_type) { + src1_ptr = (const cuda_t *) src1->data; + } else { + // Convert src1 to target type using traits conversion functions const int64_t ne_src1 = ggml_nelements(src1); - src1_f16_alloc.alloc(ne_src1); - GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ggml_nrows(src1), src1->ne[0], main_stream); + src1_alloc.alloc(ne_src1); + + const auto convert_func = traits::get_nc_converter(src1->type); + GGML_ASSERT(convert_func != nullptr); + convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); + src1_ptr = src1_alloc.get(); + s11 = ne10; + s12 = ne11*s11; + s13 = ne12*s12; + + is_src1_cont_2 = true; } - half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool()); + // Setup destination buffer + ggml_cuda_pool_alloc dst_temp(ctx.pool()); char * dst_t; - - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - cudaDataType_t cu_data_type = CUDA_R_16F; - - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - + cublasComputeType_t cu_compute_type = traits::compute_type; + cudaDataType_t cu_data_type = traits::data_type; + cudaDataType_t cu_data_type_a = traits::data_type; + cudaDataType_t cu_data_type_b = traits::data_type; + const void * alpha = traits::get_alpha(); + const void * beta = traits::get_beta(); const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; - - const void * alpha = &alpha_f16; - const void * beta = &beta_f16; + const float beta_f32 = 0.0f; if (dst->op_params[0] == GGML_PREC_DEFAULT) { - dst_t = (char *) dst_f16.alloc(ne_dst); - - nbd2 /= sizeof(float) / sizeof(half); - nbd3 /= sizeof(float) / sizeof(half); + if constexpr (src0_type == GGML_TYPE_F32) { + dst_t = (char *) dst_ddf; // Direct F32 output + } else { + dst_t = (char *) dst_temp.alloc(ne_dst); + nbd2 /= sizeof(float) / sizeof(cuda_t); + nbd3 /= sizeof(float) / sizeof(cuda_t); + } } else { dst_t = (char *) dst_ddf; - cu_compute_type = CUBLAS_COMPUTE_32F; - cu_data_type = CUDA_R_32F; - + cu_data_type = CUDA_R_32F; alpha = &alpha_f32; - beta = &beta_f32; + beta = &beta_f32; } + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -1931,77 +2000,85 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 - // use cublasGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; - CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else -#ifdef GGML_USE_MUSA - GGML_ASSERT(false); -#else // !GGML_USE_MUSA - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA - (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB - beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC + alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA + src1_ptr, cu_data_type_b, s11, smb, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { // use cublasGemmBatchedEx - const int ne23 = ne12*ne13; + const int64_t ne23 = ne12*ne13; ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( - src0_f16, src1_f16, dst_t, + size_t src1_stride_size = sizeof(cuda_t); + + const int threads_x = 16; + const int threads_y = 16; + dim3 block_dims(threads_x, threads_y); + + dim3 grid_dims( + (ne13 + threads_x - 1) / threads_x, + (ne12 + threads_y - 1) / threads_y + ); + k_compute_batched_ptrs<<>>( + src0_ptr, src1_ptr, dst_t, ptrs_src.get(), ptrs_dst.get(), ne12, ne13, ne23, nb02, nb03, - src1->type == GGML_TYPE_F16 ? nb12 : nb12/2, - src1->type == GGML_TYPE_F16 ? nb13 : nb13/2, + (src1->type == src0_type) ? nb12 : s12*src1_stride_size, + (src1->type == src0_type) ? nb13 : s13*src1_stride_size, nbd2, nbd3, r2, r3); + CUDA_CHECK(cudaGetLastError()); CUBLAS_CHECK( cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10, - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01, + alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00, + (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, ne23, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } -#endif // GGML_USE_MUSA -#endif - if (dst->op_params[0] == GGML_PREC_DEFAULT) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_ddf, ggml_nrows(dst), dst->ne[0], main_stream); + // Convert output back to F32 if needed + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val); + to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, 1, main_stream); + } +} + +static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32); + + switch (src0->type) { + case GGML_TYPE_F32: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + case GGML_TYPE_BF16: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + case GGML_TYPE_F16: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + default: + GGML_ABORT("Unsupported type"); } } diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 689613f5..05e0ba19 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -2122,3 +2122,69 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return nullptr; } } + +// non-contuigous conversions + +template +static __global__ void convert_unary( + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i00 >= ne00) { + return; + } + + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + + const src_t * x = (const src_t *) vx; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); +} + +template +static void convert_unary_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + convert_unary<<>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; + default: + return nullptr; + } +} + +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; + case GGML_TYPE_F16: + return convert_unary_cuda; + default: + return nullptr; + } +} + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F16: + return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; + default: + return nullptr; + } +} + diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 9e131d2b..a519980c 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -22,6 +22,19 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); +template +using to_t_nc_cuda_t = void (*)(const void * x, T * y, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); + +typedef to_t_nc_cuda_t to_fp32_nc_cuda_t; +typedef to_t_nc_cuda_t to_fp16_nc_cuda_t; +typedef to_t_nc_cuda_t to_bf16_nc_cuda_t; + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); + template __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { if constexpr (std::is_same_v) { @@ -30,6 +43,15 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { + // bypass compile error on cuda 12.0.1 +#ifdef GGML_USE_HIPBLAS + return __float22bfloat162_rn(x); +#else + return {x.x, x.y}; +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 50e6458d..7a5e3841 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -940,6 +940,5 @@ bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || - src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT || - src0_type == GGML_TYPE_F16; + src0_type == GGML_TYPE_IQ2_KT || src0_type == GGML_TYPE_IQ3_KT || src0_type == GGML_TYPE_IQ4_KT; } From 1cb8cd534f8b1214c8da483fd44579752c32b8eb Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 22 Jan 2026 12:26:23 +0200 Subject: [PATCH 4/6] Fix build failure when OpenMP is not available (#1171) --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index d99d7022..38fd5692 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2165,6 +2165,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s for (auto & s : sched->statuses) s = GGML_STATUS_SUCCESS; + int first_reduce = -1; bool work_done = false; #ifdef GGML_USE_OPENMP //This may not be available in old OpenMP versions @@ -2185,7 +2186,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } } - int first_reduce = -1; for (int i = 0; i < sched->n_splits; i++) { auto split = &sched->splits[i]; if (split->graph.n_nodes == 1 && split->graph.nodes[0]->op == GGML_OP_REDUCE) { From 101fe54797e8029b727af5f522aba49848c3a7d8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 22 Jan 2026 12:28:11 +0200 Subject: [PATCH 5/6] CUDA graphs with tensor overrides (#1172) * Use GUDA graphs also when theretensor overrides * Change graph key --- ggml/src/ggml-cuda.cu | 140 +++++++++++++++------------------- ggml/src/ggml-cuda/common.cuh | 8 +- ggml/src/ggml-cuda/cpy.cu | 30 ++++---- ggml/src/ggml-cuda/graph.cuh | 2 + 4 files changed, 83 insertions(+), 97 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index a306ee26..298d3214 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3740,11 +3740,23 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool use_cuda_graph) { +static inline const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { + return cgraph->nodes[0]; +} + +static inline ggml_cuda_graph * ggml_cuda_get_graph(ggml_backend_cuda_context & ctx, const void * key) { + auto & graph = ctx.cuda_graphs[key]; + if (!graph) { + graph = std::make_unique(); + } + return graph.get(); +} + +static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph * graph, ggml_cgraph * cgraph, + bool use_cuda_graph, cudaStream_t stream) { // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); + graph->cpy_dest_ptrs.clear(); const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; @@ -3755,16 +3767,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; - } - - if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { - use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture -#ifndef NDEBUG - GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer %s\n", __func__, node->src[0]->name); -#endif - } + if (ggml_is_noop(node)) continue; if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture @@ -3812,7 +3815,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud // Store the pointers which are updated for each token, such that these can be sent // to the device and accessed using indirection from CUDA graph - cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); + graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); // store a pointer to each copy op CUDA kernel to identify it later void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); @@ -3829,9 +3832,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud } if (use_cuda_graph) { - cuda_ctx->cuda_graph->use_cpy_indirection = true; + graph->use_cpy_indirection = true; // copy pointers to GPU so they can be accessed via indirection within CUDA graph - ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); + ggml_cuda_cpy_dest_ptrs_copy(graph, graph->cpy_dest_ptrs.data(), graph->cpy_dest_ptrs.size(), stream); } return use_cuda_graph; @@ -3888,18 +3891,18 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool is_cuda_graph_update_required(ggml_cuda_graph * graph, ggml_cgraph * cgraph) { bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->instance == nullptr) { + if (graph->instance == nullptr) { cuda_graph_update_required = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + if (graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + graph->ggml_graph_properties.resize(cgraph->n_nodes); } // Loop over nodes in GGML graph to determine if CUDA graph update is required @@ -3907,26 +3910,26 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, for (int i = 0; i < cgraph->n_nodes; i++) { bool has_matching_properties = true; if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph->ggml_graph_properties[i]); } if (!has_matching_properties) { cuda_graph_update_required = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + set_ggml_graph_node_properties(cgraph->nodes[i], &graph->ggml_graph_properties[i]); } return cuda_graph_update_required; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void update_cuda_graph_executable(ggml_cuda_graph * graph) { #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info); #else cudaGraphNode_t errorNode; cudaGraphExecUpdateResult result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info); #endif // CUDART_VERSION >= 12000 if (stat == cudaErrorGraphExecUpdateFailure) { @@ -3937,9 +3940,9 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(graph->instance)); + graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } @@ -3952,6 +3955,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // TODO [[maybe_unused]] const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated; +#ifdef USE_CUDA_GRAPH + auto graph = ggml_cuda_get_graph(*cuda_ctx, ggml_cuda_graph_get_key(cgraph)); +#endif + //printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us()); while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. @@ -3961,34 +3968,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; - } - -#if 0 - static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion) { - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - i += 2; - ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); - continue; - } - } -#endif -#ifndef NDEBUG - //assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); - //for (int j = 0; j < GGML_MAX_SRC; j++) { - // if (node->src[j] != nullptr) { - // assert(node->src[j]->buffer); - // } - //} -#endif // NDEBUG + if (ggml_is_noop(node)) continue; bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i); if (!ok) { @@ -3999,12 +3979,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } #ifdef USE_CUDA_GRAPH if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; + if (graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph->graph)); + graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured std::lock_guard lock(ggml_cuda_lock); @@ -4017,14 +3997,14 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + if (graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); + update_cuda_graph_executable(graph); } // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); #else graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH @@ -4037,6 +4017,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t ggml_cuda_set_device(cuda_ctx->device); #ifdef USE_CUDA_GRAPH + cuda_ctx->cur_graph = nullptr; + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, @@ -4044,16 +4026,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Also disable for multi-gpu for now. TO DO investigate bool use_cuda_graph = !disable_cuda_graphs_due_to_env && cuda_ctx->use_cuda_graph; - // Objects required for CUDA Graph - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } + auto graph = ggml_cuda_get_graph(*cuda_ctx, ggml_cuda_graph_get_key(cgraph)); + cuda_ctx->cur_graph = graph; bool cuda_graph_update_required = false; - if (use_cuda_graph && cuda_ctx->cuda_graph->graph == nullptr) { + if (use_cuda_graph && graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; + graph->disable_due_to_gpu_arch = true; #ifndef NDEBUG GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); #endif @@ -4061,26 +4041,26 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } if (use_cuda_graph && ( - cuda_ctx->cuda_graph->disable_due_to_gpu_arch || - cuda_ctx->cuda_graph->disable_due_to_too_many_updates || - cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture)) { + graph->disable_due_to_gpu_arch || + graph->disable_due_to_too_many_updates || + graph->disable_due_to_failed_graph_capture)) { use_cuda_graph = false; } if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + cuda_graph_update_required = is_cuda_graph_update_required(graph, cgraph); - use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(graph, cgraph, use_cuda_graph, cuda_ctx->stream()); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; + graph->number_consecutive_updates++; } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; + graph->number_consecutive_updates = 0; } - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; + if (graph->number_consecutive_updates >= 4) { + graph->disable_due_to_too_many_updates = true; #ifndef NDEBUG GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); #endif @@ -4098,7 +4078,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } if (!use_cuda_graph) { - cuda_ctx->cuda_graph->use_cpy_indirection = false; + graph->use_cpy_indirection = false; } #else diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ba26db5b..45021c3e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include "vendors/hip.h" @@ -849,13 +850,16 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr cuda_graph; - int fusion = GGML_CUDA_FUSION; int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; int mmq_id_thresh = 32; #ifdef USE_CUDA_GRAPH bool use_cuda_graph = true; + + ggml_cuda_graph * cur_graph = nullptr; + + std::unordered_map> cuda_graphs; + #endif void * copy_buffer = nullptr; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index b5fe2d87..7b7e6b26 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -542,9 +542,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char ** dest_ptrs_d = nullptr; int graph_cpynode_index = -1; #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(!disable_indirection_for_this_node && ctx.cuda_graph && ctx.cuda_graph->use_cpy_indirection) { - dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; - graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + if(!disable_indirection_for_this_node && ctx.cur_graph && ctx.cur_graph->use_cpy_indirection) { + dest_ptrs_d = ctx.cur_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cur_graph->graph_cpynode_index; } #else GGML_UNUSED(disable_indirection_for_this_node); @@ -651,8 +651,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_type_name(src0->type), ggml_type_name(src1->type)); } #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(!disable_indirection_for_this_node && ctx.cuda_graph && ctx.cuda_graph->use_cpy_indirection) { - ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + if(!disable_indirection_for_this_node && ctx.cur_graph && ctx.cur_graph->use_cpy_indirection) { + ctx.cur_graph->graph_cpynode_index = graph_cpynode_index; } #else GGML_UNUSED(disable_indirection_for_this_node); @@ -796,9 +796,9 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1, char ** dest_ptrs = nullptr; int graph_cpynode_index = -1; #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { - dest_ptrs = ctx.cuda_graph->dest_ptrs_d; - graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) { + dest_ptrs = ctx.cur_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cur_graph->graph_cpynode_index; } #else GGML_UNUSED(disable_indirection); @@ -813,8 +813,8 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1, } #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { - ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) { + ctx.cur_graph->graph_cpynode_index = graph_cpynode_index; } #endif return true; @@ -859,9 +859,9 @@ bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * c char ** dest_ptrs = nullptr; int graph_cpynode_index = -1; #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { - dest_ptrs = ctx.cuda_graph->dest_ptrs_d; - graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) { + dest_ptrs = ctx.cur_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cur_graph->graph_cpynode_index; } #endif @@ -874,8 +874,8 @@ bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * c } #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { - ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) { + ctx.cur_graph->graph_cpynode_index = graph_cpynode_index; } #endif return true; diff --git a/ggml/src/ggml-cuda/graph.cuh b/ggml/src/ggml-cuda/graph.cuh index ed032aa5..1b3a30ff 100644 --- a/ggml/src/ggml-cuda/graph.cuh +++ b/ggml/src/ggml-cuda/graph.cuh @@ -1,5 +1,7 @@ #pragma once +#include "ggml.h" + struct ggml_graph_node_properties { void * node_address; ggml_op node_op; From 573e23679dae6524d683288a34a9be102f53918f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 22 Jan 2026 12:28:30 +0200 Subject: [PATCH 6/6] sweep_bench: set number of repetions (#1176) --- common/common.cpp | 5 +++ common/common.h | 59 ++++++++++++------------ examples/sweep-bench/sweep-bench.cpp | 67 +++++++++++++++++----------- 3 files changed, 77 insertions(+), 54 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 2e7d6312..3192fd37 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -786,6 +786,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.max_extra_alloc_MiB = std::stoi(argv[i]); return true; } + if (arg == "-nrep" || arg == "--n-repetitions") { + CHECK_ARG + params.nrep = std::stoi(argv[i]); + return true; + } if (arg == "--samplers") { CHECK_ARG const auto sampler_names = string_split(argv[i], ";"); diff --git a/common/common.h b/common/common.h index b67d62d1..1de82a6c 100644 --- a/common/common.h +++ b/common/common.h @@ -145,38 +145,39 @@ struct gpt_params { int32_t n_threads = cpu_get_num_math(); int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_threads_batch_draft = -1; - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 0; // context size - int32_t n_ctx_draft = 0; // context size for draft model - int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 16; // number of tokens to draft during speculative decoding - int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding - float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy) - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - int32_t max_gpu = 0; // max number of GPUs to use at a time for split mode "graph" - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 0; // context size + int32_t n_ctx_draft = 0; // context size for draft model + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_draft_min = 1; // minimum number of tokens to draft during speculative decoding + float p_draft_min = 0.8f; // minimum speculative decoding probability (greedy) + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + int32_t max_gpu = 0; // max number of GPUs to use at a time for split mode "graph" + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = -1.0f; // YaRN low correction dim + float yarn_beta_fast = -1.0f; // YaRN low correction dim float yarn_beta_slow = -1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = -1.0f; // KV cache defragmentation threshold - int32_t max_extra_alloc_MiB = 256; // additional VRAM per GPU the scheduler may allocate for more efficient compute graph evaluation + int32_t yarn_orig_ctx = 0; // YaRN original context length + float defrag_thold = -1.0f; // KV cache defragmentation threshold + int32_t max_extra_alloc_MiB = 256; // extra VRAM per GPU the scheduler may allocate for more efficient compute graph evaluation + int32_t nrep = 1; // number of repetitions used in sweep bench ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; diff --git a/examples/sweep-bench/sweep-bench.cpp b/examples/sweep-bench/sweep-bench.cpp index 449a0b66..f77ca658 100644 --- a/examples/sweep-bench/sweep-bench.cpp +++ b/examples/sweep-bench/sweep-bench.cpp @@ -31,6 +31,7 @@ int main(int argc, char ** argv) { print_usage(argc, argv); return 1; } + if (params.nrep < 1) params.nrep = 1; // init LLM @@ -135,49 +136,63 @@ int main(int argc, char ** argv) { common_batch_clear(batch); llama_kv_cache_clear(ctx); + int i_loop = 0; + for (unsigned int n_kv = 0; n_kv < n_kv_max; n_kv += params.n_ubatch) { // clean up KV cache before generation - llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + //llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + int nrep = i_loop < 1 ? params.nrep : 1; // first measure token generation performance at this context size const auto t_tg_start = ggml_time_us(); - for (unsigned int i = 0; i < tg; ++i) { + for (int irep = 0; irep < nrep; ++irep) { + + llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + for (unsigned int i = 0; i < tg; ++i) { + common_batch_clear(batch); + common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + } + + const auto t_tg_end = ggml_time_us(); + + // measure prompt processing performance + const auto t_pp_start = ggml_time_us(); + + for (int irep = 0; irep < nrep; ++irep) { + + // clean up KV cache after generation + llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + // prepare batch of pp size for prompt processing performance measurement common_batch_clear(batch); - common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true); + + for (unsigned int i = 0; i < pp; ++i) { + common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false); + } + batch.logits[batch.n_tokens - 1] = true; if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } - } - const auto t_tg_end = ggml_time_us(); - - // clean up KV cache after generation - llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); - - // prepare batch of pp size for prompt processing performance measurement - common_batch_clear(batch); - - for (unsigned int i = 0; i < pp; ++i) { - common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false); - } - batch.logits[batch.n_tokens - 1] = true; - - // measure prompt processing performance - const auto t_pp_start = ggml_time_us(); - - if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG_TEE("%s: llama_decode() failed\n", __func__); - return 1; } const auto t_pp_end = ggml_time_us(); // calculate and print metrics - const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f; - const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f; + const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f / nrep; + const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f / nrep; const float speed_pp = pp / t_pp; const float speed_tg = tg / t_tg; @@ -192,6 +207,8 @@ int main(int argc, char ** argv) { } else { LOG_TEE("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg); } + + ++i_loop; } llama_batch_free(batch);