mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
@@ -332,7 +332,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
cde_element_op_{cde_element_op},
|
||||
MRaw_{MRaw},
|
||||
NRaw_{NRaw},
|
||||
KRaw_{KRaw}
|
||||
{
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -400,6 +403,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking vector load/store
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -486,6 +494,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
|
||||
Reference in New Issue
Block a user