From 5f1d79b42d9bc32b28e781dfdbee4371af2bfc32 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Wed, 21 Jan 2026 11:04:40 -0500 Subject: [PATCH] Add basic SinkhornKnoppKernelDummyNonStochastic implementation --- .../sinkhorn_knopp/sinkhorn_knopp_kernel.hpp | 106 +++++++++++++----- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index ae65e8c0a9..876674c2ff 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -6,6 +6,15 @@ namespace ck_tile { +// template +// struct SinkhornKnoppArgs +// { +// YDataType* out; +// const XDataType* p_x; +// const index_t input_m; +// int max_iterations; +// }; + struct SinkhornKnoppArgs { void* out; @@ -62,7 +71,7 @@ struct SinkhornKnoppKernelDummyNonStochastic } template - CK_TILE_DEVICE static auto MakeYBlockTile() + CK_TILE_DEVICE static auto MakeComputeBlockTile() { constexpr auto dstr = make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( @@ -70,44 +79,91 @@ struct SinkhornKnoppKernelDummyNonStochastic .get_static_tile_distribution_encoding(), sequence<0>{})); - auto tensor = make_static_distributed_tensor(dstr); + auto tensor = make_static_distributed_tensor(dstr); return tensor; } + // template + // CK_TILE_DEVICE static auto MakeYBlockTile() + // { + // constexpr auto dstr = + // make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + // XDistributedTensor_::get_tile_distribution() + // .get_static_tile_distribution_encoding(), + // sequence<0>{})); + + // auto tensor = make_static_distributed_tensor(dstr); + + // return tensor; + // } + CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const { - // using S = Problem::BlockShape; + using S = Problem::BlockShape; + using XDataType = typename Problem::XDataType; + using YDataType = typename Problem::YDataType; - // static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); + static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); - // const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), - // make_tuple(args.input_m, 1), - // number<4>{}, // TODO: Hardcoded - // vectorization, we should calculate it! - // number<1>{}); + auto* p_x = static_cast(args.p_x); + auto* p_y = static_cast(args.out); + auto reduce_func = ck_tile::ReduceOp::Add{}; - // auto buffer_view = make_buffer_view( - // args.p_x, desc.get_element_space_size(), number<0>{}); + const XDataType custom_padding_value = type_convert( + reduce_func.GetIdentityValue()); - // const auto x_tensor = - // tensor_view{buffer_view, x_desc}; + const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), + make_tuple(args.input_m, 1), + number<4>{}, // TODO: Hardcoded + // vectorization, //we should calculate it! + number<1>{}); - // auto x_window = make_tile_window(x_tensor, - // make_tuple(number{}, number{}), - // {0, 0}, - // Policy::template MakeXBlockTileDistribution()); + auto buffer_view = make_buffer_view( + p_x, x_desc.get_element_space_size(), custom_padding_value); - // auto out_buffer_view = make_buffer_view( - // args.out, x_desc.get_element_space_size(), number<0>{}); + const auto x_tensor = + tensor_view{buffer_view, x_desc}; - // const auto y_tensor = - // tensor_view{out_buffer_view, x_desc}; + [[maybe_unused]] auto x_window = + make_tile_window(x_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeXBlockTileDistribution()); - // auto y_window = make_tile_window(y_tensor, - // make_tuple(number{}, number{}), - // {0, 0}, - // Policy::template MakeXBlockTileDistribution()); + auto out_buffer_view = make_buffer_view( + p_y, x_desc.get_element_space_size(), type_convert(custom_padding_value)); + + auto y_tensor = + tensor_view{out_buffer_view, x_desc}; + + [[maybe_unused]] auto y_window = + make_tile_window(y_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeXBlockTileDistribution()); + // Dummy copy from input to output + + [[maybe_unused]] auto x_tile = load_tile(x_window); + + // auto y_tile = MakeYBlockTile(); + auto y_tile = make_static_distributed_tensor( + Policy::template MakeXBlockTileDistribution()); + + // Set all output elements to the custom padding value. + // // Simple solution to set the whole tile to a constant // + // set_tile(y_tile, custom_padding_value); + // store_tile(y_window, y_tile); + + constexpr auto y_spans = y_tile.get_distributed_spans(); + sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + y_tile(distributed_indices) = type_convert(custom_padding_value); + }); + }); + + store_tile(y_window, y_tile); } };