mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
add vector load check (#680)
Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
cde_element_op_{cde_element_op},
|
||||
MRaw_{M},
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
@@ -335,6 +338,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking vector load/store
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -488,6 +496,85 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -239,7 +239,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
MRaw_{M},
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
@@ -276,6 +279,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
// for checking vector load/store
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -417,6 +424,68 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of C
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
|
||||
Reference in New Issue
Block a user