Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank.

This commit is contained in:
Ville Pietilä
2026-01-13 03:19:10 -05:00
parent 97793cf352
commit 1d519792ca
2 changed files with 15 additions and 26 deletions

View File

@@ -10,11 +10,12 @@
namespace ck_tile::builder::factory::internal {
// Block transfer parameters for A or B tensor.
template <size_t ThreadClusterRank = 3>
struct BlockTransfer
{
ck::Array<size_t, 3> thread_cluster_dims{}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order{};
ck::Array<size_t, 3> src_access_order{};
ck::Array<size_t, ThreadClusterRank> thread_cluster_dims{};
ck::Array<size_t, ThreadClusterRank> thread_cluster_order{};
ck::Array<size_t, ThreadClusterRank> src_access_order{};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
@@ -22,27 +23,15 @@ struct BlockTransfer
bool lds_padding = false;
};
template <size_t ThreadSliceDim = 3>
struct BwdBlockTransfer
{
ck::Array<size_t, ThreadSliceDim> thread_cluster_dims{};
ck::Array<size_t, ThreadSliceDim> thread_cluster_order{};
ck::Array<size_t, ThreadSliceDim> src_access_order{};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool lds_padding = false;
};
template <auto TRANSFER>
constexpr BlockTransfer SetFwdConvBlockTransfer()
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
return BlockTransfer{
return BlockTransfer<>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
@@ -68,7 +57,7 @@ constexpr auto SetBwdConvBlockTransfer()
if constexpr(array_length == 3)
{
return BwdBlockTransfer<3>{
return BlockTransfer<3>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0],
block_order.order[1],
@@ -82,7 +71,7 @@ constexpr auto SetBwdConvBlockTransfer()
}
else if constexpr(array_length == 4)
{
return BwdBlockTransfer<4>{
return BlockTransfer<4>{
.thread_cluster_dims = {block_xfer.k_batch_size,
block_xfer.k0,
block_xfer.m_n,

View File

@@ -120,20 +120,20 @@ inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
return oss.str();
}
template <size_t ThreadSliceDim>
inline std::string to_string(BlockTransfer<ThreadSliceDim> t)
template <size_t ThreadClusterRank>
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
{
if constexpr(ThreadSliceDim == 4)
if constexpr(ThreadClusterRank == 4)
{
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
}
else if constexpr(ThreadSliceDim == 3)
else if constexpr(ThreadClusterRank == 3)
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
}
else
{
static_assert(ThreadSliceDim == 3 || ThreadSliceDim == 4, "Unsupported ThreadSliceDim");
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, "Unsupported ThreadClusterRank");
}
}
@@ -288,8 +288,8 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
return to_string(t.gridwise_gemm);
}
template <size_t ThreadSliceDim = 3>
inline std::string to_string(Transfer_<ThreadSliceDim> t)
template <size_t ThreadClusterRank = 3>
inline std::string to_string(Transfer_<ThreadClusterRank> t)
{
return to_string(t.transfer);
}