diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 666599bf44..3a24385ef4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -225,8 +225,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 static constexpr auto I6 = Number<6>{}; static constexpr auto I7 = Number<7>{}; - // TODO: remove - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + static constexpr auto EShuffleBlockTransferScalarPerVector = CDEShuffleBlockTransferScalarPerVectors{}[I0]; // K1 should be Number<...> @@ -1304,32 +1303,30 @@ struct GridwiseGemm_wmma_cshuffle_v3 if constexpr(is_same::value) { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + if(karg.N % EShuffleBlockTransferScalarPerVector != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + "EShuffleBlockTransferScalarPerVector (" + << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } } else { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + if(karg.M % EShuffleBlockTransferScalarPerVector != 0) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + "EShuffleBlockTransferScalarPerVector (" + << EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } @@ -1719,28 +1716,31 @@ struct GridwiseGemm_wmma_cshuffle_v3 // blockwise copy which loads C from LDS, D from global, applies elementwise // operation and stores result E to global - auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< ThisThreadBlock, // ThreadGroup decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), CDEElementwiseOperation, // ElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // DstInMemOp, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOps, Sequence<1, CShuffleMRepeatPerShuffle * MWave * MPerWmma, 1, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - sequence_merge_t, - uniform_sequence_gen_t< - NumDTensor, - false>>, // bool ThreadTransferSrcResetCoordinateAfterRun, - Sequence> // bool ThreadTransferDstResetCoordinateAfterRun> + Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // DstDimAccessOrder, + 3, // SrcVectorDim, + 3, // DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors + EShuffleBlockTransferScalarPerVector, // DstScalarPerVector + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags {c_ds_desc_refs, idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), @@ -1790,7 +1790,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 // each block loads its C data from LDS, D from global, applies elementwise // operation and stores result E to global - cde_shuffle_block_copy_lds_to_global.Run( + cde_shuffle_block_copy_lds_and_global.Run( c_ds_desc_refs, c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), @@ -1801,13 +1801,13 @@ struct GridwiseGemm_wmma_cshuffle_v3 constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); // move on Ds static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow( c_ds_desc_refs, i + I1, cde_global_step); }); // move on E - cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, cde_global_step); + cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step); } }); }