From f669f39eafc3253aad2172a147bc0ac238746eb5 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Tue, 27 Jan 2026 07:23:12 -0500 Subject: [PATCH] WIP --- .../pipeline/sinkhorn_knopp_problem.hpp | 10 +- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 125 ++++++++++++------ 2 files changed, 86 insertions(+), 49 deletions(-) diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp index 6d888c9fd2..4c180ca35d 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp @@ -58,15 +58,15 @@ struct SinkhornKnoppShape static constexpr index_t BlockSize = ck_tile::get_warp_size(); }; -template + typename _ComputeDataType = _OutDataType> struct SinkhornKnoppProblem { - using XDataType = remove_cvref_t<_XDataType>; + using InDataType = remove_cvref_t<_InDataType>; using ComputeDataType = remove_cvref_t<_ComputeDataType>; - using YDataType = remove_cvref_t<_YDataType>; + using OutDataType = remove_cvref_t<_OutDataType>; using BlockShape = remove_cvref_t<_BlockShape>; }; diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index 876674c2ff..65416e3fe3 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -6,42 +6,76 @@ namespace ck_tile { -// template -// struct SinkhornKnoppArgs -// { -// YDataType* out; -// const XDataType* p_x; -// const index_t input_m; -// int max_iterations; -// }; - struct SinkhornKnoppArgs { - void* out; - const void* p_x; + void* p_out; + const void* p_in; const index_t input_m; int max_iterations; }; +template struct SinkhornKnoppKernelReduce { - template CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const { // // Creating tensor descriptors, views and windows for inputs and outputs - // // Create the reduce ops - // // * Reduce Op ADD for row and column sums - // // * Elementwise Op EXP for exponentiation + using S = Problem::BlockShape; + using InDataType = typename Problem::OutDataType; + using OutDataType = typename Problem::OutDataType; - // using ExponentiationOp = ElementwiseOp; - // using AddOp = ElementwiseOp; - // using DivideOp = ElementwiseOp; + static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); + + auto* p_in = static_cast(args.p_in); + auto* p_out = static_cast(args.p_out); + + [[maybe_unused]] auto exp_op = ck_tile::element_wise::Exp{}; + [[maybe_unused]] auto acc_op = ck_tile::ReduceOp::Add{}; + [[maybe_unused]] auto div_op = ck_tile::element_wise::UnaryDivide{}; + + // We require exp(input) > 0, and exp(padding) == 0 + const InDataType x_padding_value = -ck_tile::numeric::infinity(); + + const auto in_desc = + make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), + make_tuple(args.input_m, 1), + number<4>{}, // TODO: Hardcoded + // vectorization, //we should calculate it! + number<1>{}); + + auto buffer_view = make_buffer_view( + p_in, in_desc.get_element_space_size(), x_padding_value); + + const auto x_tensor = + tensor_view{buffer_view, in_desc}; + + [[maybe_unused]] auto x_window = + make_tile_window(x_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeXBlockTileDistribution()); + + const OutDataType y_padding_value = acc_op.template GetIdentityValue(); + auto out_buffer_view = make_buffer_view( + p_out, in_desc.get_element_space_size(), y_padding_value); + + auto y_tensor = + tensor_view{out_buffer_view, in_desc}; + + [[maybe_unused]] auto y_window = + make_tile_window(y_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeXBlockTileDistribution()); + + [[maybe_unused]] auto c_window = + make_null_tile_window(make_tuple(number{}, number{})); + [[maybe_unused]] auto x_tile = load_tile(x_window); - // using ReduceOp = ReduceOp; // // Run the first steps iteration of the Sinkhorn-Knopp algorithm // // Exponentiate the matrix x - // auto x = load_tile(...); + // elementwise() // // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations // // Use BlockReduce2D for row and column sums @@ -79,7 +113,7 @@ struct SinkhornKnoppKernelDummyNonStochastic .get_static_tile_distribution_encoding(), sequence<0>{})); - auto tensor = make_static_distributed_tensor(dstr); + auto tensor = make_static_distributed_tensor(dstr); return tensor; } @@ -93,49 +127,52 @@ struct SinkhornKnoppKernelDummyNonStochastic // .get_static_tile_distribution_encoding(), // sequence<0>{})); - // auto tensor = make_static_distributed_tensor(dstr); + // auto tensor = make_static_distributed_tensor(dstr); // return tensor; // } CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const { - using S = Problem::BlockShape; - using XDataType = typename Problem::XDataType; - using YDataType = typename Problem::YDataType; + using S = Problem::BlockShape; + using InDataType = typename Problem::InDataType; + using OutDataType = typename Problem::OutDataType; static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); - auto* p_x = static_cast(args.p_x); - auto* p_y = static_cast(args.out); + auto* p_in = static_cast(args.p_in); + auto* p_out = static_cast(args.p_out); auto reduce_func = ck_tile::ReduceOp::Add{}; - const XDataType custom_padding_value = type_convert( + const InDataType custom_padding_value = type_convert( reduce_func.GetIdentityValue()); - const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), - make_tuple(args.input_m, 1), - number<4>{}, // TODO: Hardcoded - // vectorization, //we should calculate it! - number<1>{}); + const auto in_desc = + make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), + make_tuple(args.input_m, 1), + number<4>{}, // TODO: Hardcoded + // vectorization, //we should calculate it! + number<1>{}); auto buffer_view = make_buffer_view( - p_x, x_desc.get_element_space_size(), custom_padding_value); + p_in, in_desc.get_element_space_size(), custom_padding_value); - const auto x_tensor = - tensor_view{buffer_view, x_desc}; + const auto input_tensor = + tensor_view{buffer_view, in_desc}; - [[maybe_unused]] auto x_window = - make_tile_window(x_tensor, + [[maybe_unused]] auto input_window = + make_tile_window(input_tensor, make_tuple(number{}, number{}), {0, 0}, Policy::template MakeXBlockTileDistribution()); auto out_buffer_view = make_buffer_view( - p_y, x_desc.get_element_space_size(), type_convert(custom_padding_value)); + p_out, + in_desc.get_element_space_size(), + type_convert(custom_padding_value)); auto y_tensor = - tensor_view{out_buffer_view, x_desc}; + tensor_view{out_buffer_view, in_desc}; [[maybe_unused]] auto y_window = make_tile_window(y_tensor, @@ -144,10 +181,10 @@ struct SinkhornKnoppKernelDummyNonStochastic Policy::template MakeXBlockTileDistribution()); // Dummy copy from input to output - [[maybe_unused]] auto x_tile = load_tile(x_window); + [[maybe_unused]] auto input_tile = load_tile(input_window); - // auto y_tile = MakeYBlockTile(); - auto y_tile = make_static_distributed_tensor( + // auto y_tile = MakeYBlockTile(); + auto y_tile = make_static_distributed_tensor( Policy::template MakeXBlockTileDistribution()); // Set all output elements to the custom padding value. @@ -159,7 +196,7 @@ struct SinkhornKnoppKernelDummyNonStochastic sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) { constexpr auto distributed_indices = make_tuple(idx0, idx1); - y_tile(distributed_indices) = type_convert(custom_padding_value); + y_tile(distributed_indices) = type_convert(custom_padding_value); }); });