This commit is contained in:
Matti Eskelinen
2026-01-19 08:55:00 -05:00
parent 1c35e916f0
commit 5a0fea7f5a
3 changed files with 30 additions and 25 deletions

View File

@@ -24,6 +24,22 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
sequence<2, 2, 1, 1>,
sequence<0, 2, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<1, 1>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
};
} // namespace ck_tile

View File

@@ -5,8 +5,8 @@ namespace ck_tile {
template <WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N, ThreadTile_M, ThreadTile_N, Repeat_M, Repeat_N>
struct SinkHornKnoppShape
{
static constexpr index_t WarpPerBlock_M = WarpPerBlock_M;
static constexpr index_t WarpPerBlock_N = WarpPerBlock_N;
static constexpr index_t Block_M = WarpPerBlock_M;
static constexpr index_t Block_N = WarpPerBlock_N;
static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M;
static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N;
static constexpr index_t ThreadTile_M = ThreadTile_M;

View File

@@ -49,33 +49,22 @@ struct SinkhornKnoppKernelDummyNonStochastic
{
template <typename Problem>
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const {
// Creating tensor descriptors, views and windows for inputs and outputs
using S = Problem::BlockShape;
auto desc = make_naive_tensor_descriptor();
// Create the reduce ops
// * Reduce Op ADD for row and column sums
// * Elementwise Op EXP for exponentiation
auto buffer_view = make_buffer_view<address_space_enum::global>(
args.p_x, desc.get_element_space_size(), number<0>{});
using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
using AddOp = ElementwiseOp<AddOperation>;
using DivideOp = ElementwiseOp<DivideOperation>;
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
using ReduceOp = ReduceOp<AddOp, AddOp>;
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the matrix x
auto x = load_tile(...);
auto x_window =
make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
// Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// Use BlockReduce2D for row and column sums
for (int i = 0; i <= args.max_iterations; i++) {
// 0. LOAD x
// 1. Compute row sums (REDUCE)
// 2. Divide values by row sums (SWEEP)
// 3. STORE the result of the division (in transposed format)
// 4. LOAD transposed x
// 5. Compute column sums (REDUCE)
// 6. Divide values by column sums (SWEEP)
// 7. STORE the result of the division (in transposed format)
}
}
};