From ff428f34786c41960ff06d4ca0b8531510f316bd Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Mon, 2 Feb 2026 09:45:43 -0500 Subject: [PATCH] WIP --- .../sinkhorn_knopp_default_policy.hpp | 9 -- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 100 +++++++++++++----- test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp | 3 +- 3 files changed, 74 insertions(+), 38 deletions(-) diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp index 66c3103f13..7ec366f727 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp @@ -42,15 +42,6 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy sequence<2, 1>, sequence<3, 3>>{}); } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSum() - { - using br_problem = BlockReduce2dProblem; - return BlockReduce2d{}; - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index 8b1e361860..ff48892026 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -26,7 +26,7 @@ struct SinkhornKnoppKernelReduce CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const { - // // Creating tensor descriptors, views and windows for inputs and outputs + // Creating tensor descriptors, views and windows for inputs and outputs using S = Problem::BlockShape; using InDataType = typename Problem::InDataType; using ComputeDataType = typename Problem::ComputeDataType; @@ -62,9 +62,9 @@ struct SinkhornKnoppKernelReduce Policy::template MakeInputBlockTileDistribution()); }(); - const OutDataType out_padding_value = acc_op.template GetIdentityValue(); - [[maybe_unused]] auto out_window = [&]() { - auto out_buffer_view = make_buffer_view( + auto out_window = [&]() { + const OutDataType out_padding_value = acc_op.template GetIdentityValue(); + auto out_buffer_view = make_buffer_view( p_out, in_out_desc.get_element_space_size(), out_padding_value); auto out_tensor = tensor_view{ @@ -72,17 +72,21 @@ struct SinkhornKnoppKernelReduce return make_tile_window(out_tensor, make_tuple(number{}, number{}), - {0, 0}, + {0, 0}, Policy::template MakeInputBlockTileDistribution()); }(); auto input_tile = load_tile(input_window); + // Run the first steps iteration of the Sinkhorn-Knopp algorithm // Exponentiate the input to make it strictly positive - auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); }; + // auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); }; // auto exp_func = []([[maybe_unused]] ComputeDataType x) { // return static_cast(1.0); // }; + auto exp_func = [](InDataType x) -> ComputeDataType { + return static_cast(x); + }; auto compute_tile = tile_elementwise_in(exp_func, input_tile); @@ -92,13 +96,30 @@ struct SinkhornKnoppKernelReduce // Hot loop for Sinkhorn-Knopp iterations from 1 to iterations // Use BlockReduce2D for row and column sums - auto row_sum = Policy::template GetSum(); + auto row_sum = [&](const auto& c_tile) { + // TODO: Handle case where the input doesn't fit in a single tile + using br_problem = BlockReduce2dProblem; - for(int i = 0; i <= args.iterations; i++) + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + // TODO: Deduce/allow specifying a separate type for the accumulators? + // NOTE: MakeYBlockTile defaults to reducing 2nd dimension + auto acc_tile = block_reduce2d.template MakeYBlockTile(); + set_tile(acc_tile, acc_op.template GetIdentityValue()); + + block_reduce2d(c_tile, acc_tile, acc_op); + block_reduce2d_sync(acc_tile, acc_op); + + return acc_tile; + }; + + for(int i = 0; i < 1; i++) { // 1. Compute row sums (REDUCE) // FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead - auto row_acc_tile = row_sum(compute_tile, out_padding_value, acc_op); + auto row_acc_tile = row_sum(compute_tile); // 2. Divide values by row sums (SWEEP) constexpr auto c_spans = compute_tile.get_distributed_spans(); @@ -106,14 +127,53 @@ struct SinkhornKnoppKernelReduce sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) { constexpr auto c_idx = make_tuple(idx0, idx1); constexpr auto row_acc_idx = make_tuple(idx0); - compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx); + if(threadIdx.x == 0) + { + print(row_acc_idx); + print(":"); + print(row_acc_tile(row_acc_idx)); + print("\n"); + } + compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx); }); }); + if(threadIdx.x == 0) + { + print("compute tile after rows summed and divided\n"); + [&](auto tile) { + constexpr auto spans = tile.get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](const auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](const auto idx1) { + constexpr auto idx = make_tuple(idx0, idx1); + print(idx); + print(tile(idx)); + print("\n"); + }); + }); + }(compute_tile); + } + transpose_tile2d(compute_tile_t, compute_tile); - // // Row sum is column sum for transposed c_tile - auto col_acc_tile = row_sum(compute_tile_t, out_padding_value, acc_op); + if(threadIdx.x == 0) + { + print("compute tile transposed\n"); + [&](auto tile) { + constexpr auto spans = tile.get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](const auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](const auto idx1) { + constexpr auto idx = make_tuple(idx0, idx1); + print(idx); + print(tile(idx)); + print("\n"); + }); + }); + }(compute_tile_t); + } + + // Row sum is column sum for transposed c_tile + auto col_acc_tile = row_sum(compute_tile_t); constexpr auto c_t_spans = compute_tile_t.get_distributed_spans(); sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) { @@ -125,22 +185,6 @@ struct SinkhornKnoppKernelReduce }); transpose_tile2d(compute_tile, compute_tile_t); - // 3. STORE the result of the division (in transposed format) - - // 4. LOAD transposed x - // 5. Compute column sums (REDUCE) - // auto col_acc_tile = row_sum(compute_tile, out_padding_value, acc_op); - - // 6. Divide values by column sums (SWEEP) - // constexpr auto ct_spans = compute_tile.get_distributed_spans(); - // sweep_tile_span(ct_spans[number<0>{}], [&](const auto idx0) { - // sweep_tile_span(ct_spans[number<1>{}], [&](const auto idx1) { - // constexpr auto c_idx = make_tuple(idx0, idx1); - // constexpr auto col_acc_idx = make_tuple(idx1); - // compute_tile(c_idx) = compute_tile(c_idx) / col_acc_tile(acc_idx); - // }); - // }); - // 7. STORE the result of the division (in transposed format) } // Copy the final values to the output diff --git a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp index cfbc4261d8..bdc2ac536f 100644 --- a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp +++ b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp @@ -36,7 +36,8 @@ class TestCkTileSinkHorn : public ::testing::Test ck_tile::HostTensor h_x(input_shape, default_stride); ck_tile::HostTensor h_y(input_shape, default_stride); - ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + // ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + ck_tile::FillMonotonicSeq{}(h_x); auto buffer_size = h_x.get_element_space_size_in_bytes(); ck_tile::DeviceMem d_x_mem(buffer_size);