Add basic SinkhornKnoppKernelDummyNonStochastic implementation

This commit is contained in:
Damien Lejeune
2026-01-21 11:04:40 -05:00
parent 9a743139af
commit 5f1d79b42d

View File

@@ -6,6 +6,15 @@
namespace ck_tile {
// template <typename XDataType, typename YDataType>
// struct SinkhornKnoppArgs
// {
// YDataType* out;
// const XDataType* p_x;
// const index_t input_m;
// int max_iterations;
// };
struct SinkhornKnoppArgs
{
void* out;
@@ -62,7 +71,7 @@ struct SinkhornKnoppKernelDummyNonStochastic
}
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
CK_TILE_DEVICE static auto MakeComputeBlockTile()
{
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
@@ -70,44 +79,91 @@ struct SinkhornKnoppKernelDummyNonStochastic
.get_static_tile_distribution_encoding(),
sequence<0>{}));
auto tensor = make_static_distributed_tensor<Problem::ComputeDataType>(dstr);
auto tensor = make_static_distributed_tensor<typename Problem::YDataType>(dstr);
return tensor;
}
// template <typename XDistributedTensor_>
// CK_TILE_DEVICE static auto MakeYBlockTile()
// {
// constexpr auto dstr =
// make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
// XDistributedTensor_::get_tile_distribution()
// .get_static_tile_distribution_encoding(),
// sequence<0>{}));
// auto tensor = make_static_distributed_tensor<typename Problem::YDataType>(dstr);
// return tensor;
// }
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
{
// using S = Problem::BlockShape;
using S = Problem::BlockShape;
using XDataType = typename Problem::XDataType;
using YDataType = typename Problem::YDataType;
// static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
// const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
// make_tuple(args.input_m, 1),
// number<4>{}, // TODO: Hardcoded
// vectorization, we should calculate it!
// number<1>{});
auto* p_x = static_cast<const Problem::XDataType*>(args.p_x);
auto* p_y = static_cast<Problem::YDataType*>(args.out);
auto reduce_func = ck_tile::ReduceOp::Add{};
// auto buffer_view = make_buffer_view<address_space_enum::global>(
// args.p_x, desc.get_element_space_size(), number<0>{});
const XDataType custom_padding_value = type_convert<XDataType>(
reduce_func.GetIdentityValue<typename Problem::ComputeDataType>());
// const auto x_tensor =
// tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<4>{}, // TODO: Hardcoded
// vectorization, //we should calculate it!
number<1>{});
// auto x_window = make_tile_window(x_tensor,
// make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
// {0, 0},
// Policy::template MakeXBlockTileDistribution<Problem>());
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_x, x_desc.get_element_space_size(), custom_padding_value);
// auto out_buffer_view = make_buffer_view<address_space_enum::global>(
// args.out, x_desc.get_element_space_size(), number<0>{});
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
// const auto y_tensor =
// tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
[[maybe_unused]] auto x_window =
make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
// auto y_window = make_tile_window(y_tensor,
// make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
// {0, 0},
// Policy::template MakeXBlockTileDistribution<Problem>());
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_y, x_desc.get_element_space_size(), type_convert<YDataType>(custom_padding_value));
auto y_tensor =
tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
[[maybe_unused]] auto y_window =
make_tile_window(y_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
// Dummy copy from input to output
[[maybe_unused]] auto x_tile = load_tile(x_window);
// auto y_tile = MakeYBlockTile<decltype(x_window)>();
auto y_tile = make_static_distributed_tensor<YDataType>(
Policy::template MakeXBlockTileDistribution<Problem>());
// Set all output elements to the custom padding value.
// // Simple solution to set the whole tile to a constant //
// set_tile(y_tile, custom_padding_value);
// store_tile(y_window, y_tile);
constexpr auto y_spans = y_tile.get_distributed_spans();
sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
y_tile(distributed_indices) = type_convert<YDataType>(custom_padding_value);
});
});
store_tile(y_window, y_tile);
}
};