Use ThreadGroupTensorSliceTransfer_v7r3

This commit is contained in:
Anton Gorenko
2025-05-29 13:10:27 +05:00
parent ed047d08b4
commit deebe1ea13

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -225,8 +225,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// TODO: remove
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
@@ -1304,32 +1303,30 @@ struct GridwiseGemm_wmma_cshuffle_v3
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
@@ -1719,28 +1716,31 @@ struct GridwiseGemm_wmma_cshuffle_v3
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOp,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
sequence_merge_t<Sequence<true>,
uniform_sequence_gen_t<
NumDTensor,
false>>, // bool ThreadTransferSrcResetCoordinateAfterRun,
Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRun>
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
@@ -1790,7 +1790,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_to_global.Run(
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
@@ -1801,13 +1801,13 @@ struct GridwiseGemm_wmma_cshuffle_v3
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, cde_global_step);
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}