// SPDX-License-Identifier: MIT // Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #include #include #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/batched_transpose.hpp" enum class PipelineTag : ck_tile::index_t { Universal, LDSLoadTranspose, }; template struct PipelineSelector { }; template <> struct PipelineSelector { template using Problem = ck_tile::BatchedTransposeProblem; using Policy = ck_tile::BatchedTransposePolicy; template using Pipeline = ck_tile::BatchedTransposePipeline; }; template <> struct PipelineSelector { template using Problem = ck_tile::BatchedTransposeLdsProblem; using Policy = ck_tile::BatchedTransposeLdsPolicy; template using Pipeline = ck_tile::BatchedTransposeLdsPipeline; }; template struct PipelineConfig { using DataType = DataType_; using BlockTile = ck_tile::sequence; using WarpLayout = ck_tile::sequence; static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr PipelineTag kPipelineId = kPipelineId_; static constexpr ck_tile::index_t kBlockX = kBlockX_; static constexpr ck_tile::index_t kBlockY = kBlockY_; static constexpr ck_tile::index_t kNumWarpsX = kNumWarpsX_; static constexpr ck_tile::index_t kNumWarpsY = kNumWarpsY_; using Problem = typename PipelineSelector< kPipelineId_>::template Problem; using Pipeline = typename PipelineSelector::template Pipeline; using Kernel = ck_tile::BatchedTransposeKernel; }; template class TestCkTileBatchedTranspose // N C H W layout_in==NCHW : public ::testing::TestWithParam> { protected: void Run(std::tuple param) { using DataType = typename Config::DataType; const auto [N, C, H, W, nchw2nhwc] = param; const std::string layout_in = nchw2nhwc ? "NCHW" : "NHWC"; const std::string layout_out = nchw2nhwc ? "NHWC" : "NCHW"; const auto X_dim = nchw2nhwc ? std::array{N, C, H, W} : std::array{N, H, W, C}; const auto X_stride = nchw2nhwc ? std::array{C * H * W, H * W, W, 1} : std::array{C * H * W, C * W, C, 1}; ck_tile::HostTensor x_host(X_dim, X_stride); const auto Y_dim = nchw2nhwc ? std::array{N, H, W, C} : std::array{N, C, H, W}; const auto Y_stride = nchw2nhwc ? std::array{C * H * W, C * W, C, 1} : std::array{C * H * W, H * W, W, 1}; ck_tile::HostTensor y_host(Y_dim, Y_stride); ck_tile::HostTensor y_ref(Y_dim, Y_stride); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillConstant{-37}(y_host); ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); x_dev.ToDevice(x_host.data()); y_dev.ToDevice(y_host.data()); using Kernel = typename Config::Kernel; const ck_tile::index_t height = nchw2nhwc ? C : H * W; const ck_tile::index_t width = nchw2nhwc ? H * W : C; if(height % Config::kBlockX != 0 && !Config::kPadM) { GTEST_SKIP_("Input cannot be covered with block tiles and Kernel does not force height " "padding"); } if(width % Config::kBlockY != 0 && !Config::kPadN) { GTEST_SKIP_( "Input cannot be covered with block tiles and Kernel does not force width padding"); } const auto device_name = ck_tile::get_device_name(); if(Config::kPipelineId == PipelineTag::LDSLoadTranspose && device_name.find("gfx950") == std::string::npos) { GTEST_SKIP_("LDS Load Transpose cannot be launched with this device"); } const auto host_args = ck_tile::BatchedTransposeHostArgs{x_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), N, height, width, height * width, Config::BlockTile::at(0), Config::BlockTile::at(1)}; auto kargs = Kernel::MakeKargs(host_args); auto sc = ck_tile::stream_config{}; const dim3 grid_size = Kernel::GridSize(host_args); const dim3 block_size = Kernel::BlockSize(); ck_tile::launch_kernel(sc, ck_tile::make_kernel<1>(Kernel{}, grid_size, block_size, 0, kargs)); y_dev.FromDevice(y_host.data()); ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); std::ostringstream message; message << "N=" << N << " C=" << C << " H=" << H << " W=" << W << " layout_in=" << layout_in << " layout_out=" << layout_out << " grid_size={" << grid_size.x << ", " << grid_size.y << ", " << grid_size.z << "} block_size=" << block_size.x << " device_name=" << device_name; // NB: order of output and reference matters bool pass = ck_tile::check_err( /* out */ y_host, /* ref */ y_ref, message.str(), /* rtol */ 0, /* atol */ 0, /* allow inf */ false); EXPECT_TRUE(pass); } }; // clang-format off // the default indent is not sane static const auto kTestingValues = ::testing::Values( // N C H W layout_in==NCHW std::tuple{1, 32, 1, 32, true}, std::tuple{1, 64, 1, 64, true}, std::tuple{1, 32, 1, 64, true}, std::tuple{1, 64, 1, 32, true}, std::tuple{2, 12, 1, 32, false}, std::tuple{3, 1334, 1, 37, false}, std::tuple{4, 27, 1, 32, true}, std::tuple{5, 1234, 1, 12, true}, std::tuple{1, 1, 1, 1, true}, std::tuple{1, 1, 1, 1, false}, std::tuple{17, 1024, 64, 64, true}, std::tuple{17, 1024, 64, 64, false}, std::tuple{16, 64, 32, 128, true}, std::tuple{16, 64, 128, 32, false}, std::tuple{1, 2048, 1, 1, true}, std::tuple{1, 2048, 1, 1, false}, std::tuple{1, 1, 1024, 1024, true}, std::tuple{1, 1, 1024, 1024, false}, std::tuple{8, 16, 8, 16, true}, std::tuple{8, 16, 8, 16, false}, std::tuple{1, 64, 1, 1024, true}, std::tuple{1, 64, 1024, 1, false} ); // clang-format on class CaseHalf : public TestCkTileBatchedTranspose> { }; class CaseByte : public TestCkTileBatchedTranspose> { }; class CaseWord : public TestCkTileBatchedTranspose> { }; class CaseHalfLoadTranspose : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseByteLoadTranspose : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseHalfPad : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseHalfPadLoadTranspose : public TestCkTileBatchedTranspose> { }; class CaseHalfPadMultiWarp : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseHalfPadMultiWarpLoadTranspose : public TestCkTileBatchedTranspose> { }; class CaseHalfPadMultiWarp128MNLoadTranspose : public TestCkTileBatchedTranspose> { }; class CaseHalfPadMultiWarp128MN : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseHalfPadRectTile1 : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseHalfPadRectTile2 : public TestCkTileBatchedTranspose< PipelineConfig> { }; class CaseHalfPadRectTile1LoadTranspose : public TestCkTileBatchedTranspose> { }; class CaseHalfPadRectTile2LoadTranspose : public TestCkTileBatchedTranspose> { }; TEST_P(CaseHalf, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseByte, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseWord, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfLoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseByteLoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPad, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadLoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadMultiWarp, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadMultiWarpLoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadMultiWarp128MN, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadMultiWarp128MNLoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile1, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile1LoadTranspose, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile2, TestCorrectness) { this->Run(GetParam()); } TEST_P(CaseHalfPadRectTile2LoadTranspose, TestCorrectness) { this->Run(GetParam()); } // clang-format off INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalf, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseByte, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseWord, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfLoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseByteLoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPad, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadLoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarp, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarpLoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarp128MN, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadMultiWarp128MNLoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile1, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile1LoadTranspose, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile2, kTestingValues); INSTANTIATE_TEST_SUITE_P(TestCkTileBatchedTransposeSuite, CaseHalfPadRectTile2LoadTranspose, kTestingValues); // clang-format on