mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
WIP
This commit is contained in:
@@ -10,7 +10,7 @@ namespace ck_tile {
|
||||
struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeTransposedXBlockTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
@@ -20,29 +20,36 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 1>>, // WarpPerBlock_M, WarpPerBlock_N;
|
||||
// ThreadPerWarp_M, ThreadPerWarp_N
|
||||
// sequence<1, 1, 2, 2>,
|
||||
// sequence<0, 3, 0, 3>>{}); // Repeat_N, ThreadTile_N, Repeat_M, ThreadTile_M
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
|
||||
sequence<2, 2, 1, 1>,
|
||||
sequence<0, 2, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
|
||||
sequence<0, 3, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
sequence<>, // Repetitions (in input dimensions?)
|
||||
tuple<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSum()
|
||||
{
|
||||
using br_problem = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<br_problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,85 +11,123 @@ struct SinkhornKnoppArgs
|
||||
void* p_out;
|
||||
const void* p_in;
|
||||
const index_t input_m;
|
||||
int max_iterations;
|
||||
int iterations;
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct SinkhornKnoppKernelReduce
|
||||
{
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
// // Creating tensor descriptors, views and windows for inputs and outputs
|
||||
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::OutDataType;
|
||||
using OutDataType = typename Problem::OutDataType;
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::InDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
using OutDataType = typename Problem::OutDataType;
|
||||
|
||||
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
|
||||
|
||||
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
|
||||
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
|
||||
|
||||
[[maybe_unused]] auto exp_op = ck_tile::element_wise::Exp{};
|
||||
[[maybe_unused]] auto acc_op = ck_tile::ReduceOp::Add{};
|
||||
[[maybe_unused]] auto div_op = ck_tile::element_wise::UnaryDivide{};
|
||||
auto acc_op = ck_tile::ReduceOp::Add{};
|
||||
|
||||
// We require exp(input) > 0, and exp(padding) == 0
|
||||
const InDataType x_padding_value = -ck_tile::numeric<InDataType>::infinity();
|
||||
|
||||
const auto in_desc =
|
||||
const auto in_out_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
|
||||
make_tuple(args.input_m, 1),
|
||||
number<4>{}, // TODO: Hardcoded
|
||||
number<16>{}, // TODO: Hardcoded
|
||||
// vectorization, //we should calculate it!
|
||||
number<1>{});
|
||||
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_in, in_desc.get_element_space_size(), x_padding_value);
|
||||
const auto input_window = [&]() {
|
||||
// We require exp(input) > 0, and exp(padding) == 0
|
||||
const InDataType input_padding_value = -ck_tile::numeric<InDataType>::infinity();
|
||||
|
||||
const auto x_tensor =
|
||||
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_in, in_out_desc.get_element_space_size(), input_padding_value);
|
||||
|
||||
[[maybe_unused]] auto x_window =
|
||||
make_tile_window(x_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
const auto in_tensor =
|
||||
tensor_view<decltype(buffer_view), decltype(in_out_desc)>{buffer_view, in_out_desc};
|
||||
|
||||
const OutDataType y_padding_value = acc_op.template GetIdentityValue<OutDataType>();
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_out, in_desc.get_element_space_size(), y_padding_value);
|
||||
return make_tile_window(in_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeInputBlockTileDistribution<Problem>());
|
||||
}();
|
||||
|
||||
auto y_tensor =
|
||||
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
|
||||
const OutDataType out_padding_value = acc_op.template GetIdentityValue<OutDataType>();
|
||||
[[maybe_unused]] auto out_window = [&]() {
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_out, in_out_desc.get_element_space_size(), out_padding_value);
|
||||
|
||||
[[maybe_unused]] auto y_window =
|
||||
make_tile_window(y_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
auto out_tensor = tensor_view<decltype(out_buffer_view), decltype(in_out_desc)>{
|
||||
out_buffer_view, in_out_desc};
|
||||
|
||||
[[maybe_unused]] auto c_window =
|
||||
make_null_tile_window(make_tuple(number<S::Block_M>{}, number<S::Block_N>{}));
|
||||
[[maybe_unused]] auto x_tile = load_tile(x_window);
|
||||
return make_tile_window(out_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeInputBlockTileDistribution<Problem>());
|
||||
}();
|
||||
|
||||
// // Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// // Exponentiate the matrix x
|
||||
// elementwise()
|
||||
auto input_tile = load_tile(input_window);
|
||||
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// Exponentiate the input to make it strictly positive
|
||||
auto exp_func = [](ComputeDataType x) { return ck_tile::exp(x); };
|
||||
|
||||
// // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
|
||||
// // Use BlockReduce2D for row and column sums
|
||||
// for(int i = 0; i <= args.max_iterations; i++)
|
||||
// {
|
||||
// // 0. LOAD x
|
||||
// // 1. Compute row sums (REDUCE)
|
||||
// // 2. Divide values by row sums (SWEEP)
|
||||
// // 3. STORE the result of the division (in transposed format)
|
||||
// // 4. LOAD transposed x
|
||||
// // 5. Compute column sums (REDUCE)
|
||||
// // 6. Divide values by column sums (SWEEP)
|
||||
// // 7. STORE the result of the division (in transposed format)
|
||||
// }
|
||||
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
|
||||
|
||||
// Create a transposed tile
|
||||
[[maybe_unused]] auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
|
||||
Policy::template MakeTransposedInputBlockTileDistribution<Problem>());
|
||||
|
||||
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
|
||||
// Use BlockReduce2D for row and column sums
|
||||
auto row_sum = Policy::template GetSum<Problem>();
|
||||
for(int i = 0; i <= args.iterations; i++)
|
||||
{
|
||||
// 1. Compute row sums (REDUCE)
|
||||
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
|
||||
auto row_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
|
||||
|
||||
// 2. Divide values by row sums (SWEEP)
|
||||
constexpr auto c_spans = compute_tile.get_distributed_spans();
|
||||
sweep_tile_span(c_spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto row_acc_idx = make_tuple(idx0);
|
||||
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
// 3. STORE the result of the division (in transposed format)
|
||||
// store_tile(compute_tile_t, compute_tile);
|
||||
|
||||
// 4. LOAD transposed x
|
||||
// 5. Compute column sums (REDUCE)
|
||||
// auto col_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
|
||||
|
||||
// 6. Divide values by column sums (SWEEP)
|
||||
// constexpr auto ct_spans = compute_tile.get_distributed_spans();
|
||||
// sweep_tile_span(ct_spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(ct_spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto c_idx = make_tuple(idx0, idx1);
|
||||
// constexpr auto col_acc_idx = make_tuple(idx1);
|
||||
// compute_tile(c_idx) = compute_tile(c_idx) / col_acc_tile(acc_idx);
|
||||
// });
|
||||
// });
|
||||
// 7. STORE the result of the division (in transposed format)
|
||||
}
|
||||
|
||||
// Copy the final values to the output
|
||||
store_tile(out_window, compute_tile);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -104,20 +142,6 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
return is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeComputeBlockTile()
|
||||
{
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
sequence<0>{}));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// template <typename XDistributedTensor_>
|
||||
// CK_TILE_DEVICE static auto MakeYBlockTile()
|
||||
// {
|
||||
@@ -171,11 +195,11 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
in_desc.get_element_space_size(),
|
||||
type_convert<OutDataType>(custom_padding_value));
|
||||
|
||||
auto y_tensor =
|
||||
auto out_tensor =
|
||||
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
|
||||
|
||||
[[maybe_unused]] auto y_window =
|
||||
make_tile_window(y_tensor,
|
||||
[[maybe_unused]] auto out_window =
|
||||
make_tile_window(out_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
@@ -183,24 +207,24 @@ struct SinkhornKnoppKernelDummyNonStochastic
|
||||
|
||||
[[maybe_unused]] auto input_tile = load_tile(input_window);
|
||||
|
||||
// auto y_tile = MakeYBlockTile<decltype(input_window)>();
|
||||
auto y_tile = make_static_distributed_tensor<OutDataType>(
|
||||
// auto out_tile = MakeYBlockTile<decltype(input_window)>();
|
||||
auto out_tile = make_static_distributed_tensor<OutDataType>(
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
// Set all output elements to the custom padding value.
|
||||
// // Simple solution to set the whole tile to a constant //
|
||||
// set_tile(y_tile, custom_padding_value);
|
||||
// store_tile(y_window, y_tile);
|
||||
// set_tile(out_tile, custom_padding_value);
|
||||
// store_tile(out_window, out_tile);
|
||||
|
||||
constexpr auto y_spans = y_tile.get_distributed_spans();
|
||||
constexpr auto y_spans = out_tile.get_distributed_spans();
|
||||
sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
y_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
|
||||
out_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_window, y_tile);
|
||||
store_tile(out_window, out_tile);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user