Add the structure for testing the Sinkhorn-Knopp kernel

This commit is contained in:
Damien Lejeune
2026-01-20 11:42:10 -05:00
parent 9e79b07298
commit e3efa236ec
7 changed files with 244 additions and 60 deletions

View File

@@ -36,8 +36,8 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
tile_distribution_encoding<
sequence<>,
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<1, 1>>,
sequence<1, 1, 2, 2>,

View File

@@ -4,24 +4,60 @@
namespace ck_tile {
template <WarpPerBlock_M,
WarpPerBlock_N,
ThreadPerWarp_M,
ThreadPerWarp_N,
ThreadTile_M,
ThreadTile_N,
Repeat_M,
Repeat_N>
// template <WarpPerBlock_M,
// WarpPerBlock_N,
// ThreadPerWarp_M,
// ThreadPerWarp_N,
// ThreadTile_M,
// ThreadTile_N,
// Repeat_M,
// Repeat_N>
// struct SinkHornKnoppShape
// {
// static constexpr index_t Block_M = WarpPerBlock_M;
// static constexpr index_t Block_N = WarpPerBlock_N;
// static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M;
// static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N;
// static constexpr index_t ThreadTile_M = ThreadTile_M;
// static constexpr index_t ThreadTile_N = ThreadTile_N;
// static constexpr index_t Repeat_M = Repeat_M;
// static constexpr index_t Repeat_N = Repeat_N;
// };
template <typename BlockWarps, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WarpTile, // warp size, seq<M, N>
typename ThreadTile> // contiguous pixels(vector size) along seq<M, N>
struct SinkHornKnoppShape
{
static constexpr index_t Block_M = WarpPerBlock_M;
static constexpr index_t Block_N = WarpPerBlock_N;
static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M;
static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N;
static constexpr index_t ThreadTile_M = ThreadTile_M;
static constexpr index_t ThreadTile_N = ThreadTile_N;
static constexpr index_t Repeat_M = Repeat_M;
static constexpr index_t Repeat_N = Repeat_N;
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WarpTile::at(number<0>{});
static constexpr index_t Warp_N = WarpTile::at(number<1>{});
static constexpr index_t ThreadTile_M = ThreadTile::at(number<0>{});
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t RepeatInWarp =
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
static constexpr index_t RepeatInWarp_M =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
static constexpr index_t RepeatInWarp_N =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
// static constexpr index_t BlockSize = ck_tile::get_warp_size();
static constexpr index_t BlockSize = 1; // TODO
};
template <typename _XDataType,

View File

@@ -17,36 +17,36 @@ struct SinkhornKnoppArgs
struct SinkhornKnoppKernelReduce
{
template <typename Problem>
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
{
// Creating tensor descriptors, views and windows for inputs and outputs
// // Creating tensor descriptors, views and windows for inputs and outputs
// Create the reduce ops
// * Reduce Op ADD for row and column sums
// * Elementwise Op EXP for exponentiation
// // Create the reduce ops
// // * Reduce Op ADD for row and column sums
// // * Elementwise Op EXP for exponentiation
using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
using AddOp = ElementwiseOp<AddOperation>;
using DivideOp = ElementwiseOp<DivideOperation>;
// using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
// using AddOp = ElementwiseOp<AddOperation>;
// using DivideOp = ElementwiseOp<DivideOperation>;
using ReduceOp = ReduceOp<AddOp, AddOp>;
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the matrix x
auto x = load_tile(...);
// using ReduceOp = ReduceOp<AddOp, AddOp>;
// // Run the first steps iteration of the Sinkhorn-Knopp algorithm
// // Exponentiate the matrix x
// auto x = load_tile(...);
// Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// Use BlockReduce2D for row and column sums
for(int i = 0; i <= args.max_iterations; i++)
{
// 0. LOAD x
// 1. Compute row sums (REDUCE)
// 2. Divide values by row sums (SWEEP)
// 3. STORE the result of the division (in transposed format)
// 4. LOAD transposed x
// 5. Compute column sums (REDUCE)
// 6. Divide values by column sums (SWEEP)
// 7. STORE the result of the division (in transposed format)
}
// // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// // Use BlockReduce2D for row and column sums
// for(int i = 0; i <= args.max_iterations; i++)
// {
// // 0. LOAD x
// // 1. Compute row sums (REDUCE)
// // 2. Divide values by row sums (SWEEP)
// // 3. STORE the result of the division (in transposed format)
// // 4. LOAD transposed x
// // 5. Compute column sums (REDUCE)
// // 6. Divide values by column sums (SWEEP)
// // 7. STORE the result of the division (in transposed format)
// }
}
};
@@ -75,30 +75,30 @@ struct SinkhornKnoppKernelDummyNonStochastic
const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<args.input_m>{},
number<4>{}, // TODO: Hardcoded vectorization, we should calculate it!
number<1>{});
auto buffer_view = make_buffer_view<address_space_enum::global>(
args.p_x, desc.get_element_space_size(), number<0>{});
// auto buffer_view = make_buffer_view<address_space_enum::global>(
// args.p_x, desc.get_element_space_size(), number<0>{});
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
// const auto x_tensor =
// tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
auto x_window = make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
// auto x_window = make_tile_window(x_tensor,
// make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
// {0, 0},
// Policy::template MakeXBlockTileDistribution<Problem>());
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
args.out, x_desc.get_element_space_size(), number<0>{});
// auto out_buffer_view = make_buffer_view<address_space_enum::global>(
// args.out, x_desc.get_element_space_size(), number<0>{});
const auto y_tensor =
tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
// const auto y_tensor =
// tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
auto y_window = make_tile_window(y_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
// auto y_window = make_tile_window(y_tensor,
// make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
// {0, 0},
// Policy::template MakeXBlockTileDistribution<Problem>());
}
};