mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.
This commit is contained in:
@@ -10,11 +10,12 @@
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims{}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order{};
|
||||
ck::Array<size_t, 3> src_access_order{};
|
||||
ck::Array<size_t, ThreadClusterRank> thread_cluster_dims{};
|
||||
ck::Array<size_t, ThreadClusterRank> thread_cluster_order{};
|
||||
ck::Array<size_t, ThreadClusterRank> src_access_order{};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
@@ -22,27 +23,15 @@ struct BlockTransfer
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <size_t ThreadSliceDim = 3>
|
||||
struct BwdBlockTransfer
|
||||
{
|
||||
ck::Array<size_t, ThreadSliceDim> thread_cluster_dims{};
|
||||
ck::Array<size_t, ThreadSliceDim> thread_cluster_order{};
|
||||
ck::Array<size_t, ThreadSliceDim> src_access_order{};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
return BlockTransfer{
|
||||
return BlockTransfer<>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
@@ -68,7 +57,7 @@ constexpr auto SetBwdConvBlockTransfer()
|
||||
|
||||
if constexpr(array_length == 3)
|
||||
{
|
||||
return BwdBlockTransfer<3>{
|
||||
return BlockTransfer<3>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0],
|
||||
block_order.order[1],
|
||||
@@ -82,7 +71,7 @@ constexpr auto SetBwdConvBlockTransfer()
|
||||
}
|
||||
else if constexpr(array_length == 4)
|
||||
{
|
||||
return BwdBlockTransfer<4>{
|
||||
return BlockTransfer<4>{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size,
|
||||
block_xfer.k0,
|
||||
block_xfer.m_n,
|
||||
|
||||
@@ -120,20 +120,20 @@ inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <size_t ThreadSliceDim>
|
||||
inline std::string to_string(BlockTransfer<ThreadSliceDim> t)
|
||||
template <size_t ThreadClusterRank>
|
||||
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
|
||||
{
|
||||
if constexpr(ThreadSliceDim == 4)
|
||||
if constexpr(ThreadClusterRank == 4)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else if constexpr(ThreadSliceDim == 3)
|
||||
else if constexpr(ThreadClusterRank == 3)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim");
|
||||
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, "Unsupported ThreadClusterRank");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,8 +288,8 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <size_t ThreadSliceDim = 3>
|
||||
inline std::string to_string(Transfer_<ThreadSliceDim> t)
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
inline std::string to_string(Transfer_<ThreadClusterRank> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user