This commit is contained in:
Matti Eskelinen
2026-01-27 07:23:12 -05:00
parent 0a43ee37b4
commit f669f39eaf
2 changed files with 86 additions and 49 deletions

View File

@@ -58,15 +58,15 @@ struct SinkhornKnoppShape
static constexpr index_t BlockSize = ck_tile::get_warp_size();
};
template <typename _XDataType,
typename _YDataType,
template <typename _InDataType,
typename _OutDataType,
typename _BlockShape,
typename _ComputeDataType = float>
typename _ComputeDataType = _OutDataType>
struct SinkhornKnoppProblem
{
using XDataType = remove_cvref_t<_XDataType>;
using InDataType = remove_cvref_t<_InDataType>;
using ComputeDataType = remove_cvref_t<_ComputeDataType>;
using YDataType = remove_cvref_t<_YDataType>;
using OutDataType = remove_cvref_t<_OutDataType>;
using BlockShape = remove_cvref_t<_BlockShape>;
};

View File

@@ -6,42 +6,76 @@
namespace ck_tile {
// template <typename XDataType, typename YDataType>
// struct SinkhornKnoppArgs
// {
// YDataType* out;
// const XDataType* p_x;
// const index_t input_m;
// int max_iterations;
// };
struct SinkhornKnoppArgs
{
void* out;
const void* p_x;
void* p_out;
const void* p_in;
const index_t input_m;
int max_iterations;
};
template <typename Problem, typename Policy>
struct SinkhornKnoppKernelReduce
{
template <typename Problem>
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
{
// // Creating tensor descriptors, views and windows for inputs and outputs
// // Create the reduce ops
// // * Reduce Op ADD for row and column sums
// // * Elementwise Op EXP for exponentiation
using S = Problem::BlockShape;
using InDataType = typename Problem::OutDataType;
using OutDataType = typename Problem::OutDataType;
// using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
// using AddOp = ElementwiseOp<AddOperation>;
// using DivideOp = ElementwiseOp<DivideOperation>;
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
[[maybe_unused]] auto exp_op = ck_tile::element_wise::Exp{};
[[maybe_unused]] auto acc_op = ck_tile::ReduceOp::Add{};
[[maybe_unused]] auto div_op = ck_tile::element_wise::UnaryDivide{};
// We require exp(input) > 0, and exp(padding) == 0
const InDataType x_padding_value = -ck_tile::numeric<InDataType>::infinity();
const auto in_desc =
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<4>{}, // TODO: Hardcoded
// vectorization, //we should calculate it!
number<1>{});
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_in, in_desc.get_element_space_size(), x_padding_value);
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
[[maybe_unused]] auto x_window =
make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
const OutDataType y_padding_value = acc_op.template GetIdentityValue<OutDataType>();
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_out, in_desc.get_element_space_size(), y_padding_value);
auto y_tensor =
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
[[maybe_unused]] auto y_window =
make_tile_window(y_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
[[maybe_unused]] auto c_window =
make_null_tile_window(make_tuple(number<S::Block_M>{}, number<S::Block_N>{}));
[[maybe_unused]] auto x_tile = load_tile(x_window);
// using ReduceOp = ReduceOp<AddOp, AddOp>;
// // Run the first steps iteration of the Sinkhorn-Knopp algorithm
// // Exponentiate the matrix x
// auto x = load_tile(...);
// elementwise()
// // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// // Use BlockReduce2D for row and column sums
@@ -79,7 +113,7 @@ struct SinkhornKnoppKernelDummyNonStochastic
.get_static_tile_distribution_encoding(),
sequence<0>{}));
auto tensor = make_static_distributed_tensor<typename Problem::YDataType>(dstr);
auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
return tensor;
}
@@ -93,49 +127,52 @@ struct SinkhornKnoppKernelDummyNonStochastic
// .get_static_tile_distribution_encoding(),
// sequence<0>{}));
// auto tensor = make_static_distributed_tensor<typename Problem::YDataType>(dstr);
// auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
// return tensor;
// }
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
{
using S = Problem::BlockShape;
using XDataType = typename Problem::XDataType;
using YDataType = typename Problem::YDataType;
using S = Problem::BlockShape;
using InDataType = typename Problem::InDataType;
using OutDataType = typename Problem::OutDataType;
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
auto* p_x = static_cast<const Problem::XDataType*>(args.p_x);
auto* p_y = static_cast<Problem::YDataType*>(args.out);
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
auto reduce_func = ck_tile::ReduceOp::Add{};
const XDataType custom_padding_value = type_convert<XDataType>(
const InDataType custom_padding_value = type_convert<InDataType>(
reduce_func.GetIdentityValue<typename Problem::ComputeDataType>());
const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<4>{}, // TODO: Hardcoded
// vectorization, //we should calculate it!
number<1>{});
const auto in_desc =
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<4>{}, // TODO: Hardcoded
// vectorization, //we should calculate it!
number<1>{});
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_x, x_desc.get_element_space_size(), custom_padding_value);
p_in, in_desc.get_element_space_size(), custom_padding_value);
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
const auto input_tensor =
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
[[maybe_unused]] auto x_window =
make_tile_window(x_tensor,
[[maybe_unused]] auto input_window =
make_tile_window(input_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_y, x_desc.get_element_space_size(), type_convert<YDataType>(custom_padding_value));
p_out,
in_desc.get_element_space_size(),
type_convert<OutDataType>(custom_padding_value));
auto y_tensor =
tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
[[maybe_unused]] auto y_window =
make_tile_window(y_tensor,
@@ -144,10 +181,10 @@ struct SinkhornKnoppKernelDummyNonStochastic
Policy::template MakeXBlockTileDistribution<Problem>());
// Dummy copy from input to output
[[maybe_unused]] auto x_tile = load_tile(x_window);
[[maybe_unused]] auto input_tile = load_tile(input_window);
// auto y_tile = MakeYBlockTile<decltype(x_window)>();
auto y_tile = make_static_distributed_tensor<YDataType>(
// auto y_tile = MakeYBlockTile<decltype(input_window)>();
auto y_tile = make_static_distributed_tensor<OutDataType>(
Policy::template MakeXBlockTileDistribution<Problem>());
// Set all output elements to the custom padding value.
@@ -159,7 +196,7 @@ struct SinkhornKnoppKernelDummyNonStochastic
sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
y_tile(distributed_indices) = type_convert<YDataType>(custom_padding_value);
y_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
});
});