diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 1f60818e39..21afc06040 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -64,7 +64,7 @@ __global__ void const index_t N = gemm_desc_ptr[group_id].N; const index_t K = gemm_desc_ptr[group_id].K; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) return; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 8354335577..68c6dcc0f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t N = gemm_descs[i].N_; const index_t K = gemm_descs[i].K_; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) { skipped_group_count_++; continue; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 70011124fc..2884e558cd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -109,7 +109,7 @@ __global__ void N = gemm_desc_ptr[group_id].N; K = gemm_desc_ptr[group_id].K; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) { grid_size_grp = 0; continue; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index c98ec6e2aa..ac05a0703f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -68,7 +68,7 @@ __global__ void const index_t N = gemm_desc_ptr[group_id].N; const index_t K = gemm_desc_ptr[group_id].K; - if(M * N * K == 0) + if(M == 0 || N == 0 || K == 0) return; const auto StrideA = gemm_desc_ptr[group_id].StrideA;