mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
WIP
This commit is contained in:
@@ -58,15 +58,15 @@ struct SinkhornKnoppShape
|
||||
static constexpr index_t BlockSize = ck_tile::get_warp_size();
|
||||
};
|
||||
|
||||
template <typename _XDataType,
|
||||
typename _YDataType,
|
||||
template <typename _InDataType,
|
||||
typename _OutDataType,
|
||||
typename _BlockShape,
|
||||
typename _ComputeDataType = float>
|
||||
typename _ComputeDataType = _OutDataType>
|
||||
struct SinkhornKnoppProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<_XDataType>;
|
||||
using InDataType = remove_cvref_t<_InDataType>;
|
||||
using ComputeDataType = remove_cvref_t<_ComputeDataType>;
|
||||
using YDataType = remove_cvref_t<_YDataType>;
|
||||
using OutDataType = remove_cvref_t<_OutDataType>;
|
||||
|
||||
using BlockShape = remove_cvref_t<_BlockShape>;
|
||||
};
|
||||
|
||||
@@ -6,42 +6,76 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename XDataType, typename YDataType>
|
||||
// struct SinkhornKnoppArgs
|
||||
// {
|
||||
// YDataType* out;
|
||||
// const XDataType* p_x;
|
||||
// const index_t input_m;
|
||||
// int max_iterations;
|
||||
// };
|
||||
|
||||
struct SinkhornKnoppArgs
|
||||
{
|
||||
void* out;
|
||||
const void* p_x;
|
||||
void* p_out;
|
||||
const void* p_in;
|
||||
const index_t input_m;
|
||||
int max_iterations;
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct SinkhornKnoppKernelReduce
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
// // Creating tensor descriptors, views and windows for inputs and outputs
|
||||
|
||||
// // Create the reduce ops
|
||||
// // * Reduce Op ADD for row and column sums
|
||||
// // * Elementwise Op EXP for exponentiation
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::OutDataType;
|
||||
using OutDataType = typename Problem::OutDataType;
|
||||
|
||||
// using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
|
||||
// using AddOp = ElementwiseOp<AddOperation>;
|
||||
// using DivideOp = ElementwiseOp<DivideOperation>;
|
||||
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
|
||||
|
||||
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
|
||||
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
|
||||
|
||||
[[maybe_unused]] auto exp_op = ck_tile::element_wise::Exp{};
|
||||
[[maybe_unused]] auto acc_op = ck_tile::ReduceOp::Add{};
|
||||
[[maybe_unused]] auto div_op = ck_tile::element_wise::UnaryDivide{};
|
||||
|
||||
// We require exp(input) > 0, and exp(padding) == 0
|
||||
const InDataType x_padding_value = -ck_tile::numeric<InDataType>::infinity();
|
||||
|
||||
const auto in_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
|
||||
make_tuple(args.input_m, 1),
|
||||
number<4>{}, // TODO: Hardcoded
|
||||
// vectorization, //we should calculate it!
|
||||
number<1>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_in, in_desc.get_element_space_size(), x_padding_value);
|
||||
|
||||
const auto x_tensor =
|
||||
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
|
||||
|
||||
[[maybe_unused]] auto x_window =
|
||||
make_tile_window(x_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
const OutDataType y_padding_value = acc_op.template GetIdentityValue<OutDataType>();
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_out, in_desc.get_element_space_size(), y_padding_value);
|
||||
|
||||
auto y_tensor =
|
||||
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
|
||||
|
||||
[[maybe_unused]] auto y_window =
|
||||
make_tile_window(y_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
[[maybe_unused]] auto c_window =
|
||||
make_null_tile_window(make_tuple(number<S::Block_M>{}, number<S::Block_N>{}));
|
||||
[[maybe_unused]] auto x_tile = load_tile(x_window);
|
||||
|
||||
// using ReduceOp = ReduceOp<AddOp, AddOp>;
|
||||
// // Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// // Exponentiate the matrix x
|
||||
// auto x = load_tile(...);
|
||||
// elementwise()
|
||||
|
||||
// // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
|
||||
// // Use BlockReduce2D for row and column sums
|
||||
@@ -79,7 +113,7 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
.get_static_tile_distribution_encoding(),
|
||||
sequence<0>{}));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<typename Problem::YDataType>(dstr);
|
||||
auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
@@ -93,49 +127,52 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
// .get_static_tile_distribution_encoding(),
|
||||
// sequence<0>{}));
|
||||
|
||||
// auto tensor = make_static_distributed_tensor<typename Problem::YDataType>(dstr);
|
||||
// auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
|
||||
|
||||
// return tensor;
|
||||
// }
|
||||
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
using S = Problem::BlockShape;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using YDataType = typename Problem::YDataType;
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::InDataType;
|
||||
using OutDataType = typename Problem::OutDataType;
|
||||
|
||||
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
|
||||
|
||||
auto* p_x = static_cast<const Problem::XDataType*>(args.p_x);
|
||||
auto* p_y = static_cast<Problem::YDataType*>(args.out);
|
||||
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
|
||||
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
|
||||
auto reduce_func = ck_tile::ReduceOp::Add{};
|
||||
|
||||
const XDataType custom_padding_value = type_convert<XDataType>(
|
||||
const InDataType custom_padding_value = type_convert<InDataType>(
|
||||
reduce_func.GetIdentityValue<typename Problem::ComputeDataType>());
|
||||
|
||||
const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
|
||||
make_tuple(args.input_m, 1),
|
||||
number<4>{}, // TODO: Hardcoded
|
||||
// vectorization, //we should calculate it!
|
||||
number<1>{});
|
||||
const auto in_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
|
||||
make_tuple(args.input_m, 1),
|
||||
number<4>{}, // TODO: Hardcoded
|
||||
// vectorization, //we should calculate it!
|
||||
number<1>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_x, x_desc.get_element_space_size(), custom_padding_value);
|
||||
p_in, in_desc.get_element_space_size(), custom_padding_value);
|
||||
|
||||
const auto x_tensor =
|
||||
tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
|
||||
const auto input_tensor =
|
||||
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
|
||||
|
||||
[[maybe_unused]] auto x_window =
|
||||
make_tile_window(x_tensor,
|
||||
[[maybe_unused]] auto input_window =
|
||||
make_tile_window(input_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_y, x_desc.get_element_space_size(), type_convert<YDataType>(custom_padding_value));
|
||||
p_out,
|
||||
in_desc.get_element_space_size(),
|
||||
type_convert<OutDataType>(custom_padding_value));
|
||||
|
||||
auto y_tensor =
|
||||
tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
|
||||
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
|
||||
|
||||
[[maybe_unused]] auto y_window =
|
||||
make_tile_window(y_tensor,
|
||||
@@ -144,10 +181,10 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
// Dummy copy from input to output
|
||||
|
||||
[[maybe_unused]] auto x_tile = load_tile(x_window);
|
||||
[[maybe_unused]] auto input_tile = load_tile(input_window);
|
||||
|
||||
// auto y_tile = MakeYBlockTile<decltype(x_window)>();
|
||||
auto y_tile = make_static_distributed_tensor<YDataType>(
|
||||
// auto y_tile = MakeYBlockTile<decltype(input_window)>();
|
||||
auto y_tile = make_static_distributed_tensor<OutDataType>(
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Set all output elements to the custom padding value.
|
||||
@@ -159,7 +196,7 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
y_tile(distributed_indices) = type_convert<YDataType>(custom_padding_value);
|
||||
y_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user