diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index a89a190489..a4150e8d84 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -49,9 +49,11 @@ struct BatchedTransposeKernel CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args) { - size_t grid_size_x = (host_args.height + host_args.dim_block_h - 1) / host_args.dim_block_h; - size_t grid_size_y = (host_args.width + host_args.dim_block_w - 1) / host_args.dim_block_w; - size_t grid_size_z = host_args.batch; + const size_t grid_size_x = + ck_tile::integer_divide_ceil(host_args.height, host_args.dim_block_h); + const size_t grid_size_y = + ck_tile::integer_divide_ceil(host_args.width, host_args.dim_block_w); + const size_t grid_size_z = host_args.batch; return dim3(grid_size_x, grid_size_y, grid_size_z); } @@ -71,41 +73,43 @@ struct BatchedTransposeKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; - static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput; - static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput; + static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput; + static constexpr ck_tile::index_t VectorStrideInput = 1; + static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput; + static constexpr ck_tile::index_t VectorStrideOutput = 1; - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); - const auto iDim = blockIdx.z; + const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); + const auto offset = __builtin_amdgcn_readfirstlane(blockIdx.z * kargs.height * kargs.width); const auto x_m_n = [&]() { const auto x_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_input) + iDim * kargs.dim_stride, + static_cast(kargs.p_input) + offset, make_tuple(kargs.height, kargs.width), make_tuple(kargs.width, 1), number{}, - number<1>{}); + number{}); return pad_tensor_view(x_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); const auto y_n_m = [&]() { const auto y_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_output) + iDim * kargs.dim_stride, + static_cast(kargs.p_output) + offset, make_tuple(kargs.width, kargs.height), make_tuple(kargs.height, 1), number{}, - number<1>{}); + number{}); return pad_tensor_view(y_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); auto x_block_window = make_tile_window( diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp index e344c24bf5..3b8d5a142e 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -15,15 +15,15 @@ struct BatchedTransposeCommonPolicy template CK_TILE_DEVICE static constexpr auto MakeInputDistribution() { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t LeadDimPerBlock = Problem::kMPerBlock; - constexpr index_t SecondDimPerBlock = Problem::kNPerBlock; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kLeadDimPerBlock = Problem::kNPerBlock; + constexpr index_t kSecondDimPerBlock = Problem::kMPerBlock; - constexpr index_t kVectorSize = Problem::VectorSizeOutput; - - using TileEncodingPattern = TileDistributionEncodingPattern2D; return TileEncodingPattern::Make2DStaticTileDistribution(); diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp index 491db37564..45803ae2da 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp @@ -18,19 +18,19 @@ struct BatchedTransposeLdsProblem { using DataType = remove_cvref_t; - static constexpr index_t kRowWarps_ = NumWarps::at(number<1>{}); - static constexpr index_t kColWarps_ = NumWarps::at(number<0>{}); + static constexpr index_t kRowWarps_ = NumWarps::at(number<0>{}); + static constexpr index_t kColWarps_ = NumWarps::at(number<1>{}); static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_; - static constexpr index_t kRowPerBlock_ = BlockTile::at(number<1>{}); - static constexpr index_t kColPerBlock_ = BlockTile::at(number<0>{}); + static constexpr index_t kRowPerBlock_ = BlockTile::at(number<0>{}); + static constexpr index_t kColPerBlock_ = BlockTile::at(number<1>{}); static constexpr index_t kBlockSize = kBlockSize_; // warps per block - static constexpr index_t kLeadNumWarps = kRowWarps_; - static constexpr index_t kSecondNumWarps = kColWarps_; + static constexpr index_t kLeadNumWarps = kColWarps_; + static constexpr index_t kSecondNumWarps = kRowWarps_; - static constexpr index_t kLeadSizePerBlock = kRowPerBlock_; - static constexpr index_t kSecondSizePerBlock = kColPerBlock_; + static constexpr index_t kLeadSizePerBlock = kColPerBlock_; + static constexpr index_t kSecondSizePerBlock = kRowPerBlock_; static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; @@ -60,8 +60,8 @@ struct BatchedTransposeLdsProblem static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; - static constexpr auto kMPerBlock = kLeadSizePerBlock; - static constexpr auto kNPerBlock = kSecondSizePerBlock; + static constexpr auto kMPerBlock = kSecondSizePerBlock; + static constexpr auto kNPerBlock = kLeadSizePerBlock; // 128-bit is the max single-instruction bandwidth for load/store static constexpr index_t MaxLoadStoreSize = 16; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index 5238fecdc5..e6bbc709ea 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -19,8 +19,8 @@ struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy constexpr index_t VecLoadSize = Problem::VectorSizeOutput; using TileEncodingPattern = TileDistributionEncodingPattern2D; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); diff --git a/test/ck_tile/batched_transpose/test_batched_transpose.cpp b/test/ck_tile/batched_transpose/test_batched_transpose.cpp index cce00e27cb..77d5825eed 100644 --- a/test/ck_tile/batched_transpose/test_batched_transpose.cpp +++ b/test/ck_tile/batched_transpose/test_batched_transpose.cpp @@ -95,10 +95,12 @@ class TestCkTileBatchedTranspose // N C H W layout_in== ck_tile::HostTensor y_ref(Y_dim, Y_stride); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillConstant{-37}(y_host); ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); x_dev.ToDevice(x_host.data()); + y_dev.ToDevice(y_host.data()); using Kernel = typename Config::Kernel; @@ -131,8 +133,8 @@ class TestCkTileBatchedTranspose // N C H W layout_in== height, width, height * width, - Config::BlockTile::at(1), - Config::BlockTile::at(0)}; + Config::BlockTile::at(0), + Config::BlockTile::at(1)}; auto kargs = Kernel::MakeKargs(host_args); auto sc = ck_tile::stream_config{}; @@ -140,15 +142,24 @@ class TestCkTileBatchedTranspose // N C H W layout_in== constexpr dim3 block_size = Kernel::BlockSize(); ck_tile::launch_kernel( sc, ck_tile::make_kernel(Kernel{}, grid_size, block_size, 0, kargs)); + y_dev.FromDevice(y_host.data()); ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); std::ostringstream message; message << "N=" << N << " C=" << C << " H=" << H << " W=" << W << " layout_in=" << layout_in - << " layout_out=" << layout_out << " device_name=" << device_name; + << " layout_out=" << layout_out << " grid_size={" << grid_size.x << ", " + << grid_size.y << ", " << grid_size.z << "} block_size=" << block_size.x + << " device_name=" << device_name; + // NB: order of output and reference matters bool pass = ck_tile::check_err( - y_ref, y_host, message.str(), /* rtol */ 0, /* atol */ 0, /* allow inf */ false); + /* out */ y_host, + /* ref */ y_ref, + message.str(), + /* rtol */ 0, + /* atol */ 0, + /* allow inf */ false); EXPECT_TRUE(pass); } @@ -160,14 +171,16 @@ static const auto kTestingValues = ::testing::Values( // N C H W layout_in==NCHW std::tuple{1, 32, 1, 32, true}, std::tuple{1, 64, 1, 64, true}, + std::tuple{1, 32, 1, 64, true}, + std::tuple{1, 64, 1, 32, true}, std::tuple{2, 12, 1, 32, false}, std::tuple{3, 1334, 1, 37, false}, std::tuple{4, 27, 1, 32, true}, std::tuple{5, 1234, 1, 12, true}, std::tuple{1, 1, 1, 1, true}, std::tuple{1, 1, 1, 1, false}, - std::tuple{128, 1024, 64, 64, true}, - std::tuple{128, 1024, 64, 64, false}, + std::tuple{17, 1024, 64, 64, true}, + std::tuple{17, 1024, 64, 64, false}, std::tuple{16, 64, 32, 128, true}, std::tuple{16, 64, 128, 32, false}, std::tuple{1, 2048, 1, 1, true}, @@ -239,6 +252,60 @@ class CaseHalfPadMultiWarpLoadTranspose { }; +class CaseHalfPadMultiWarp128MNLoadTranspose + : public TestCkTileBatchedTranspose> +{ +}; + +class CaseHalfPadMultiWarp128MN + : public TestCkTileBatchedTranspose< + PipelineConfig> +{ +}; + +class CaseHalfPadRectTile1 + : public TestCkTileBatchedTranspose< + PipelineConfig> +{ +}; + +class CaseHalfPadRectTile2 + : public TestCkTileBatchedTranspose< + PipelineConfig> +{ +}; + +class CaseHalfPadRectTile1LoadTranspose + : public TestCkTileBatchedTranspose> +{ +}; + +class CaseHalfPadRectTile2LoadTranspose + : public TestCkTileBatchedTranspose> +{ +}; + TEST_P(CaseHalf, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseByte, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseWord, TestCorrectness) { this->Run(GetParam()); } @@ -248,6 +315,12 @@ TEST_P(CaseHalfPad, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadLoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadMultiWarp, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadMultiWarpLoadTranspose, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseHalfPadMultiWarp128MN, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseHalfPadMultiWarp128MNLoadTranspose, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseHalfPadRectTile1, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseHalfPadRectTile1LoadTranspose, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseHalfPadRectTile2, TestCorrectness) { this->Run(GetParam()); } +TEST_P(CaseHalfPadRectTile2LoadTranspose, TestCorrectness) { this->Run(GetParam()); } // clang-format off INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalf, kTestingValues); @@ -259,4 +332,11 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPad, kTestingV INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadLoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarp, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarpLoadTranspose, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarp128MN, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarp128MNLoadTranspose, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile1, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile1LoadTranspose, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile2, kTestingValues); +INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile2LoadTranspose, kTestingValues); + // clang-format on