From 62e4a80be52cd231bcac05ef02a9eb3e9793f7c0 Mon Sep 17 00:00:00 2001 From: Astha Rai Date: Mon, 23 Jun 2025 18:42:55 +0000 Subject: [PATCH] resolved floating point error for some instances --- .../device_gemm_xdl_cshuffle_streamk_v3.hpp | 62 +++++++++++++------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp index 7eff891cd1..ccffbb1b44 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -159,25 +159,15 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; calculate_grid_size(kernel); } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + // printf("Case 15\n"); + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds; + calculate_grid_size(kernel); + } } + // printf("num_cu: %u\n", static_cast(num_cu)); + // printf("occupancy: %u\n", static_cast(occupancy)); return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c),