From 04f8f3ed5d00a8e6391e3eb0b3964bc267c1499f Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Mon, 26 Jan 2026 09:14:16 +0000 Subject: [PATCH] Add CPU reference computation --- test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp | 155 +++++++++++++------ 1 file changed, 109 insertions(+), 46 deletions(-) diff --git a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp index 3c5ccfc5f7..0be972319e 100644 --- a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp +++ b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp @@ -14,30 +14,92 @@ #include "ck_tile/host/kernel_launch.hpp" template -class TestCkTileSinkHorn: public ::testing::Test +class TestCkTileSinkHorn : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using ComputeDataType = std::tuple_element_t<1, Tuple>; - using YDataType = std::tuple_element_t<2, Tuple>; - using BlockWarps_ = std::tuple_element_t<3, Tuple>; - using BlockTile_ = std::tuple_element_t<4, Tuple>; - using WarpTile_ = std::tuple_element_t<5, Tuple>; - using ThreadTile_ = std::tuple_element_t<6, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using ComputeDataType = std::tuple_element_t<1, Tuple>; + using YDataType = std::tuple_element_t<2, Tuple>; + using BlockWarps_ = std::tuple_element_t<3, Tuple>; + using BlockTile_ = std::tuple_element_t<4, Tuple>; + using WarpTile_ = std::tuple_element_t<5, Tuple>; + using ThreadTile_ = std::tuple_element_t<6, Tuple>; using TestSinkhornShape = - ck_tile::SinkhornKnoppShape< - BlockWarps_, - BlockTile_, - WarpTile_, - ThreadTile_ - >; + ck_tile::SinkhornKnoppShape; + + // template + void sinkhorn_knopp_ref_single_iter(ck_tile::HostTensor& c_n_n, + ck_tile::HostTensor& acc_n) + { + const ck_tile::index_t input_n = acc_n.get_length(0); + + // Sum and scale rowwise + for(ck_tile::index_t i = 0; i < input_n; ++i) + { + acc_n(i) = 0; + for(ck_tile::index_t j = 0; j < input_n; ++j) + { + acc_n(i) += c_n_n(i, j); + } + for(ck_tile::index_t j = 0; j < input_n; ++j) + { + c_n_n(i, j) /= acc_n(i); + } + } + + // Repeat columnwise + for(ck_tile::index_t i = 0; i < input_n; ++i) + { + acc_n(i) = 0; + for(ck_tile::index_t j = 0; j < input_n; ++j) + { + acc_n(i) += c_n_n(j, i); + } + for(ck_tile::index_t j = 0; j < input_n; ++j) + { + c_n_n(j, i) /= acc_n(i); + } + } + } + + void sinkhorn_knopp_ref(const ck_tile::HostTensor& x_n_n, + ck_tile::HostTensor& y_n_n, + const int n_iter) + { + const ck_tile::index_t input_n = x_n_n.get_length(0); + ck_tile::HostTensor c_n_n({input_n, input_n}, {1, input_n}); + ck_tile::HostTensor acc_n({input_n}, {1}); + + // First apply exp to make input nonnegative + for(ck_tile::index_t i = 0; i < input_n; ++i) + { + for(ck_tile::index_t j = 0; j < input_n; ++j) + { + c_n_n(i, j) = exp(ck_tile::type_convert(x_n_n(i, j))); + } + } + + // Iterate normalization on rows and columns + for(auto it = 0; it < n_iter; ++it) + { + sinkhorn_knopp_ref_single_iter(c_n_n, c_n_n); + } + + // Copy and cast to output type + for(ck_tile::index_t i = 0; i < input_n; ++i) + { + for(ck_tile::index_t j = 0; j < input_n; ++j) + { + y_n_n(i, j) = ck_tile::type_convert(c_n_n(i, j)); + } + } + } void RunGenericTest(const std::vector& input_shape, const int max_iterations) { - auto input_m = input_shape[0]; - - auto default_stride = {input_m, 1}; + auto input_n = input_shape[0]; + auto default_stride = {input_n, 1}; ck_tile::HostTensor h_x(input_shape, default_stride); ck_tile::HostTensor h_y(input_shape, default_stride); @@ -48,24 +110,19 @@ class TestCkTileSinkHorn: public ::testing::Test ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); ck_tile::DeviceMem d_y_mem(buffer_size); - ck_tile::SinkhornKnoppArgs args{ - static_cast(d_y_mem.GetDeviceBuffer()), - static_cast(d_x_mem.GetDeviceBuffer()), - input_m, - max_iterations - }; + ck_tile::SinkhornKnoppArgs args{static_cast(d_y_mem.GetDeviceBuffer()), + static_cast(d_x_mem.GetDeviceBuffer()), + input_n, + max_iterations}; d_x_mem.ToDevice(h_x.data()); d_y_mem.ToDevice(h_y.data()); - using Problem = ck_tile::SinkhornKnoppProblem; - using Kernel = ck_tile::SinkhornKnoppKernelDummyNonStochastic< - Problem, - ck_tile::SinkhornKnoppDefaultPolicy>; + using Problem = + ck_tile::SinkhornKnoppProblem; + using Kernel = + ck_tile::SinkhornKnoppKernelDummyNonStochastic; // Launch configuration const ck_tile::index_t kBlockSize = Kernel::BlockSize(); @@ -73,26 +130,32 @@ class TestCkTileSinkHorn: public ::testing::Test ck_tile::index_t kGridSize = 1; // TODO - //TODO - // if(!Kernel::IsSupportedArgument()) - // { - // throw std::runtime_error("Wrong! Arguments not supported!\n"); - // } + // TODO + // if(!Kernel::IsSupportedArgument()) + // { + // throw std::runtime_error("Wrong! Arguments not supported!\n"); + // } ck_tile::launch_kernel( ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel(Kernel{}, - kGridSize, - kBlockSize, - 0, - args)); + ck_tile::make_kernel(Kernel{}, kGridSize, kBlockSize, 0, args)); // Reference computation - // TODO + ck_tile::HostTensor h_y_ref(input_shape, default_stride); + sinkhorn_knopp_ref(h_x, h_y_ref, max_iterations); - // Transfer data from device and check error for each operation - // TODO + // TODO: Test whether or not output is actually doubly stochastic - EXPECT_TRUE(true); // TODO - } + // TODO: Refine tolerances + const float rtol = 1e-7; + const float atol = 1e-8; + + // Transfer data from device and check that it matches reference + d_y_mem.FromDevice(h_y.data()); + bool result = true; + result &= ck_tile::check_err( + h_y, h_y_ref, "Error: Sinkhorn-Knopp doesn't match CPU reference!", rtol, atol); + + EXPECT_TRUE(result); + } };