This commit is contained in:
Matti Eskelinen
2026-02-02 09:45:43 -05:00
parent fc977c88a2
commit ff428f3478
3 changed files with 74 additions and 38 deletions

View File

@@ -42,15 +42,6 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
sequence<2, 1>,
sequence<3, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSum()
{
using br_problem = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<br_problem>{};
}
};
} // namespace ck_tile

View File

@@ -26,7 +26,7 @@ struct SinkhornKnoppKernelReduce
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
{
// // Creating tensor descriptors, views and windows for inputs and outputs
// Creating tensor descriptors, views and windows for inputs and outputs
using S = Problem::BlockShape;
using InDataType = typename Problem::InDataType;
using ComputeDataType = typename Problem::ComputeDataType;
@@ -62,9 +62,9 @@ struct SinkhornKnoppKernelReduce
Policy::template MakeInputBlockTileDistribution<Problem>());
}();
const OutDataType out_padding_value = acc_op.template GetIdentityValue<OutDataType>();
[[maybe_unused]] auto out_window = [&]() {
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
auto out_window = [&]() {
const OutDataType out_padding_value = acc_op.template GetIdentityValue<OutDataType>();
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_out, in_out_desc.get_element_space_size(), out_padding_value);
auto out_tensor = tensor_view<decltype(out_buffer_view), decltype(in_out_desc)>{
@@ -72,17 +72,21 @@ struct SinkhornKnoppKernelReduce
return make_tile_window(out_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
{0, 0},
Policy::template MakeInputBlockTileDistribution<Problem>());
}();
auto input_tile = load_tile(input_window);
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the input to make it strictly positive
auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
// auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
// auto exp_func = []([[maybe_unused]] ComputeDataType x) {
// return static_cast<ComputeDataType>(1.0);
// };
auto exp_func = [](InDataType x) -> ComputeDataType {
return static_cast<ComputeDataType>(x);
};
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
@@ -92,13 +96,30 @@ struct SinkhornKnoppKernelReduce
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
// Use BlockReduce2D for row and column sums
auto row_sum = Policy::template GetSum<Problem>();
auto row_sum = [&](const auto& c_tile) {
// TODO: Handle case where the input doesn't fit in a single tile
using br_problem = BlockReduce2dProblem<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
for(int i = 0; i <= args.iterations; i++)
auto block_reduce2d = Policy::template GetBlockReduce2d<br_problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<br_problem>();
// TODO: Deduce/allow specifying a separate type for the accumulators?
// NOTE: MakeYBlockTile defaults to reducing 2nd dimension
auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
block_reduce2d(c_tile, acc_tile, acc_op);
block_reduce2d_sync(acc_tile, acc_op);
return acc_tile;
};
for(int i = 0; i < 1; i++)
{
// 1. Compute row sums (REDUCE)
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
auto row_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
auto row_acc_tile = row_sum(compute_tile);
// 2. Divide values by row sums (SWEEP)
constexpr auto c_spans = compute_tile.get_distributed_spans();
@@ -106,14 +127,53 @@ struct SinkhornKnoppKernelReduce
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
constexpr auto c_idx = make_tuple(idx0, idx1);
constexpr auto row_acc_idx = make_tuple(idx0);
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
if(threadIdx.x == 0)
{
print(row_acc_idx);
print(":");
print(row_acc_tile(row_acc_idx));
print("\n");
}
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
});
});
if(threadIdx.x == 0)
{
print("compute tile after rows summed and divided\n");
[&](auto tile) {
constexpr auto spans = tile.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
print(idx);
print(tile(idx));
print("\n");
});
});
}(compute_tile);
}
transpose_tile2d(compute_tile_t, compute_tile);
// // Row sum is column sum for transposed c_tile
auto col_acc_tile = row_sum(compute_tile_t, out_padding_value, acc_op);
if(threadIdx.x == 0)
{
print("compute tile transposed\n");
[&](auto tile) {
constexpr auto spans = tile.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
print(idx);
print(tile(idx));
print("\n");
});
});
}(compute_tile_t);
}
// Row sum is column sum for transposed c_tile
auto col_acc_tile = row_sum(compute_tile_t);
constexpr auto c_t_spans = compute_tile_t.get_distributed_spans();
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
@@ -125,22 +185,6 @@ struct SinkhornKnoppKernelReduce
});
transpose_tile2d(compute_tile, compute_tile_t);
// 3. STORE the result of the division (in transposed format)
// 4. LOAD transposed x
// 5. Compute column sums (REDUCE)
// auto col_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
// 6. Divide values by column sums (SWEEP)
// constexpr auto ct_spans = compute_tile.get_distributed_spans();
// sweep_tile_span(ct_spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(ct_spans[number<1>{}], [&](const auto idx1) {
// constexpr auto c_idx = make_tuple(idx0, idx1);
// constexpr auto col_acc_idx = make_tuple(idx1);
// compute_tile(c_idx) = compute_tile(c_idx) / col_acc_tile(acc_idx);
// });
// });
// 7. STORE the result of the division (in transposed format)
}
// Copy the final values to the output