mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
WIP (algorithm correct, but inaccurate)
This commit is contained in:
@@ -23,8 +23,8 @@ void sinkhorn_knopp_ref(const HostTensor<XDataType>& x_n_n,
|
||||
{
|
||||
for(index_t j = 0; j < input_n; ++j)
|
||||
{
|
||||
// c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
|
||||
c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
|
||||
c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
|
||||
// c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,63 @@ struct SinkhornKnoppArgs
|
||||
int iterations;
|
||||
};
|
||||
|
||||
template <typename T, typename F>
|
||||
CK_TILE_DEVICE auto iterate_tile_2d(const T& my_tile, F on_item)
|
||||
{
|
||||
[on_item](auto tile) {
|
||||
constexpr auto spans = tile.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto idx = make_tuple(idx0, idx1);
|
||||
const auto item = tile(idx);
|
||||
on_item(idx, item);
|
||||
});
|
||||
});
|
||||
}(my_tile);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
CK_TILE_DEVICE auto iterate_tile_1d(const T& my_tile, F on_item)
|
||||
{
|
||||
[on_item](auto tile) {
|
||||
constexpr auto spans = tile.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
constexpr auto idx = make_tuple(idx0);
|
||||
const auto item = tile(idx);
|
||||
on_item(idx, item);
|
||||
});
|
||||
}(my_tile);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
CK_TILE_DEVICE auto iterate_tile(const T& my_tile, F on_item)
|
||||
{
|
||||
constexpr index_t NDim = T::get_tile_distribution().get_num_of_dimension_y();
|
||||
if constexpr(NDim == 1)
|
||||
{
|
||||
iterate_tile_1d(my_tile, on_item);
|
||||
}
|
||||
else if constexpr(NDim == 2)
|
||||
{
|
||||
iterate_tile_2d(my_tile, on_item);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr int INSPECT_THREAD = 0;
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto print_tile([[maybe_unused]] const T& my_tile)
|
||||
{
|
||||
if(threadIdx.x == INSPECT_THREAD)
|
||||
{
|
||||
iterate_tile(my_tile, []([[maybe_unused]] auto idx, auto item) {
|
||||
print(idx);
|
||||
printf(":");
|
||||
print(item);
|
||||
printf("\n");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct SinkhornKnoppKernelReduce
|
||||
{
|
||||
@@ -27,8 +84,6 @@ struct SinkhornKnoppKernelReduce
|
||||
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
|
||||
// constexpr int INSPECT_THREAD = 2;
|
||||
|
||||
// Creating tensor descriptors, views and windows for inputs and outputs
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::InDataType;
|
||||
@@ -83,14 +138,8 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// Exponentiate the input to make it strictly positive
|
||||
// auto exp_func = [](InDataType x) -> ComputeDataType {
|
||||
// return ck_tile::exp(type_convert<ComputeDataType>(x));
|
||||
// };
|
||||
// auto exp_func = []([[maybe_unused]] InDataType x) {
|
||||
// return static_cast<ComputeDataType>(1.0);
|
||||
// };
|
||||
auto exp_func = [](InDataType x) -> ComputeDataType {
|
||||
return type_convert<ComputeDataType>(x);
|
||||
return ck_tile::exp(type_convert<ComputeDataType>(x));
|
||||
};
|
||||
|
||||
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
|
||||
@@ -98,6 +147,7 @@ struct SinkhornKnoppKernelReduce
|
||||
// Create a transposed tile
|
||||
auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
|
||||
Policy::template MakeTransposedInputBlockTileDistribution<Problem>());
|
||||
set_tile(compute_tile_t, acc_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
|
||||
// Use BlockReduce2D for row and column sums
|
||||
@@ -109,11 +159,9 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
auto block_reduce2d = Policy::template GetBlockReduce2d<br_problem>();
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<br_problem>();
|
||||
// TODO: Deduce/allow specifying a separate type for the accumulators?
|
||||
// NOTE: MakeYBlockTile defaults to reducing 2nd dimension
|
||||
// auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
|
||||
// set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
// TODO: Deduce/allow specifying a separate type for the accumulators?
|
||||
// NOTE: block_reduce2d defaults to reducing 2nd dimension
|
||||
auto acc_tile =
|
||||
block_reduce2d(c_tile, acc_op.template GetIdentityValue<ComputeDataType>(), acc_op);
|
||||
|
||||
@@ -124,114 +172,34 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
for(int i = 0; i < args.iterations; i++)
|
||||
{
|
||||
// 1. Compute row sums (REDUCE)
|
||||
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
|
||||
// 1. Compute row sums
|
||||
auto row_acc_tile = row_sum(compute_tile);
|
||||
|
||||
// 2. Divide values by row sums (SWEEP)
|
||||
// 2. Divide values by row sums
|
||||
constexpr auto c_spans = compute_tile.get_distributed_spans();
|
||||
sweep_tile_span(c_spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto row_acc_idx = make_tuple(idx0);
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// print(row_acc_idx);
|
||||
// print(":");
|
||||
// print(row_acc_tile(row_acc_idx));
|
||||
// print("\n");
|
||||
// }
|
||||
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
|
||||
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("compute tile after normalization\n");
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// printf("\n");
|
||||
// }(compute_tile);
|
||||
// }
|
||||
|
||||
transpose_tile2d(compute_tile_t, compute_tile);
|
||||
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("compute tile transposed\n");
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// }(compute_tile_t);
|
||||
// }
|
||||
|
||||
// Row sum is column sum for transposed c_tile
|
||||
// Row sum is a column sum for the transposed tile
|
||||
auto col_acc_tile = row_sum(compute_tile_t);
|
||||
|
||||
constexpr auto c_t_spans = compute_tile_t.get_distributed_spans();
|
||||
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(c_t_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_t_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto col_acc_idx = make_tuple(idx1);
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// print(col_acc_idx);
|
||||
// print(":");
|
||||
// print(col_acc_tile(col_acc_idx));
|
||||
// print("\n");
|
||||
// }
|
||||
constexpr auto col_acc_idx = make_tuple(idx0);
|
||||
compute_tile_t(c_t_idx) = compute_tile_t(c_t_idx) / col_acc_tile(col_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("transpose after normalizing\n");
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// }(compute_tile_t);
|
||||
// }
|
||||
|
||||
transpose_tile2d(compute_tile, compute_tile_t);
|
||||
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("thread %d, iter %d: original after transpose\n", INSPECT_THREAD, i);
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// }(compute_tile);
|
||||
// }
|
||||
}
|
||||
|
||||
// Copy the final values to the output
|
||||
|
||||
Reference in New Issue
Block a user