This commit is contained in:
Bartlomiej Kocot
2025-09-02 22:43:43 +00:00
parent 957639a291
commit e688dcfcda
2 changed files with 15 additions and 12 deletions

View File

@@ -369,6 +369,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
static_assert(!SkipBLds || AK1 == BK1);
// MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
// implementation we can avoid copy data to workspace before kernel launch since number of
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
@@ -529,18 +531,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
static constexpr index_t BBlockBufferSize = 1;
// Force to 1, due to KN layout for GKYXC
static constexpr index_t BScalarPerVectorSkipLds = 1;
#define GridwiseGemmMultiDSkipBLdsTemplateParams \
BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \
element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \
MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
BBlockTransferSrcScalarPerVector, false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, \
CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
#define GridwiseGemmMultiDSkipBLdsTemplateParams \
BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \
element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \
MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BScalarPerVectorSkipLds, \
false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock
using GridwiseGemm =

View File

@@ -277,7 +277,7 @@ struct GridwiseGemm_xdlops_skip_b_lds_multiple_d_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned) * sizeof(ABDataType),
c_block_size * sizeof(EDataType));
c_block_size * sizeof(CShuffleDataType));
}
template <bool HasMainK0BlockLoop,