diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 6205ec355b..90690c15dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -369,6 +369,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); + static_assert(!SkipBLds || AK1 == BK1); + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this // implementation we can avoid copy data to workspace before kernel launch since number of // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then @@ -529,18 +531,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType static constexpr index_t BBlockBufferSize = 1; + // Force to 1, due to KN layout for GKYXC + static constexpr index_t BScalarPerVectorSkipLds = 1; -#define GridwiseGemmMultiDSkipBLdsTemplateParams \ - BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ - InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \ - element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \ - MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferSrcScalarPerVector, false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, \ - CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ +#define GridwiseGemmMultiDSkipBLdsTemplateParams \ + BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \ + element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \ + MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BScalarPerVectorSkipLds, \ + false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock using GridwiseGemm = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_multiple_d_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_multiple_d_cshuffle.hpp index e4384ccc47..fa2f28cdc9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_multiple_d_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_multiple_d_cshuffle.hpp @@ -277,7 +277,7 @@ struct GridwiseGemm_xdlops_skip_b_lds_multiple_d_cshuffle c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); return math::max((a_block_space_size_aligned) * sizeof(ABDataType), - c_block_size * sizeof(EDataType)); + c_block_size * sizeof(CShuffleDataType)); } template