mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
fix bug in gemm profiler (#344)
[ROCm/composable_kernel commit: 146972f447]
This commit is contained in:
@@ -26,6 +26,7 @@ static constexpr auto ConvSpec =
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
#if 1
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
@@ -78,6 +79,60 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
#else
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename BiasLayout,
|
||||
typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<BiasLayout>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<BiasDataType>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, //
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
16, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXdl
|
||||
16, // NPerXdl
|
||||
4, // MXdlPerWave
|
||||
1, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
2, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
4, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 256, 1, 1>,
|
||||
1>;
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user