diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 4ef2f533c9..d873a4b903 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -10,11 +10,12 @@ namespace ck_tile::builder::factory::internal { // Block transfer parameters for A or B tensor. +template struct BlockTransfer { - ck::Array thread_cluster_dims{}; // k0, m, k1 - ck::Array thread_cluster_order{}; - ck::Array src_access_order{}; + ck::Array thread_cluster_dims{}; + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -22,27 +23,15 @@ struct BlockTransfer bool lds_padding = false; }; -template -struct BwdBlockTransfer -{ - ck::Array thread_cluster_dims{}; - ck::Array thread_cluster_order{}; - ck::Array src_access_order{}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool lds_padding = false; -}; - template -constexpr BlockTransfer SetFwdConvBlockTransfer() +constexpr BlockTransfer<> SetFwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; auto& block_order = TRANSFER.block_transfer_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; - return BlockTransfer{ + return BlockTransfer<>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, @@ -68,7 +57,7 @@ constexpr auto SetBwdConvBlockTransfer() if constexpr(array_length == 3) { - return BwdBlockTransfer<3>{ + return BlockTransfer<3>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], @@ -82,7 +71,7 @@ constexpr auto SetBwdConvBlockTransfer() } else if constexpr(array_length == 4) { - return BwdBlockTransfer<4>{ + return BlockTransfer<4>{ .thread_cluster_dims = {block_xfer.k_batch_size, block_xfer.k0, block_xfer.m_n, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index d80d6a1b8c..f0c3bcde39 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -120,20 +120,20 @@ inline std::string to_string(BlockGemmPipeline t) return oss.str(); } -template -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - if constexpr(ThreadSliceDim == 4) + if constexpr(ThreadClusterRank == 4) { return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); } - else if constexpr(ThreadSliceDim == 3) + else if constexpr(ThreadClusterRank == 3) { return array_to_seq(std::array{t.k0, t.m_n, t.k1}); } else { - static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim"); + static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, "Unsupported ThreadClusterRank"); } } @@ -288,8 +288,8 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); }