From 5a0fea7f5a6f618152260a2669dd538c694e180b Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Mon, 19 Jan 2026 08:55:00 -0500 Subject: [PATCH] WIP --- .../sinkhorn_knopp_default_policy.hpp | 16 +++++++++ .../pipeline/sinkhorn_knopp_problem.hpp | 4 +-- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 35 +++++++------------ 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp index b0284e2645..acf837ac9b 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp @@ -24,6 +24,22 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy sequence<2, 2, 1, 1>, sequence<0, 2, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N } + + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple< + sequence>, + sequence, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } }; } // namespace ck_tile \ No newline at end of file diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp index 6eb448737b..a8cbb9ef1f 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp @@ -5,8 +5,8 @@ namespace ck_tile { template struct SinkHornKnoppShape { - static constexpr index_t WarpPerBlock_M = WarpPerBlock_M; - static constexpr index_t WarpPerBlock_N = WarpPerBlock_N; + static constexpr index_t Block_M = WarpPerBlock_M; + static constexpr index_t Block_N = WarpPerBlock_N; static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M; static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N; static constexpr index_t ThreadTile_M = ThreadTile_M; diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index 850a089a60..209d11287c 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -49,33 +49,22 @@ struct SinkhornKnoppKernelDummyNonStochastic { template CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const { - // Creating tensor descriptors, views and windows for inputs and outputs + using S = Problem::BlockShape; + + auto desc = make_naive_tensor_descriptor(); - // Create the reduce ops - // * Reduce Op ADD for row and column sums - // * Elementwise Op EXP for exponentiation + auto buffer_view = make_buffer_view( + args.p_x, desc.get_element_space_size(), number<0>{}); - using ExponentiationOp = ElementwiseOp; - using AddOp = ElementwiseOp; - using DivideOp = ElementwiseOp; + const auto x_tensor = + tensor_view{buffer_view, desc}; - using ReduceOp = ReduceOp; - // Run the first steps iteration of the Sinkhorn-Knopp algorithm - // Exponentiate the matrix x - auto x = load_tile(...); + auto x_window = + make_tile_window(x_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeXBlockTileDistribution()); - // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations - // Use BlockReduce2D for row and column sums - for (int i = 0; i <= args.max_iterations; i++) { - // 0. LOAD x - // 1. Compute row sums (REDUCE) - // 2. Divide values by row sums (SWEEP) - // 3. STORE the result of the division (in transposed format) - // 4. LOAD transposed x - // 5. Compute column sums (REDUCE) - // 6. Divide values by column sums (SWEEP) - // 7. STORE the result of the division (in transposed format) - } } };