From c4d9d16deafe9cb2046e5eb201a0097b33ed9e46 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Fri, 16 Jan 2026 10:14:40 -0500 Subject: [PATCH] sketch algorithm --- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index 4a6bb04eb7..093b5dd8c3 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -6,6 +6,7 @@ namespace ck_tile { struct SinkhornKnoppArgs { + void* out; const void* p_x; const index_t n; int max_iterations; @@ -21,11 +22,28 @@ struct SinkhornKnoppKernel // * Reduce Op ADD for row and column sums // * Elementwise Op EXP for exponentiation - // Run the first steps iteration of the Sinkhorn-Knopp algorithm - // Using the exponentiation as the elementwise operation + using ExponentiationOp = ElementwiseOp; + using AddOp = ElementwiseOp; + using DivideOp = ElementwiseOp; - // Hot loop for Sinkhorn-Knopp iterations from max_iterations=1 - // + using ReduceOp = ReduceOp; + // Run the first steps iteration of the Sinkhorn-Knopp algorithm + // Exponentiate the matrix x + auto x = load_tile(...); + + // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations + // Use BlockReduce2D for row and column sums + for (int i = 0; i <= args.max_iterations; i++) { + // 0. LOAD x + // 1. Compute row sums (REDUCE) + // 2. Divide values by row sums (SWEEP) + // 3. STORE the result of the division (in transposed format) + // 4. LOAD transposed x + // 5. Compute column sums (REDUCE) + // 6. Divide values by column sums (SWEEP) + // 7. STORE the result of the division (in transposed format) + } } }; + } // namespace ck_tile \ No newline at end of file