From 3e6e867dae47c94587b77a960a0a353dabec2441 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 15 Mar 2023 11:12:12 -0500 Subject: [PATCH] Add layout check to IsSupportedArgument (#627) * Add layout check to IsSupportedArgument * Format --------- Co-authored-by: Rosty Geyyer Co-authored-by: zjing14 [ROCm/composable_kernel commit: c10a6e8293ade863a9b177de956d31eb86f4b128] --- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 8142c9253b..4fc8e69d2c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return true; } + // check if DsLayout is supported + template + static bool CheckDLayout() + { + static bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + static bool IsSupportedArgument(const Argument& arg) { if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) @@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return false; } + // Check supported layouts + // A0 - Row + // B0 - Col + // D0s - Rows + // B1 - Row or Col + // D1s - Rows + // E1 - Row + if(!(is_same_v && + is_same_v && + CheckDLayout() && + (is_same_v || + is_same_v)&&CheckDLayout() && + is_same_v)) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, arg.b0_grid_desc_n_k_, arg.b1_grid_desc_n_k_,