Debugging WIP

This commit is contained in:
Matti Eskelinen
2026-02-05 11:39:50 +00:00
parent ff428f3478
commit 0662a6c799
6 changed files with 147 additions and 88 deletions

View File

@@ -23,7 +23,8 @@ void sinkhorn_knopp_ref(const HostTensor<XDataType>& x_n_n,
{
for(index_t j = 0; j < input_n; ++j)
{
c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
// c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
}
}

View File

@@ -10,38 +10,48 @@ namespace ck_tile {
struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
{
using S = typename Problem::BlockShape;
// using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<1, 1>, sequence<2, 1>>,
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
sequence<2, 1>,
sequence<3, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 1>, sequence<1, 4>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2, 1>,
sequence<1, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
{
using S = typename Problem::BlockShape;
// using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>, // Repetitions (in input dimensions?)
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
sequence<2, 1>,
sequence<3, 3>>{});
tile_distribution_encoding<sequence<>,
tuple<sequence<1, 4>, sequence<4, 1>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2, 1>,
sequence<1, 1>>{});
}
// template <typename Problem>
// CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
// {
// using S = typename Problem::BlockShape;
// return make_static_tile_distribution(
// tile_distribution_encoding<
// sequence<>,
// tuple<
// sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M,
// S::ThreadTile_M>, sequence<S::Repeat_N, S::WarpPerBlock_N,
// S::ThreadPerWarp_N, S::ThreadTile_N>>,
// tuple<sequence<1, 2>, sequence<1, 2>>,
// tuple<sequence<1, 1>, sequence<1, 2>>,
// // WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
// sequence<2, 1>,
// sequence<3, 3>>{});
// }
};
} // namespace ck_tile

View File

@@ -42,18 +42,18 @@ struct SinkhornKnoppShape
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
static constexpr index_t RepeatInWarp =
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
static constexpr index_t RepeatInWarp_M =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
static constexpr index_t RepeatInWarp_N =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
// static constexpr index_t RepeatInWarp =
// Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
// static constexpr index_t RepeatInWarp_M =
// (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
// static constexpr index_t RepeatInWarp_N =
// (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
// static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
// static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize = ck_tile::get_warp_size();
};

View File

@@ -26,6 +26,9 @@ struct SinkhornKnoppKernelReduce
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
{
// constexpr int INSPECT_THREAD = 2;
// Creating tensor descriptors, views and windows for inputs and outputs
using S = Problem::BlockShape;
using InDataType = typename Problem::InDataType;
@@ -80,18 +83,20 @@ struct SinkhornKnoppKernelReduce
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the input to make it strictly positive
// auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
// auto exp_func = []([[maybe_unused]] ComputeDataType x) {
// auto exp_func = [](InDataType x) -> ComputeDataType {
// return ck_tile::exp(type_convert<ComputeDataType>(x));
// };
// auto exp_func = []([[maybe_unused]] InDataType x) {
// return static_cast<ComputeDataType>(1.0);
// };
auto exp_func = [](InDataType x) -> ComputeDataType {
return static_cast<ComputeDataType>(x);
return type_convert<ComputeDataType>(x);
};
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
// Create a transposed tile
[[maybe_unused]] auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
Policy::template MakeTransposedInputBlockTileDistribution<Problem>());
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
@@ -106,16 +111,18 @@ struct SinkhornKnoppKernelReduce
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<br_problem>();
// TODO: Deduce/allow specifying a separate type for the accumulators?
// NOTE: MakeYBlockTile defaults to reducing 2nd dimension
auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
// auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
// set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
auto acc_tile =
block_reduce2d(c_tile, acc_op.template GetIdentityValue<ComputeDataType>(), acc_op);
block_reduce2d(c_tile, acc_tile, acc_op);
block_reduce2d_sync(acc_tile, acc_op);
return acc_tile;
};
for(int i = 0; i < 1; i++)
for(int i = 0; i < args.iterations; i++)
{
// 1. Compute row sums (REDUCE)
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
@@ -127,50 +134,51 @@ struct SinkhornKnoppKernelReduce
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
constexpr auto c_idx = make_tuple(idx0, idx1);
constexpr auto row_acc_idx = make_tuple(idx0);
if(threadIdx.x == 0)
{
print(row_acc_idx);
print(":");
print(row_acc_tile(row_acc_idx));
print("\n");
}
// if(threadIdx.x == INSPECT_THREAD)
// {
// print(row_acc_idx);
// print(":");
// print(row_acc_tile(row_acc_idx));
// print("\n");
// }
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
});
});
if(threadIdx.x == 0)
{
print("compute tile after rows summed and divided\n");
[&](auto tile) {
constexpr auto spans = tile.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
print(idx);
print(tile(idx));
print("\n");
});
});
}(compute_tile);
}
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("compute tile after normalization\n");
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// printf("\n");
// }(compute_tile);
// }
transpose_tile2d(compute_tile_t, compute_tile);
if(threadIdx.x == 0)
{
print("compute tile transposed\n");
[&](auto tile) {
constexpr auto spans = tile.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
print(idx);
print(tile(idx));
print("\n");
});
});
}(compute_tile_t);
}
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("compute tile transposed\n");
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// }(compute_tile_t);
// }
// Row sum is column sum for transposed c_tile
auto col_acc_tile = row_sum(compute_tile_t);
@@ -179,16 +187,55 @@ struct SinkhornKnoppKernelReduce
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
sweep_tile_span(c_t_spans[number<1>{}], [&](const auto idx1) {
constexpr auto c_t_idx = make_tuple(idx0, idx1);
constexpr auto col_acc_idx = make_tuple(idx0);
constexpr auto col_acc_idx = make_tuple(idx1);
// if(threadIdx.x == INSPECT_THREAD)
// {
// print(col_acc_idx);
// print(":");
// print(col_acc_tile(col_acc_idx));
// print("\n");
// }
compute_tile_t(c_t_idx) = compute_tile_t(c_t_idx) / col_acc_tile(col_acc_idx);
});
});
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("transpose after normalizing\n");
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// }(compute_tile_t);
// }
transpose_tile2d(compute_tile, compute_tile_t);
// if(threadIdx.x == INSPECT_THREAD)
// {
// printf("thread %d, iter %d: original after transpose\n", INSPECT_THREAD, i);
// [&](auto tile) {
// constexpr auto spans = tile.get_distributed_spans();
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
// constexpr auto idx = make_tuple(idx0, idx1);
// print(idx);
// print(tile(idx));
// print("\n");
// });
// });
// }(compute_tile);
// }
}
// Copy the final values to the output
store_tile(out_window, compute_tile);
store_tile(out_window, cast_tile<OutDataType>(compute_tile));
}
};

View File

@@ -16,10 +16,10 @@
#include "test_sinkhorn_impl.hpp"
// Shape parameters for different test configurations
using Shape1_BlockWarps = ck_tile::sequence<4, 1>;
using Shape1_BlockTile = ck_tile::sequence<128, 128>;
using Shape1_WarpTile = ck_tile::sequence<32, 128>;
using Shape1_ThreadTile = ck_tile::sequence<1, 4>;
using Shape1_BlockWarps = ck_tile::sequence<1, 1>;
using Shape1_BlockTile = ck_tile::sequence<4, 4>;
using Shape1_WarpTile = ck_tile::sequence<4, 4>;
using Shape1_ThreadTile = ck_tile::sequence<4, 1>;
// Test configurations for different data types and input size
using TestConfig_F16 = std::tuple<float, // ck_tile::half_t, // XDataType
@@ -34,4 +34,4 @@ using TestTypes = ::testing::Types<TestConfig_F16>;
TYPED_TEST_SUITE(TestCkTileSinkHorn, TestTypes);
TYPED_TEST(TestCkTileSinkHorn, Test_4x4) { this->RunGenericTest({4, 4}, 20); }
TYPED_TEST(TestCkTileSinkHorn, Test_4x4) { this->RunGenericTest({4, 4}, 1); }

View File

@@ -31,13 +31,14 @@ class TestCkTileSinkHorn : public ::testing::Test
void RunGenericTest(const std::vector<ck_tile::index_t>& input_shape, const int max_iterations)
{
auto input_n = input_shape[0];
auto default_stride = {input_n, 1};
auto default_stride = {1, input_n};
ck_tile::HostTensor<XDataType> h_x(input_shape, default_stride);
ck_tile::HostTensor<YDataType> h_y(input_shape, default_stride);
// ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
ck_tile::FillMonotonicSeq<XDataType>{}(h_x);
std::cout << h_x << std::endl;
auto buffer_size = h_x.get_element_space_size_in_bytes();
ck_tile::DeviceMem d_x_mem(buffer_size);