// SPDX-License-Identifier: MIT // Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/tdm.hpp" #include "ck_tile/host/kernel_launch.hpp" namespace ck_tile { namespace test { using F16 = half_t; using F8 = fp8_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; using GatherModeEnable = bool_constant; using GatherModeDisable = bool_constant; using Gather16bitIndex = constant; using Gather32bitIndex = constant; struct TDMTestParams { index_t m = 16; index_t n = 16; index_t x_stride = -1; index_t y_stride = -1; int do_validation = 1; int warmup = 0; int repeat = 1; template void normalize() { if constexpr(std::is_same_v) { if(x_stride < 0) x_stride = n; if(y_stride < 0) y_stride = n; } else { if(x_stride < 0) x_stride = m; if(y_stride < 0) y_stride = m; } } }; using TestTypes = ::testing::Types, std::tuple, std::tuple, std::tuple, std::tuple, std::tuple, std::tuple, std::tuple>; template class TDMBasicTypedTest : public ::testing::Test { protected: using DataType = std::tuple_element_t<0, TypeParam>; using Layout = std::tuple_element_t<1, TypeParam>; using GatherMode = std:: conditional_t::value == 3, GatherModeEnable, GatherModeDisable>; template struct GatherModeDTypeHelper { using type = uint16_t; // dummy data type when gather mode is disabled }; template struct GatherModeDTypeHelper { using type = std::conditional_t{}() == TDMGatherIndexSize::Row16bit_Index, uint16_t, uint32_t>; }; using GatherModeDType = GatherModeDTypeHelper>::type; static constexpr index_t tensor_rank = 2; static constexpr index_t tile_m = 16; static constexpr index_t tile_n = 16; static constexpr index_t warp_m = 1; static constexpr index_t warp_n = 1; static constexpr index_t warp_tile_m = 16; static constexpr index_t warp_tile_n = 16; // Common type definitions using TDMShape = TDMTileShape, sequence, sequence>; // Constants static constexpr index_t warp_size = 32; static constexpr index_t cluster_dim_x = 2; static constexpr index_t cluster_dim_y = 1; static constexpr index_t cluster_dim_z = 1; private: // Helper functions static std::vector get_tensor_dims(const TDMTestParams& params, bool is_cluster_test) { return (!is_cluster_test && std::is_same_v) ? std::vector{params.n, params.m} : std::vector{params.m, params.n}; } template struct TDMTraitsFactory { using type = TDMPipelineTraits< DataType, std::conditional_t, GatherModeDType, false, /*AtomicBarrierEnable_*/ IsGatherMode, /*IsGatherMode_*/ false, /*IterateEnable_*/ false, /*PadEnable_*/ false, /*EarlyTimeOutEnable_*/ IsClusterMode /*ClusterEnable_*/>; }; struct TDMTestData { HostTensor x_host; HostTensor y_host; HostTensor ref_host; HostTensor gather_index_host; DeviceMem x_buf; DeviceMem y_buf; DeviceMem gather_index_buf; TDMTestData(const std::vector& dims, const TDMTestParams& params, bool use_cluster, bool use_gather) : x_host({dims[0], dims[1]}, {params.x_stride, 1}), y_host({dims[0], dims[1]}, {params.y_stride, 1}), ref_host({dims[0], dims[1]}, {params.y_stride, 1}), gather_index_host(use_gather ? std::vector{warp_tile_m} : std::vector{}), x_buf(x_host.get_element_space_size_in_bytes()), y_buf(y_host.get_element_space_size_in_bytes()), gather_index_buf(use_gather ? gather_index_host.get_element_space_size_in_bytes() : 0) { FillUniformDistribution{-.5f, .5f}(x_host); if(use_gather) { for(index_t i = 0; i < warp_tile_m; i++) { gather_index_host.data()[i] = static_cast(i); } std::shuffle(gather_index_host.begin(), gather_index_host.end(), std::mt19937{std::random_device{}()}); gather_index_buf.ToDevice(gather_index_host.data()); for(index_t r = 0; r < dims[0]; r += warp_tile_m) { for(index_t inner_r = 0; inner_r < warp_tile_m; inner_r++) { index_t ref_idx = 0; index_t gather_idx = static_cast(gather_index_host(static_cast(inner_r))); for(index_t c = 0; c < dims[1]; c++) { ref_host({static_cast(r + inner_r + ref_idx), static_cast(c)}) = x_host( {static_cast(r + gather_idx), static_cast(c)}); } ref_idx++; } } } else { for(index_t r = 0; r < dims[0]; r += 1) { for(index_t c = 0; c < dims[1]; c += 1) { ref_host({static_cast(r), static_cast(c)}) = x_host({static_cast(r), static_cast(c)}); } } } if(use_cluster) { // for sanity test; only copy the fist half data. for(index_t r = 0; r < dims[0]; r += 1) { for(index_t c = 0; c < dims[1]; c += 1) { ref_host({static_cast(r), static_cast(c)}) = r >= dims[0] / 2 ? x_host({static_cast(r - dims[0] / 2), static_cast(c)}) : x_host({static_cast(r), static_cast(c)}); } } } x_buf.ToDevice(x_host.data()); y_buf.SetZero(); } }; template bool launch_tdm_kernel(TDMTestData& test_data, const TDMTestParams& params, bool use_cluster = false, bool use_gather = true) { dim3 grid((params.m + tile_m - 1) / tile_m, (params.n + tile_n - 1) / tile_n); assert(is_wave32()); const index_t block_size = warp_m * warp_n * warp_size; dim3 block(block_size); stream_config s{nullptr, false, 0, params.warmup, params.repeat}; // Determine gather pointer based on usage void* gather_ptr = use_gather ? test_data.gather_index_buf.GetDeviceBuffer() : nullptr; TDMCopyDeviceKernArgs args{test_data.x_buf.GetDeviceBuffer(), test_data.y_buf.GetDeviceBuffer(), gather_ptr, params.m, params.n, params.x_stride, params.y_stride}; if(use_cluster) { hipLaunchConfig_t config{}; config.gridDim = grid; config.blockDim = block; config.dynamicSmemBytes = 0; config.stream = s.stream_id_; hipLaunchAttribute attribute[1]; attribute[0].id = hipLaunchAttributeClusterDimension; attribute[0].val.clusterDim.x = cluster_dim_x; attribute[0].val.clusterDim.y = cluster_dim_y; attribute[0].val.clusterDim.z = cluster_dim_z; config.attrs = attribute; config.numAttrs = 1; auto kernel_func = kentry, TDMCopyDeviceKernArgs>; HIP_CHECK_ERROR(hipLaunchKernelEx(&config, kernel_func, args)); } else { TDMCopyKernel tdm_kernel; launch_kernel(s, make_kernel(tdm_kernel, grid, block, 0, args)); } test_data.y_buf.FromDevice(test_data.y_host.data()); return true; } bool validate_results(TDMTestData& test_data) const { return check_err( test_data.y_host, test_data.ref_host, "Error: Incorrect tdm copy results!"); } template bool run_tdm_test_generic(const TDMTestParams& params) { const std::vector dims = get_tensor_dims(params, IsClusterMode); TDMTestData test_data(dims, params, IsClusterMode, IsGatherMode); using TDMTraits = typename TDMTraitsFactory::type; using TDMProblem = TDMPipelineProblem; launch_tdm_kernel(test_data, params, IsClusterMode, IsGatherMode); if(params.do_validation) { return validate_results(test_data); } return true; } public: bool run_tdm_test(const TDMTestParams& params) { return run_tdm_test_generic>(params); } template bool run_tdm_cluster_test(const TDMTestParams& params) { return run_tdm_test_generic(params); } }; TYPED_TEST_SUITE(TDMBasicTypedTest, TestTypes); TYPED_TEST(TDMBasicTypedTest, SanityTest) { TDMTestParams params; params.m = 16; params.n = 16; params.template normalize(); EXPECT_TRUE(this->run_tdm_test(params)); } TYPED_TEST(TDMBasicTypedTest, SanityClusterTest) { TDMTestParams params; params.m = 32; params.n = 16; if constexpr(std::is_same_v) { GTEST_SKIP(); } params.template normalize(); EXPECT_TRUE(this->run_tdm_cluster_test(params)); } TYPED_TEST(TDMBasicTypedTest, SanityClusterGatherTest) { TDMTestParams params; params.m = 32; params.n = 16; if constexpr(std::is_same_v) { GTEST_SKIP(); } params.template normalize(); EXPECT_TRUE(this->template run_tdm_cluster_test(params)); } TYPED_TEST(TDMBasicTypedTest, RectangleTest) { TDMTestParams params; params.m = 64; params.n = 32; params.template normalize(); EXPECT_TRUE(this->run_tdm_test(params)); } TYPED_TEST(TDMBasicTypedTest, LargeDimTest) { TDMTestParams params; params.m = 256; params.n = 256; params.template normalize(); EXPECT_TRUE(this->run_tdm_test(params)); } } // namespace test } // namespace ck_tile int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }