diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 8142c9253b..4fc8e69d2c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return true; } + // check if DsLayout is supported + template + static bool CheckDLayout() + { + static bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + static bool IsSupportedArgument(const Argument& arg) { if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) @@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return false; } + // Check supported layouts + // A0 - Row + // B0 - Col + // D0s - Rows + // B1 - Row or Col + // D1s - Rows + // E1 - Row + if(!(is_same_v && + is_same_v && + CheckDLayout() && + (is_same_v || + is_same_v)&&CheckDLayout() && + is_same_v)) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, arg.b0_grid_desc_n_k_, arg.b1_grid_desc_n_k_,