mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fixes
This commit is contained in:
@@ -369,6 +369,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
"wrong! only implemented for 2D and 3D now");
|
||||
|
||||
static_assert(!SkipBLds || AK1 == BK1);
|
||||
|
||||
// MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
|
||||
// implementation we can avoid copy data to workspace before kernel launch since number of
|
||||
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
|
||||
@@ -529,18 +531,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
|
||||
|
||||
static constexpr index_t BBlockBufferSize = 1;
|
||||
// Force to 1, due to KN layout for GKYXC
|
||||
static constexpr index_t BScalarPerVectorSkipLds = 1;
|
||||
|
||||
#define GridwiseGemmMultiDSkipBLdsTemplateParams \
|
||||
BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
||||
InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \
|
||||
element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \
|
||||
MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
|
||||
BBlockTransferSrcScalarPerVector, false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, \
|
||||
CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
#define GridwiseGemmMultiDSkipBLdsTemplateParams \
|
||||
BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
|
||||
InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \
|
||||
element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \
|
||||
MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
|
||||
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
|
||||
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
|
||||
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BScalarPerVectorSkipLds, \
|
||||
false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
using GridwiseGemm =
|
||||
|
||||
@@ -277,7 +277,7 @@ struct GridwiseGemm_xdlops_skip_b_lds_multiple_d_cshuffle
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned) * sizeof(ABDataType),
|
||||
c_block_size * sizeof(EDataType));
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
template <bool HasMainK0BlockLoop,
|
||||
|
||||
Reference in New Issue
Block a user