mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Add layout check to IsSupportedArgument (#627)
* Add layout check to IsSupportedArgument * Format --------- Co-authored-by: Rosty Geyyer <rosty.geyyer@amd.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -788,6 +788,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if DsLayout is supported
|
||||
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
|
||||
static bool CheckDLayout()
|
||||
{
|
||||
static bool valid = true;
|
||||
// iterate over DLayout tuple
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// if RefLayout and DLayout are same, keep valid true, otherwise false
|
||||
valid = valid && is_same_v<RefLayout, DLayout>;
|
||||
});
|
||||
return valid;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
@@ -795,6 +809,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check supported layouts
|
||||
// A0 - Row
|
||||
// B0 - Col
|
||||
// D0s - Rows
|
||||
// B1 - Row or Col
|
||||
// D1s - Rows
|
||||
// E1 - Row
|
||||
if(!(is_same_v<tensor_layout::gemm::RowMajor, A0Layout> &&
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
|
||||
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor,
|
||||
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
|
||||
D1sLayout,
|
||||
NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
|
||||
arg.b0_grid_desc_n_k_,
|
||||
arg.b1_grid_desc_n_k_,
|
||||
|
||||
Reference in New Issue
Block a user