This commit is contained in:
Matti Eskelinen
2026-01-30 03:47:34 -05:00
parent f669f39eaf
commit f5732e875b
3 changed files with 117 additions and 87 deletions

View File

@@ -10,7 +10,7 @@ namespace ck_tile {
struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeTransposedXBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
@@ -20,29 +20,36 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<1, 1>, sequence<2, 1>>, // WarpPerBlock_M, WarpPerBlock_N;
// ThreadPerWarp_M, ThreadPerWarp_N
// sequence<1, 1, 2, 2>,
// sequence<0, 3, 0, 3>>{}); // Repeat_N, ThreadTile_N, Repeat_M, ThreadTile_M
tuple<sequence<1, 1>, sequence<2, 2>>,
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
sequence<2, 2, 1, 1>,
sequence<0, 2, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
sequence<0, 3, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
sequence<>, // Repetitions (in input dimensions?)
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<1, 1>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSum()
{
using br_problem = BlockReduce2dProblem<typename Problem::InDataType,
typename Problem::ComputeDataType,
typename Problem::BlockShape>;
return BlockReduce2d<br_problem>{};
}
};
} // namespace ck_tile

View File

@@ -11,85 +11,123 @@ struct SinkhornKnoppArgs
void* p_out;
const void* p_in;
const index_t input_m;
int max_iterations;
int iterations;
};
template <typename Problem, typename Policy>
struct SinkhornKnoppKernelReduce
{
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
{
// // Creating tensor descriptors, views and windows for inputs and outputs
using S = Problem::BlockShape;
using InDataType = typename Problem::OutDataType;
using OutDataType = typename Problem::OutDataType;
using S = Problem::BlockShape;
using InDataType = typename Problem::InDataType;
using ComputeDataType = typename Problem::ComputeDataType;
using OutDataType = typename Problem::OutDataType;
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
[[maybe_unused]] auto exp_op = ck_tile::element_wise::Exp{};
[[maybe_unused]] auto acc_op = ck_tile::ReduceOp::Add{};
[[maybe_unused]] auto div_op = ck_tile::element_wise::UnaryDivide{};
auto acc_op = ck_tile::ReduceOp::Add{};
// We require exp(input) > 0, and exp(padding) == 0
const InDataType x_padding_value = -ck_tile::numeric<InDataType>::infinity();
const auto in_desc =
const auto in_out_desc =
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<4>{}, // TODO: Hardcoded
number<16>{}, // TODO: Hardcoded
// vectorization, //we should calculate it!
number<1>{});
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_in, in_desc.get_element_space_size(), x_padding_value);
const auto input_window = [&]() {
// We require exp(input) > 0, and exp(padding) == 0
const InDataType input_padding_value = -ck_tile::numeric<InDataType>::infinity();
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_in, in_out_desc.get_element_space_size(), input_padding_value);
[[maybe_unused]] auto x_window =
make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
const auto in_tensor =
tensor_view<decltype(buffer_view), decltype(in_out_desc)>{buffer_view, in_out_desc};
const OutDataType y_padding_value = acc_op.template GetIdentityValue<OutDataType>();
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_out, in_desc.get_element_space_size(), y_padding_value);
return make_tile_window(in_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeInputBlockTileDistribution<Problem>());
}();
auto y_tensor =
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
const OutDataType out_padding_value = acc_op.template GetIdentityValue<OutDataType>();
[[maybe_unused]] auto out_window = [&]() {
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_out, in_out_desc.get_element_space_size(), out_padding_value);
[[maybe_unused]] auto y_window =
make_tile_window(y_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto out_tensor = tensor_view<decltype(out_buffer_view), decltype(in_out_desc)>{
out_buffer_view, in_out_desc};
[[maybe_unused]] auto c_window =
make_null_tile_window(make_tuple(number<S::Block_M>{}, number<S::Block_N>{}));
[[maybe_unused]] auto x_tile = load_tile(x_window);
return make_tile_window(out_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeInputBlockTileDistribution<Problem>());
}();
// // Run the first steps iteration of the Sinkhorn-Knopp algorithm
// // Exponentiate the matrix x
// elementwise()
auto input_tile = load_tile(input_window);
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the input to make it strictly positive
auto exp_func = [](ComputeDataType x) { return ck_tile::exp(x); };
// // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// // Use BlockReduce2D for row and column sums
// for(int i = 0; i <= args.max_iterations; i++)
// {
// // 0. LOAD x
// // 1. Compute row sums (REDUCE)
// // 2. Divide values by row sums (SWEEP)
// // 3. STORE the result of the division (in transposed format)
// // 4. LOAD transposed x
// // 5. Compute column sums (REDUCE)
// // 6. Divide values by column sums (SWEEP)
// // 7. STORE the result of the division (in transposed format)
// }
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
// Create a transposed tile
[[maybe_unused]] auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
Policy::template MakeTransposedInputBlockTileDistribution<Problem>());
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
// Use BlockReduce2D for row and column sums
auto row_sum = Policy::template GetSum<Problem>();
for(int i = 0; i <= args.iterations; i++)
{
// 1. Compute row sums (REDUCE)
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
auto row_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
// 2. Divide values by row sums (SWEEP)
constexpr auto c_spans = compute_tile.get_distributed_spans();
sweep_tile_span(c_spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
constexpr auto c_idx = make_tuple(idx0, idx1);
constexpr auto row_acc_idx = make_tuple(idx0);
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
});
});
// 3. STORE the result of the division (in transposed format)
// store_tile(compute_tile_t, compute_tile);
// 4. LOAD transposed x
// 5. Compute column sums (REDUCE)
// auto col_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
// 6. Divide values by column sums (SWEEP)
// constexpr auto ct_spans = compute_tile.get_distributed_spans();
// sweep_tile_span(ct_spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(ct_spans[number<1>{}], [&](const auto idx1) {
// constexpr auto c_idx = make_tuple(idx0, idx1);
// constexpr auto col_acc_idx = make_tuple(idx1);
// compute_tile(c_idx) = compute_tile(c_idx) / col_acc_tile(acc_idx);
// });
// });
// 7. STORE the result of the division (in transposed format)
}
// Copy the final values to the output
store_tile(out_window, compute_tile);
}
};
@@ -104,20 +142,6 @@ struct SinkhornKnoppKernelDummyNonStochastic
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeComputeBlockTile()
{
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
sequence<0>{}));
auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
return tensor;
}
// template <typename XDistributedTensor_>
// CK_TILE_DEVICE static auto MakeYBlockTile()
// {
@@ -171,11 +195,11 @@ struct SinkhornKnoppKernelDummyNonStochastic
in_desc.get_element_space_size(),
type_convert<OutDataType>(custom_padding_value));
auto y_tensor =
auto out_tensor =
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
[[maybe_unused]] auto y_window =
make_tile_window(y_tensor,
[[maybe_unused]] auto out_window =
make_tile_window(out_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
@@ -183,24 +207,24 @@ struct SinkhornKnoppKernelDummyNonStochastic
[[maybe_unused]] auto input_tile = load_tile(input_window);
// auto y_tile = MakeYBlockTile<decltype(input_window)>();
auto y_tile = make_static_distributed_tensor<OutDataType>(
// auto out_tile = MakeYBlockTile<decltype(input_window)>();
auto out_tile = make_static_distributed_tensor<OutDataType>(
Policy::template MakeXBlockTileDistribution<Problem>());
// Set all output elements to the custom padding value.
// // Simple solution to set the whole tile to a constant //
// set_tile(y_tile, custom_padding_value);
// store_tile(y_window, y_tile);
// set_tile(out_tile, custom_padding_value);
// store_tile(out_window, out_tile);
constexpr auto y_spans = y_tile.get_distributed_spans();
constexpr auto y_spans = out_tile.get_distributed_spans();
sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
y_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
out_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
});
});
store_tile(y_window, y_tile);
store_tile(out_window, out_tile);
}
};

View File

@@ -39,7 +39,7 @@ class TestCkTileSinkHorn : public ::testing::Test
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
auto buffer_size = h_x.get_element_space_size_in_bytes();
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_x_mem(buffer_size);
ck_tile::DeviceMem d_y_mem(buffer_size);
ck_tile::SinkhornKnoppArgs args{static_cast<void*>(d_y_mem.GetDeviceBuffer()),
@@ -53,8 +53,7 @@ class TestCkTileSinkHorn : public ::testing::Test
using Problem =
ck_tile::SinkhornKnoppProblem<XDataType, YDataType, TestSinkhornShape, ComputeDataType>;
using Kernel =
ck_tile::SinkhornKnoppKernelDummyNonStochastic<Problem,
ck_tile::SinkhornKnoppDefaultPolicy>;
ck_tile::SinkhornKnoppKernelReduce<Problem, ck_tile::SinkhornKnoppDefaultPolicy>;
// Launch configuration
const ck_tile::index_t kBlockSize = Kernel::BlockSize();