From c862437f0586040ca3d146e4b66cce9682dd979a Mon Sep 17 00:00:00 2001 From: joye Date: Wed, 23 Apr 2025 19:27:29 -0500 Subject: [PATCH] fix some issues --- .../transpose/batched_transpose_kernel.hpp | 8 +- .../ck_tile/ops/transpose/block_transpose.hpp | 83 ++++++++++--------- .../ops/transpose/transpose_policy.hpp | 16 ++-- 3 files changed, 55 insertions(+), 52 deletions(-) diff --git a/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp b/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp index 33b6d7b72c..24e2492ce5 100644 --- a/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/transpose/batched_transpose_kernel.hpp @@ -71,8 +71,8 @@ struct BatchedTransposeKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { __shared__ char smem[Pipeline::GetSmemSize()]; - static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondDimWarps; - static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadDimWarps; + static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock; const auto iDim = blockIdx.z; const auto x_m_n = [&]() { @@ -88,8 +88,8 @@ struct BatchedTransposeKernel sequence{}); }(); - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); + const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock); const auto y_n_m = [&]() { const auto y_dram_naive = make_naive_tensor_view( diff --git a/include/ck_tile/ops/transpose/block_transpose.hpp b/include/ck_tile/ops/transpose/block_transpose.hpp index 02d458552e..950659aba9 100644 --- a/include/ck_tile/ops/transpose/block_transpose.hpp +++ b/include/ck_tile/ops/transpose/block_transpose.hpp @@ -30,10 +30,10 @@ template // this is the col number per warp per iteration + index_t kRowPerBlock_, // row number per block + index_t kColPerBlock_, // col number per block + index_t kRowPerXdl_, // row number per xdl ops + index_t kColPerXdl_> // col number per xdl ops struct TransposePipelineProblem { static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_, @@ -41,18 +41,41 @@ struct TransposePipelineProblem using DataType = remove_cvref_t; using Layout = remove_cvref_t; static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kLeadDimWarps = + static constexpr index_t kLeadNumWarps = TransposeTraits::kLeadDim; - static constexpr index_t kSecondDimWarps = + static constexpr index_t kSecondNumWarps = TransposeTraits::kSecondDim; - static constexpr index_t kLeadDimPerBlock = + static constexpr index_t kLeadSizePerBlock = TransposeTraits::kLeadDim; - static constexpr index_t kSecondDimPerBlock = + static constexpr index_t kSecondSizePerBlock = TransposeTraits::kSecondDim; - static constexpr index_t kLeadDimPerWarp = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondDimPerWarp = - TransposeTraits::kSecondDim; + static constexpr index_t kLeadSizePerXdl = + TransposeTraits::kLeadDim; + static constexpr index_t kSecondSizePerXdl = + TransposeTraits::kSecondDim; + + static constexpr index_t kQuadrantLeadDim = QuartTransposeTraits::kleadDim; + static constexpr index_t kQuadrantSecondDim = QuartTransposeTraits::ksecondDim; + + static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, "block dim should be divided by warp dim!"); + static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, "block dim should be divided by warp dim!"); + + static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; + static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; + + static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, "warp dim should be divided by xdl dim!"); + static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, "warp dim should be divided by xdl dim!"); + + static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl; + static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl; + + static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, "xdl dim should be divided by quad dim!"); + static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, "xdl dim should be divided by quad dim!"); + + static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim; + static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim; + + static constexpr index_t kIterationsPerSecondDim = kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); }; template @@ -65,32 +88,12 @@ struct BlockTranspose using Layout = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock; - static constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock; - static constexpr index_t kLeadDimPerWarp = Problem::kLeadDimPerWarp; - static constexpr index_t kSecondDimPerWarp = Problem::kSecondDimPerWarp; - - static constexpr index_t kQuadrantLeadDim = QuartTransposeTraits::kleadDim; - static constexpr index_t kQuadrantSecondDim = QuartTransposeTraits::ksecondDim; - - static_assert(kLeadDimPerBlock % kLeadDimPerWarp == 0, "row per block is not correct!"); - static_assert(kSecondDimPerBlock % kSecondDimPerWarp == 0, "col per block is not correct!"); - - static_assert(kLeadDimPerWarp % kQuadrantLeadDim == 0, "row per warp is not correct!"); - static_assert(kSecondDimPerWarp % kQuadrantSecondDim == 0, "col per warp is not correct!"); - - static constexpr index_t kNumWarpInLeadDim = kLeadDimPerBlock / kLeadDimPerWarp; - static constexpr index_t kNumWarpInSecondDim = kSecondDimPerBlock / kSecondDimPerWarp; - - static constexpr index_t kLeadDimPerWarpInQuadrant = kLeadDimPerWarp / kQuadrantLeadDim; - static constexpr index_t kSecondDimPerWarpInQuadrant = kSecondDimPerWarp / kQuadrantSecondDim; - - // this pipeline is only designed for wave64 now - static_assert(get_warp_size() == 64, "the warp size is not correct!"); - static_assert(kBlockSize == kNumWarpInLeadDim * kNumWarpInSecondDim * get_warp_size(), - "the block size is not correct!"); - //static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(), - // "the warp size is not correct!"); + static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; + static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; + static constexpr index_t kLeadSizePerXdl = Problem::kLeadSizePerXdl; + static constexpr index_t kSecondSizePerXdl = Problem::kSecondSizePerXdl; + static constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps; + static constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps; static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } @@ -119,12 +122,12 @@ struct BlockTranspose auto copy_to_lds_window = make_tile_window(input_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}); auto load_from_lds_window = make_tile_window(output_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeLdsLoadTileDistribution()); diff --git a/include/ck_tile/ops/transpose/transpose_policy.hpp b/include/ck_tile/ops/transpose/transpose_policy.hpp index 3175c883f9..0375d57085 100644 --- a/include/ck_tile/ops/transpose/transpose_policy.hpp +++ b/include/ck_tile/ops/transpose/transpose_policy.hpp @@ -59,8 +59,8 @@ struct TransposePolicy CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() { constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t LeadDimPerBlock = Problem::kLeadDimPerBlock; - constexpr index_t SecondDimPerBlock = Problem::kSecondDimPerBlock; + constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock; + constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock; constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType); using TileEncodingPattern = TileDistributionEncodingPattern2D; - constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock; - constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock; + constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; + constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType); constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( @@ -130,9 +130,9 @@ struct TransposePolicy decltype(detail::make_embed_tile_distribution_encoding( WarpLevelOuterDistribution_{}, QuartTransposeTileDistribution{})); constexpr index_t LeadDimIterPerWarp = - Problem::kLeadDimPerBlock / (Problem::kLeadDimPerWarp * Problem::kLeadDimWarps); + Problem::kLeadSizePerBlock / (Problem::kLeadSizePerWarp * Problem::kLeadSizeWarps); constexpr index_t SecondDimIterPerWarp = - Problem::kSecondDimPerBlock / (Problem::kSecondDimPerWarp * Problem::kSecondDimWarps); + Problem::kSecondSizePerBlock / (Problem::kSecondSizePerWarp * Problem::kSecondSizeWarps); constexpr auto block_outer_dst_encoding = tile_distribution_encoding< sequence<>,