WIP (algorithm correct, but inaccurate)

This commit is contained in:
Matti Eskelinen
2026-02-06 08:38:12 +00:00
parent 0662a6c799
commit 3b31e42359
4 changed files with 70 additions and 104 deletions

View File

@@ -23,8 +23,8 @@ void sinkhorn_knopp_ref(const HostTensor<XDataType>& x_n_n,
{
for(index_t j = 0; j < input_n; ++j)
{
// c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
// c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
}
}

View File

@@ -14,6 +14,63 @@ struct SinkhornKnoppArgs
int iterations;
};
template <typename T, typename F>
CK_TILE_DEVICE auto iterate_tile_2d(const T& my_tile, F on_item)
{
[on_item](auto tile) {
constexpr auto spans = tile.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
const auto item = tile(idx);
on_item(idx, item);
});
});
}(my_tile);
}
template <typename T, typename F>
CK_TILE_DEVICE auto iterate_tile_1d(const T& my_tile, F on_item)
{
[on_item](auto tile) {
constexpr auto spans = tile.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
constexpr auto idx = make_tuple(idx0);
const auto item = tile(idx);
on_item(idx, item);
});
}(my_tile);
}
template <typename T, typename F>
CK_TILE_DEVICE auto iterate_tile(const T& my_tile, F on_item)
{
constexpr index_t NDim = T::get_tile_distribution().get_num_of_dimension_y();
if constexpr(NDim == 1)
{
iterate_tile_1d(my_tile, on_item);
}
else if constexpr(NDim == 2)
{
iterate_tile_2d(my_tile, on_item);
}
};
constexpr int INSPECT_THREAD = 0;
template <typename T>
CK_TILE_DEVICE auto print_tile([[maybe_unused]] const T& my_tile)
{
if(threadIdx.x == INSPECT_THREAD)
{
iterate_tile(my_tile, []([[maybe_unused]] auto idx, auto item) {
print(idx);
printf(":");
print(item);
printf("\n");
});
}
}
template <typename Problem, typename Policy>
struct SinkhornKnoppKernelReduce
{
@@ -27,8 +84,6 @@ struct SinkhornKnoppKernelReduce
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
{
// constexpr int INSPECT_THREAD = 2;
// Creating tensor descriptors, views and windows for inputs and outputs
using S = Problem::BlockShape;
using InDataType = typename Problem::InDataType;
@@ -83,14 +138,8 @@ struct SinkhornKnoppKernelReduce
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the input to make it strictly positive
// auto exp_func = [](InDataType x) -> ComputeDataType {
// return ck_tile::exp(type_convert<ComputeDataType>(x));
// };
// auto exp_func = []([[maybe_unused]] InDataType x) {
// return static_cast<ComputeDataType>(1.0);
// };
auto exp_func = [](InDataType x) -> ComputeDataType {
return type_convert<ComputeDataType>(x);
return ck_tile::exp(type_convert<ComputeDataType>(x));
};
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
@@ -98,6 +147,7 @@ struct SinkhornKnoppKernelReduce
// Create a transposed tile
auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
Policy::template MakeTransposedInputBlockTileDistribution<Problem>());
set_tile(compute_tile_t, acc_op.template GetIdentityValue<ComputeDataType>());
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
// Use BlockReduce2D for row and column sums
@@ -109,11 +159,9 @@ struct SinkhornKnoppKernelReduce
auto block_reduce2d = Policy::template GetBlockReduce2d<br_problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<br_problem>();
// TODO: Deduce/allow specifying a separate type for the accumulators?
// NOTE: MakeYBlockTile defaults to reducing 2nd dimension
// auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
// set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
// TODO: Deduce/allow specifying a separate type for the accumulators?
// NOTE: block_reduce2d defaults to reducing 2nd dimension
auto acc_tile =
block_reduce2d(c_tile, acc_op.template GetIdentityValue<ComputeDataType>(), acc_op);
@@ -124,114 +172,34 @@ struct SinkhornKnoppKernelReduce
for(int i = 0; i < args.iterations; i++)
{
// 1. Compute row sums (REDUCE)
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
// 1. Compute row sums
auto row_acc_tile = row_sum(compute_tile);
// 2. Divide values by row sums (SWEEP)
// 2. Divide values by row sums
constexpr auto c_spans = compute_tile.get_distributed_spans();
sweep_tile_span(c_spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
constexpr auto c_idx = make_tuple(idx0, idx1);
constexpr auto row_acc_idx = make_tuple(idx0);
// if(threadIdx.x == INSPECT_THREAD)
// {
// print(row_acc_idx);
// print(":");
// print(row_acc_tile(row_acc_idx));
// print("\n");
// }
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
});
});
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("compute tile after normalization\n");
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// printf("\n");
// }(compute_tile);
// }
transpose_tile2d(compute_tile_t, compute_tile);
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("compute tile transposed\n");
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// }(compute_tile_t);
// }
// Row sum is column sum for transposed c_tile
// Row sum is a column sum for the transposed tile
auto col_acc_tile = row_sum(compute_tile_t);
constexpr auto c_t_spans = compute_tile_t.get_distributed_spans();
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(c_t_spans[number<1>{}], [&](const auto idx1) {
constexpr auto c_t_idx = make_tuple(idx0, idx1);
constexpr auto col_acc_idx = make_tuple(idx1);
// if(threadIdx.x == INSPECT_THREAD)
// {
// print(col_acc_idx);
// print(":");
// print(col_acc_tile(col_acc_idx));
// print("\n");
// }
constexpr auto col_acc_idx = make_tuple(idx0);
compute_tile_t(c_t_idx) = compute_tile_t(c_t_idx) / col_acc_tile(col_acc_idx);
});
});
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("transpose after normalizing\n");
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// }(compute_tile_t);
// }
transpose_tile2d(compute_tile, compute_tile_t);
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("thread %d, iter %d: original after transpose\n", INSPECT_THREAD, i);
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// }(compute_tile);
// }
}
// Copy the final values to the output