From e3efa236ec7991d5060c28b4cf952d1ae2a15647 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Tue, 20 Jan 2026 11:42:10 -0500 Subject: [PATCH] Add the structure for testing the Sinkhorn-Knopp kernel --- .../sinkhorn_knopp_default_policy.hpp | 4 +- .../pipeline/sinkhorn_knopp_problem.hpp | 68 ++++++++++--- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 84 ++++++++-------- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/sinkhorn/CMakeLists.txt | 8 ++ test/ck_tile/sinkhorn/test_sinkhorn.cpp | 42 ++++++++ test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp | 97 +++++++++++++++++++ 7 files changed, 244 insertions(+), 60 deletions(-) create mode 100644 test/ck_tile/sinkhorn/CMakeLists.txt create mode 100644 test/ck_tile/sinkhorn/test_sinkhorn.cpp create mode 100644 test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp index 2abc6377e3..81826004e0 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp @@ -36,8 +36,8 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy tile_distribution_encoding< sequence<>, tuple< - sequence>, - sequence, + sequence, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 1>>, sequence<1, 1, 2, 2>, diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp index 5953ecd8ff..69f0bee3bc 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp @@ -4,24 +4,60 @@ namespace ck_tile { -template +// template +// struct SinkHornKnoppShape +// { +// static constexpr index_t Block_M = WarpPerBlock_M; +// static constexpr index_t Block_N = WarpPerBlock_N; +// static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M; +// static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N; +// static constexpr index_t ThreadTile_M = ThreadTile_M; +// static constexpr index_t ThreadTile_N = ThreadTile_N; +// static constexpr index_t Repeat_M = Repeat_M; +// static constexpr index_t Repeat_N = Repeat_N; +// }; + + +template + typename BlockTile, // block size, seq + typename WarpTile, // warp size, seq + typename ThreadTile> // contiguous pixels(vector size) along seq struct SinkHornKnoppShape { - static constexpr index_t Block_M = WarpPerBlock_M; - static constexpr index_t Block_N = WarpPerBlock_N; - static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M; - static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N; - static constexpr index_t ThreadTile_M = ThreadTile_M; - static constexpr index_t ThreadTile_N = ThreadTile_N; - static constexpr index_t Repeat_M = Repeat_M; - static constexpr index_t Repeat_N = Repeat_N; + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); + + static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{}); + static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); + + static constexpr index_t RepeatInWarp = + Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size(); + static constexpr index_t RepeatInWarp_M = + (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1; + static constexpr index_t RepeatInWarp_N = + (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp; + + static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N; + + static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N); + + // static constexpr index_t BlockSize = ck_tile::get_warp_size(); + static constexpr index_t BlockSize = 1; // TODO }; template - CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const + CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const { - // Creating tensor descriptors, views and windows for inputs and outputs + // // Creating tensor descriptors, views and windows for inputs and outputs - // Create the reduce ops - // * Reduce Op ADD for row and column sums - // * Elementwise Op EXP for exponentiation + // // Create the reduce ops + // // * Reduce Op ADD for row and column sums + // // * Elementwise Op EXP for exponentiation - using ExponentiationOp = ElementwiseOp; - using AddOp = ElementwiseOp; - using DivideOp = ElementwiseOp; + // using ExponentiationOp = ElementwiseOp; + // using AddOp = ElementwiseOp; + // using DivideOp = ElementwiseOp; - using ReduceOp = ReduceOp; - // Run the first steps iteration of the Sinkhorn-Knopp algorithm - // Exponentiate the matrix x - auto x = load_tile(...); + // using ReduceOp = ReduceOp; + // // Run the first steps iteration of the Sinkhorn-Knopp algorithm + // // Exponentiate the matrix x + // auto x = load_tile(...); - // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations - // Use BlockReduce2D for row and column sums - for(int i = 0; i <= args.max_iterations; i++) - { - // 0. LOAD x - // 1. Compute row sums (REDUCE) - // 2. Divide values by row sums (SWEEP) - // 3. STORE the result of the division (in transposed format) - // 4. LOAD transposed x - // 5. Compute column sums (REDUCE) - // 6. Divide values by column sums (SWEEP) - // 7. STORE the result of the division (in transposed format) - } + // // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations + // // Use BlockReduce2D for row and column sums + // for(int i = 0; i <= args.max_iterations; i++) + // { + // // 0. LOAD x + // // 1. Compute row sums (REDUCE) + // // 2. Divide values by row sums (SWEEP) + // // 3. STORE the result of the division (in transposed format) + // // 4. LOAD transposed x + // // 5. Compute column sums (REDUCE) + // // 6. Divide values by column sums (SWEEP) + // // 7. STORE the result of the division (in transposed format) + // } } }; @@ -75,30 +75,30 @@ struct SinkhornKnoppKernelDummyNonStochastic const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), make_tuple(args.input_m, 1), - number{}, + number<4>{}, // TODO: Hardcoded vectorization, we should calculate it! number<1>{}); - auto buffer_view = make_buffer_view( - args.p_x, desc.get_element_space_size(), number<0>{}); + // auto buffer_view = make_buffer_view( + // args.p_x, desc.get_element_space_size(), number<0>{}); - const auto x_tensor = - tensor_view{buffer_view, x_desc}; + // const auto x_tensor = + // tensor_view{buffer_view, x_desc}; - auto x_window = make_tile_window(x_tensor, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeXBlockTileDistribution()); + // auto x_window = make_tile_window(x_tensor, + // make_tuple(number{}, number{}), + // {0, 0}, + // Policy::template MakeXBlockTileDistribution()); - auto out_buffer_view = make_buffer_view( - args.out, x_desc.get_element_space_size(), number<0>{}); + // auto out_buffer_view = make_buffer_view( + // args.out, x_desc.get_element_space_size(), number<0>{}); - const auto y_tensor = - tensor_view{out_buffer_view, x_desc}; + // const auto y_tensor = + // tensor_view{out_buffer_view, x_desc}; - auto y_window = make_tile_window(y_tensor, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeXBlockTileDistribution()); + // auto y_window = make_tile_window(y_tensor, + // make_tuple(number{}, number{}), + // {0, 0}, + // Policy::template MakeXBlockTileDistribution()); } }; diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 93cd7fa063..02266162ca 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -40,3 +40,4 @@ add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) add_subdirectory(pooling) add_subdirectory(grouped_conv) +add_subdirectory(sinkhorn) diff --git a/test/ck_tile/sinkhorn/CMakeLists.txt b/test/ck_tile/sinkhorn/CMakeLists.txt new file mode 100644 index 0000000000..0d4327ff72 --- /dev/null +++ b/test/ck_tile/sinkhorn/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_ck_tile_sinkhorn test_sinkhorn.cpp) + target_link_libraries(test_ck_tile_sinkhorn) +endif() + diff --git a/test/ck_tile/sinkhorn/test_sinkhorn.cpp b/test/ck_tile/sinkhorn/test_sinkhorn.cpp new file mode 100644 index 0000000000..a525050e3e --- /dev/null +++ b/test/ck_tile/sinkhorn/test_sinkhorn.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/sinkhorn_knopp.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +#include "test_sinkhorn_impl.hpp" + +// Shape parameters for different test configurations +using Shape1_BlockWarps = ck_tile::sequence<4, 1>; +using Shape1_BlockTile = ck_tile::sequence<128, 128>; +using Shape1_WarpTile = ck_tile::sequence<32, 128>; +using Shape1_ThreadTile = ck_tile::sequence<8, 8>; + +// Test configurations for different data types and input size +using TestConfig_F16 = std::tuple< + ck_tile::half_t, // XDataType + float, // ComputeDataType + float, // YDataType + Shape1_BlockWarps, + Shape1_BlockTile, + Shape1_WarpTile, + Shape1_ThreadTile>; + + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestCkTileSinkHorn, TestTypes); + +TYPED_TEST(TestCkTileSinkHorn, Test_4x4) +{ + this->RunGenericTest({4, 4}, 10); +} diff --git a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp new file mode 100644 index 0000000000..0fe5d9f356 --- /dev/null +++ b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp @@ -0,0 +1,97 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/sinkhorn_knopp.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +template +class TestCkTileSinkHorn: public ::testing::Test +{ + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using ComputeDataType = std::tuple_element_t<1, Tuple>; + using YDataType = std::tuple_element_t<2, Tuple>; + using BlockWarps_ = std::tuple_element_t<3, Tuple>; + using BlockTile_ = std::tuple_element_t<4, Tuple>; + using WarpTile_ = std::tuple_element_t<5, Tuple>; + using ThreadTile_ = std::tuple_element_t<6, Tuple>; + + using TestSinkhornShape = + ck_tile::SinkhornKnoppShape< + BlockWarps_, + BlockTile_, + WarpTile_, + ThreadTile_ + >; + + void RunGenericTest(const std::vector& input_shape, const int max_iterations) + { + + SinkhornKnoppArgs args{}; + args.input_m = static_cast(input_shape[0]); + args.max_iterations = max_iterations; + + auto default_stride = {args.input_m, 1}; + + ck_tile::HostTensor h_x(input_shape, default_stride); + ck_tile::HostTensor h_y(input_shape, default_stride); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + + auto buffer_size = h_xs.get_element_space_size_in_bytes(); + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_y_mem(output_buffer_size); + + args.p_x = static_cast(d_x_mem.GetDeviceBuffer()); + args.out = static_cast(d_y_mem.GetDeviceBuffer()); + + d_x_mem.ToDevice(h_x.data()); + d_y_mem.ToDevice(h_y.data()); + + using Problem = ck_tile::SinkhornKnoppProblem; + using Kernel = ck_tile::SinkhornKnoppKernelDummyNonStochastic< + Problem, + ck_tile::SinkhornKnoppPolicy>; + + // Launch configuration + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = 1 // TODO + + //TODO + // if(!Kernel::IsSupportedArgument()) + // { + // throw std::runtime_error("Wrong! Arguments not supported!\n"); + // } + + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + args)); + + // Reference computation + // TODO + + // Transfer data from device and check error for each operation + // TODO + + EXPECT_TRUE(true); // TODO + } +};