mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Debugging WIP
This commit is contained in:
@@ -23,7 +23,8 @@ void sinkhorn_knopp_ref(const HostTensor<XDataType>& x_n_n,
|
||||
{
|
||||
for(index_t j = 0; j < input_n; ++j)
|
||||
{
|
||||
c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
|
||||
// c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
|
||||
c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,38 +10,48 @@ namespace ck_tile {
|
||||
struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
// using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
|
||||
tuple<sequence<2, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 1>>,
|
||||
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
|
||||
sequence<2, 1>,
|
||||
sequence<3, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4, 1>, sequence<1, 4>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
// using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // Repetitions (in input dimensions?)
|
||||
tuple<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>,
|
||||
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 2>>,
|
||||
// WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
|
||||
sequence<2, 1>,
|
||||
sequence<3, 3>>{});
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<1, 4>, sequence<4, 1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
|
||||
// {
|
||||
// using S = typename Problem::BlockShape;
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<>,
|
||||
// tuple<
|
||||
// sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M,
|
||||
// S::ThreadTile_M>, sequence<S::Repeat_N, S::WarpPerBlock_N,
|
||||
// S::ThreadPerWarp_N, S::ThreadTile_N>>,
|
||||
// tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
// tuple<sequence<1, 1>, sequence<1, 2>>,
|
||||
// // WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N
|
||||
// sequence<2, 1>,
|
||||
// sequence<3, 3>>{});
|
||||
// }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -42,18 +42,18 @@ struct SinkhornKnoppShape
|
||||
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});
|
||||
|
||||
static constexpr index_t RepeatInWarp =
|
||||
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
|
||||
static constexpr index_t RepeatInWarp_M =
|
||||
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
|
||||
static constexpr index_t RepeatInWarp_N =
|
||||
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
|
||||
// static constexpr index_t RepeatInWarp =
|
||||
// Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
|
||||
// static constexpr index_t RepeatInWarp_M =
|
||||
// (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
|
||||
// static constexpr index_t RepeatInWarp_N =
|
||||
// (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
// static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
|
||||
// static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize = ck_tile::get_warp_size();
|
||||
};
|
||||
|
||||
@@ -26,6 +26,9 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
|
||||
{
|
||||
|
||||
// constexpr int INSPECT_THREAD = 2;
|
||||
|
||||
// Creating tensor descriptors, views and windows for inputs and outputs
|
||||
using S = Problem::BlockShape;
|
||||
using InDataType = typename Problem::InDataType;
|
||||
@@ -80,18 +83,20 @@ struct SinkhornKnoppKernelReduce
|
||||
|
||||
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
|
||||
// Exponentiate the input to make it strictly positive
|
||||
// auto exp_func = [](ComputeDataType x) -> ComputeDataType { return ck_tile::exp(x); };
|
||||
// auto exp_func = []([[maybe_unused]] ComputeDataType x) {
|
||||
// auto exp_func = [](InDataType x) -> ComputeDataType {
|
||||
// return ck_tile::exp(type_convert<ComputeDataType>(x));
|
||||
// };
|
||||
// auto exp_func = []([[maybe_unused]] InDataType x) {
|
||||
// return static_cast<ComputeDataType>(1.0);
|
||||
// };
|
||||
auto exp_func = [](InDataType x) -> ComputeDataType {
|
||||
return static_cast<ComputeDataType>(x);
|
||||
return type_convert<ComputeDataType>(x);
|
||||
};
|
||||
|
||||
auto compute_tile = tile_elementwise_in(exp_func, input_tile);
|
||||
|
||||
// Create a transposed tile
|
||||
[[maybe_unused]] auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
|
||||
auto compute_tile_t = make_static_distributed_tensor<ComputeDataType>(
|
||||
Policy::template MakeTransposedInputBlockTileDistribution<Problem>());
|
||||
|
||||
// Hot loop for Sinkhorn-Knopp iterations from 1 to iterations
|
||||
@@ -106,16 +111,18 @@ struct SinkhornKnoppKernelReduce
|
||||
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<br_problem>();
|
||||
// TODO: Deduce/allow specifying a separate type for the accumulators?
|
||||
// NOTE: MakeYBlockTile defaults to reducing 2nd dimension
|
||||
auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
|
||||
set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
|
||||
// auto acc_tile = block_reduce2d.template MakeYBlockTile<decltype(c_tile)>();
|
||||
// set_tile(acc_tile, acc_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
auto acc_tile =
|
||||
block_reduce2d(c_tile, acc_op.template GetIdentityValue<ComputeDataType>(), acc_op);
|
||||
|
||||
block_reduce2d(c_tile, acc_tile, acc_op);
|
||||
block_reduce2d_sync(acc_tile, acc_op);
|
||||
|
||||
return acc_tile;
|
||||
};
|
||||
|
||||
for(int i = 0; i < 1; i++)
|
||||
for(int i = 0; i < args.iterations; i++)
|
||||
{
|
||||
// 1. Compute row sums (REDUCE)
|
||||
// FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead
|
||||
@@ -127,50 +134,51 @@ struct SinkhornKnoppKernelReduce
|
||||
sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto row_acc_idx = make_tuple(idx0);
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print(row_acc_idx);
|
||||
print(":");
|
||||
print(row_acc_tile(row_acc_idx));
|
||||
print("\n");
|
||||
}
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// print(row_acc_idx);
|
||||
// print(":");
|
||||
// print(row_acc_tile(row_acc_idx));
|
||||
// print("\n");
|
||||
// }
|
||||
compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print("compute tile after rows summed and divided\n");
|
||||
[&](auto tile) {
|
||||
constexpr auto spans = tile.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto idx = make_tuple(idx0, idx1);
|
||||
print(idx);
|
||||
print(tile(idx));
|
||||
print("\n");
|
||||
});
|
||||
});
|
||||
}(compute_tile);
|
||||
}
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("compute tile after normalization\n");
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// printf("\n");
|
||||
// }(compute_tile);
|
||||
// }
|
||||
|
||||
transpose_tile2d(compute_tile_t, compute_tile);
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print("compute tile transposed\n");
|
||||
[&](auto tile) {
|
||||
constexpr auto spans = tile.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto idx = make_tuple(idx0, idx1);
|
||||
print(idx);
|
||||
print(tile(idx));
|
||||
print("\n");
|
||||
});
|
||||
});
|
||||
}(compute_tile_t);
|
||||
}
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("compute tile transposed\n");
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// }(compute_tile_t);
|
||||
// }
|
||||
|
||||
// Row sum is column sum for transposed c_tile
|
||||
auto col_acc_tile = row_sum(compute_tile_t);
|
||||
@@ -179,16 +187,55 @@ struct SinkhornKnoppKernelReduce
|
||||
sweep_tile_span(c_t_spans[number<0>{}], [&](const auto idx0) {
|
||||
sweep_tile_span(c_t_spans[number<1>{}], [&](const auto idx1) {
|
||||
constexpr auto c_t_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto col_acc_idx = make_tuple(idx0);
|
||||
constexpr auto col_acc_idx = make_tuple(idx1);
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// print(col_acc_idx);
|
||||
// print(":");
|
||||
// print(col_acc_tile(col_acc_idx));
|
||||
// print("\n");
|
||||
// }
|
||||
compute_tile_t(c_t_idx) = compute_tile_t(c_t_idx) / col_acc_tile(col_acc_idx);
|
||||
});
|
||||
});
|
||||
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("transpose after normalizing\n");
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// }(compute_tile_t);
|
||||
// }
|
||||
|
||||
transpose_tile2d(compute_tile, compute_tile_t);
|
||||
|
||||
// if(threadIdx.x == INSPECT_THREAD)
|
||||
// {
|
||||
// printf("thread %d, iter %d: original after transpose\n", INSPECT_THREAD, i);
|
||||
// [&](auto tile) {
|
||||
// constexpr auto spans = tile.get_distributed_spans();
|
||||
// sweep_tile_span(spans[number<0>{}], [&](const auto idx0) {
|
||||
// sweep_tile_span(spans[number<1>{}], [&](const auto idx1) {
|
||||
// constexpr auto idx = make_tuple(idx0, idx1);
|
||||
// print(idx);
|
||||
// print(tile(idx));
|
||||
// print("\n");
|
||||
// });
|
||||
// });
|
||||
// }(compute_tile);
|
||||
// }
|
||||
}
|
||||
|
||||
// Copy the final values to the output
|
||||
store_tile(out_window, compute_tile);
|
||||
store_tile(out_window, cast_tile<OutDataType>(compute_tile));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
#include "test_sinkhorn_impl.hpp"
|
||||
|
||||
// Shape parameters for different test configurations
|
||||
using Shape1_BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using Shape1_BlockTile = ck_tile::sequence<128, 128>;
|
||||
using Shape1_WarpTile = ck_tile::sequence<32, 128>;
|
||||
using Shape1_ThreadTile = ck_tile::sequence<1, 4>;
|
||||
using Shape1_BlockWarps = ck_tile::sequence<1, 1>;
|
||||
using Shape1_BlockTile = ck_tile::sequence<4, 4>;
|
||||
using Shape1_WarpTile = ck_tile::sequence<4, 4>;
|
||||
using Shape1_ThreadTile = ck_tile::sequence<4, 1>;
|
||||
|
||||
// Test configurations for different data types and input size
|
||||
using TestConfig_F16 = std::tuple<float, // ck_tile::half_t, // XDataType
|
||||
@@ -34,4 +34,4 @@ using TestTypes = ::testing::Types<TestConfig_F16>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileSinkHorn, TestTypes);
|
||||
|
||||
TYPED_TEST(TestCkTileSinkHorn, Test_4x4) { this->RunGenericTest({4, 4}, 20); }
|
||||
TYPED_TEST(TestCkTileSinkHorn, Test_4x4) { this->RunGenericTest({4, 4}, 1); }
|
||||
|
||||
@@ -31,13 +31,14 @@ class TestCkTileSinkHorn : public ::testing::Test
|
||||
void RunGenericTest(const std::vector<ck_tile::index_t>& input_shape, const int max_iterations)
|
||||
{
|
||||
auto input_n = input_shape[0];
|
||||
auto default_stride = {input_n, 1};
|
||||
auto default_stride = {1, input_n};
|
||||
|
||||
ck_tile::HostTensor<XDataType> h_x(input_shape, default_stride);
|
||||
ck_tile::HostTensor<YDataType> h_y(input_shape, default_stride);
|
||||
|
||||
// ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(h_x);
|
||||
ck_tile::FillMonotonicSeq<XDataType>{}(h_x);
|
||||
std::cout << h_x << std::endl;
|
||||
|
||||
auto buffer_size = h_x.get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem d_x_mem(buffer_size);
|
||||
|
||||
Reference in New Issue
Block a user