mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
WIP
This commit is contained in:
@@ -20,10 +20,10 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 1>>,
|
||||
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
|
||||
sequence<2, 2, 1, 1>,
|
||||
sequence<0, 3, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
|
||||
sequence<2, 1>,
|
||||
sequence<3, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -38,8 +38,9 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
|
||||
sequence<2, 1>,
|
||||
sequence<3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -27,7 +27,6 @@ struct SinkhornKnoppKernelReduce
|
||||
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
// // Creating tensor descriptors, views and windows for inputs and outputs
|
||||
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::InDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
@@ -43,7 +42,7 @@ struct SinkhornKnoppKernelReduce
|
||||
const auto in_out_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
|
||||
make_tuple(args.input_m, 1),
|
||||
number<16>{}, // TODO: Hardcoded
|
||||
number<4>{}, // TODO: Hardcoded
|
||||
// vectorization, //we should calculate it!
|
||||
number<1>{});
|
||||
|
||||
@@ -80,7 +79,10 @@ struct SinkhornKnoppKernelReduce
|
||||
auto input_tile = load_tile(input_window);
|
||||
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// Exponentiate the input to make it strictly positive
|
||||
auto exp_func = [](ComputeDataType x) { return ck_tile::exp(x); };
|
||||
auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
|
||||
// auto exp_func = []([[maybe_unused]] ComputeDataType x) {
|
||||
// return static_cast<ComputeDataType>(1.0);
|
||||
// };
|
||||
|
||||
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
|
||||
|
||||
@@ -91,6 +93,7 @@ struct SinkhornKnoppKernelReduce
|
||||
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
|
||||
// Use BlockReduce2D for row and column sums
|
||||
auto row_sum = Policy::template GetSum<Problem>();
|
||||
|
||||
for(int i = 0; i <= args.iterations; i++)
|
||||
{
|
||||
// 1. Compute row sums (REDUCE)
|
||||
@@ -107,8 +110,22 @@ struct SinkhornKnoppKernelReduce
|
||||
});
|
||||
});
|
||||
|
||||
transpose_tile2d(compute_tile_t, compute_tile);
|
||||
|
||||
// // Row sum is column sum for transposed c_tile
|
||||
auto col_acc_tile = row_sum(compute_tile_t, out_padding_value, acc_op);
|
||||
|
||||
constexpr auto c_t_spans = compute_tile_t.get_distributed_spans();
|
||||
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(c_t_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_t_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto col_acc_idx = make_tuple(idx0);
|
||||
compute_tile_t(c_t_idx) = compute_tile_t(c_t_idx) / col_acc_tile(col_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
transpose_tile2d(compute_tile, compute_tile_t);
|
||||
// 3. STORE the result of the division (in transposed format)
|
||||
// store_tile(compute_tile_t, compute_tile);
|
||||
|
||||
// 4. LOAD transposed x
|
||||
// 5. Compute column sums (REDUCE)
|
||||
|
||||
Reference in New Issue
Block a user