mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
WIP
This commit is contained in:
@@ -42,15 +42,6 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
|
||||
sequence<2, 1>,
|
||||
sequence<3, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSum()
|
||||
{
|
||||
using br_problem = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
return BlockReduce2d<br_problem>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -26,7 +26,7 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
// // Creating tensor descriptors, views and windows for inputs and outputs
|
||||
// Creating tensor descriptors, views and windows for inputs and outputs
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::InDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
@@ -62,9 +62,9 @@ struct SinkhornKnoppKernelReduce
|
||||
Policy::template MakeInputBlockTileDistribution<Problem>());
|
||||
}();
|
||||
|
||||
const OutDataType out_padding_value = acc_op.template GetIdentityValue<OutDataType>();
|
||||
[[maybe_unused]] auto out_window = [&]() {
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
auto out_window = [&]() {
|
||||
const OutDataType out_padding_value = acc_op.template GetIdentityValue<OutDataType>();
|
||||
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
p_out, in_out_desc.get_element_space_size(), out_padding_value);
|
||||
|
||||
auto out_tensor = tensor_view<decltype(out_buffer_view), decltype(in_out_desc)>{
|
||||
@@ -72,17 +72,21 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
return make_tile_window(out_tensor,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{0, 0},
|
||||
{0, 0},
|
||||
Policy::template MakeInputBlockTileDistribution<Problem>());
|
||||
}();
|
||||
|
||||
auto input_tile = load_tile(input_window);
|
||||
|
||||
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// Exponentiate the input to make it strictly positive
|
||||
auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
|
||||
// auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
|
||||
// auto exp_func = []([[maybe_unused]] ComputeDataType x) {
|
||||
// return static_cast<ComputeDataType>(1.0);
|
||||
// };
|
||||
auto exp_func = [](InDataType x) -> ComputeDataType {
|
||||
return static_cast<ComputeDataType>(x);
|
||||
};
|
||||
|
||||
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
|
||||
|
||||
@@ -92,13 +96,30 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
|
||||
// Use BlockReduce2D for row and column sums
|
||||
auto row_sum = Policy::template GetSum<Problem>();
|
||||
auto row_sum = [&](const auto& c_tile) {
|
||||
// TODO: Handle case where the input doesn't fit in a single tile
|
||||
using br_problem = BlockReduce2dProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
|
||||
for(int i = 0; i <= args.iterations; i++)
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<br_problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<br_problem>();
|
||||
// TODO: Deduce/allow specifying a separate type for the accumulators?
|
||||
// NOTE: MakeYBlockTile defaults to reducing 2nd dimension
|
||||
auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
|
||||
set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
block_reduce2d(c_tile, acc_tile, acc_op);
|
||||
block_reduce2d_sync(acc_tile, acc_op);
|
||||
|
||||
return acc_tile;
|
||||
};
|
||||
|
||||
for(int i = 0; i < 1; i++)
|
||||
{
|
||||
// 1. Compute row sums (REDUCE)
|
||||
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
|
||||
auto row_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
|
||||
auto row_acc_tile = row_sum(compute_tile);
|
||||
|
||||
// 2. Divide values by row sums (SWEEP)
|
||||
constexpr auto c_spans = compute_tile.get_distributed_spans();
|
||||
@@ -106,14 +127,53 @@ struct SinkhornKnoppKernelReduce
|
||||
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto row_acc_idx = make_tuple(idx0);
|
||||
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print(row_acc_idx);
|
||||
print(":");
|
||||
print(row_acc_tile(row_acc_idx));
|
||||
print("\n");
|
||||
}
|
||||
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print("compute tile after rows summed and divided\n");
|
||||
[&](auto tile) {
|
||||
constexpr auto spans = tile.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto idx = make_tuple(idx0, idx1);
|
||||
print(idx);
|
||||
print(tile(idx));
|
||||
print("\n");
|
||||
});
|
||||
});
|
||||
}(compute_tile);
|
||||
}
|
||||
|
||||
transpose_tile2d(compute_tile_t, compute_tile);
|
||||
|
||||
// // Row sum is column sum for transposed c_tile
|
||||
auto col_acc_tile = row_sum(compute_tile_t, out_padding_value, acc_op);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print("compute tile transposed\n");
|
||||
[&](auto tile) {
|
||||
constexpr auto spans = tile.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto idx = make_tuple(idx0, idx1);
|
||||
print(idx);
|
||||
print(tile(idx));
|
||||
print("\n");
|
||||
});
|
||||
});
|
||||
}(compute_tile_t);
|
||||
}
|
||||
|
||||
// Row sum is column sum for transposed c_tile
|
||||
auto col_acc_tile = row_sum(compute_tile_t);
|
||||
|
||||
constexpr auto c_t_spans = compute_tile_t.get_distributed_spans();
|
||||
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
|
||||
@@ -125,22 +185,6 @@ struct SinkhornKnoppKernelReduce
|
||||
});
|
||||
|
||||
transpose_tile2d(compute_tile, compute_tile_t);
|
||||
// 3. STORE the result of the division (in transposed format)
|
||||
|
||||
// 4. LOAD transposed x
|
||||
// 5. Compute column sums (REDUCE)
|
||||
// auto col_acc_tile = row_sum(compute_tile, out_padding_value, acc_op);
|
||||
|
||||
// 6. Divide values by column sums (SWEEP)
|
||||
// constexpr auto ct_spans = compute_tile.get_distributed_spans();
|
||||
// sweep_tile_span(ct_spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(ct_spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto c_idx = make_tuple(idx0, idx1);
|
||||
// constexpr auto col_acc_idx = make_tuple(idx1);
|
||||
// compute_tile(c_idx) = compute_tile(c_idx) / col_acc_tile(acc_idx);
|
||||
// });
|
||||
// });
|
||||
// 7. STORE the result of the division (in transposed format)
|
||||
}
|
||||
|
||||
// Copy the final values to the output
|
||||
|
||||
Reference in New Issue
Block a user