From 88a74ec31754f1c305be736e4ac706d547611dd4 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Mon, 2 Feb 2026 05:05:06 -0500 Subject: [PATCH] WIP --- .../sinkhorn_knopp_default_policy.hpp | 11 ++++---- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 25 ++++++++++++++++--- test/ck_tile/sinkhorn/test_sinkhorn.cpp | 23 +++++++---------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp index e0e32a304f..66c3103f13 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp @@ -20,10 +20,10 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy sequence, sequence>, tuple, sequence<2, 1>>, - tuple, sequence<2, 2>>, + tuple, sequence<2, 1>>, // WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N - sequence<2, 2, 1, 1>, - sequence<0, 3, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N + sequence<2, 1>, + sequence<3, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N } template @@ -38,8 +38,9 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 2>>, - sequence<1, 1, 2, 2>, - sequence<0, 3, 0, 3>>{}); + // WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N + sequence<2, 1>, + sequence<3, 3>>{}); } template diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index d5ac9d63fa..8b1e361860 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -27,7 +27,6 @@ struct SinkhornKnoppKernelReduce CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const { // // Creating tensor descriptors, views and windows for inputs and outputs - using S = Problem::BlockShape; using InDataType = typename Problem::InDataType; using ComputeDataType = typename Problem::ComputeDataType; @@ -43,7 +42,7 @@ struct SinkhornKnoppKernelReduce const auto in_out_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), make_tuple(args.input_m, 1), - number<16>{}, // TODO: Hardcoded + number<4>{}, // TODO: Hardcoded // vectorization, //we should calculate it! number<1>{}); @@ -80,7 +79,10 @@ struct SinkhornKnoppKernelReduce auto input_tile = load_tile(input_window); // Run the first steps iteration of the Sinkhorn-Knopp algorithm // Exponentiate the input to make it strictly positive - auto exp_func = [](ComputeDataType x) { return ck_tile::exp(x); }; + auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); }; + // auto exp_func = []([[maybe_unused]] ComputeDataType x) { + // return static_cast(1.0); + // }; auto compute_tile = tile_elementwise_in(exp_func, input_tile); @@ -91,6 +93,7 @@ struct SinkhornKnoppKernelReduce // Hot loop for Sinkhorn-Knopp iterations from 1 to iterations // Use BlockReduce2D for row and column sums auto row_sum = Policy::template GetSum(); + for(int i = 0; i <= args.iterations; i++) { // 1. Compute row sums (REDUCE) @@ -107,8 +110,22 @@ struct SinkhornKnoppKernelReduce }); }); + transpose_tile2d(compute_tile_t, compute_tile); + + // // Row sum is column sum for transposed c_tile + auto col_acc_tile = row_sum(compute_tile_t, out_padding_value, acc_op); + + constexpr auto c_t_spans = compute_tile_t.get_distributed_spans(); + sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) { + sweep_tile_span(c_t_spans[number<1>{}], [&](const auto idx1) { + constexpr auto c_t_idx = make_tuple(idx0, idx1); + constexpr auto col_acc_idx = make_tuple(idx0); + compute_tile_t(c_t_idx) = compute_tile_t(c_t_idx) / col_acc_tile(col_acc_idx); + }); + }); + + transpose_tile2d(compute_tile, compute_tile_t); // 3. STORE the result of the division (in transposed format) - // store_tile(compute_tile_t, compute_tile); // 4. LOAD transposed x // 5. Compute column sums (REDUCE) diff --git a/test/ck_tile/sinkhorn/test_sinkhorn.cpp b/test/ck_tile/sinkhorn/test_sinkhorn.cpp index a525050e3e..498fcdc9d3 100644 --- a/test/ck_tile/sinkhorn/test_sinkhorn.cpp +++ b/test/ck_tile/sinkhorn/test_sinkhorn.cpp @@ -19,24 +19,19 @@ using Shape1_BlockWarps = ck_tile::sequence<4, 1>; using Shape1_BlockTile = ck_tile::sequence<128, 128>; using Shape1_WarpTile = ck_tile::sequence<32, 128>; -using Shape1_ThreadTile = ck_tile::sequence<8, 8>; +using Shape1_ThreadTile = ck_tile::sequence<1, 4>; // Test configurations for different data types and input size -using TestConfig_F16 = std::tuple< - ck_tile::half_t, // XDataType - float, // ComputeDataType - float, // YDataType - Shape1_BlockWarps, - Shape1_BlockTile, - Shape1_WarpTile, - Shape1_ThreadTile>; - +using TestConfig_F16 = std::tuple; using TestTypes = ::testing::Types; TYPED_TEST_SUITE(TestCkTileSinkHorn, TestTypes); -TYPED_TEST(TestCkTileSinkHorn, Test_4x4) -{ - this->RunGenericTest({4, 4}, 10); -} +TYPED_TEST(TestCkTileSinkHorn, Test_4x4) { this->RunGenericTest({4, 4}, 20); }