diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp index 81826004e0..e0e32a304f 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp @@ -10,7 +10,7 @@ namespace ck_tile { struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy { template - CK_TILE_DEVICE static constexpr auto MakeTransposedXBlockTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution() { using S = typename Problem::BlockShape; return make_static_tile_distribution( @@ -20,29 +20,36 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy sequence, sequence>, tuple, sequence<2, 1>>, - tuple, sequence<2, 1>>, // WarpPerBlock_M, WarpPerBlock_N; - // ThreadPerWarp_M, ThreadPerWarp_N - // sequence<1, 1, 2, 2>, - // sequence<0, 3, 0, 3>>{}); // Repeat_N, ThreadTile_N, Repeat_M, ThreadTile_M + tuple, sequence<2, 2>>, + // WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N sequence<2, 2, 1, 1>, - sequence<0, 2, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N + sequence<0, 3, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N } template - CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution() { using S = typename Problem::BlockShape; return make_static_tile_distribution( tile_distribution_encoding< - sequence<>, + sequence<>, // Repetitions (in input dimensions?) tuple< sequence, sequence>, tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, + tuple, sequence<1, 2>>, sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSum() + { + using br_problem = BlockReduce2dProblem; + return BlockReduce2d{}; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index 65416e3fe3..d5ac9d63fa 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -11,85 +11,123 @@ struct SinkhornKnoppArgs void* p_out; const void* p_in; const index_t input_m; - int max_iterations; + int iterations; }; template struct SinkhornKnoppKernelReduce { - CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? kBlockSize / 2 : kBlockSize; + } + + CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const { // // Creating tensor descriptors, views and windows for inputs and outputs - using S = Problem::BlockShape; - using InDataType = typename Problem::OutDataType; - using OutDataType = typename Problem::OutDataType; + using S = Problem::BlockShape; + using InDataType = typename Problem::InDataType; + using ComputeDataType = typename Problem::ComputeDataType; + using OutDataType = typename Problem::OutDataType; static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); auto* p_in = static_cast(args.p_in); auto* p_out = static_cast(args.p_out); - [[maybe_unused]] auto exp_op = ck_tile::element_wise::Exp{}; - [[maybe_unused]] auto acc_op = ck_tile::ReduceOp::Add{}; - [[maybe_unused]] auto div_op = ck_tile::element_wise::UnaryDivide{}; + auto acc_op = ck_tile::ReduceOp::Add{}; - // We require exp(input) > 0, and exp(padding) == 0 - const InDataType x_padding_value = -ck_tile::numeric::infinity(); - - const auto in_desc = + const auto in_out_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), make_tuple(args.input_m, 1), - number<4>{}, // TODO: Hardcoded + number<16>{}, // TODO: Hardcoded // vectorization, //we should calculate it! number<1>{}); - auto buffer_view = make_buffer_view( - p_in, in_desc.get_element_space_size(), x_padding_value); + const auto input_window = [&]() { + // We require exp(input) > 0, and exp(padding) == 0 + const InDataType input_padding_value = -ck_tile::numeric::infinity(); - const auto x_tensor = - tensor_view{buffer_view, in_desc}; + auto buffer_view = make_buffer_view( + p_in, in_out_desc.get_element_space_size(), input_padding_value); - [[maybe_unused]] auto x_window = - make_tile_window(x_tensor, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeXBlockTileDistribution()); + const auto in_tensor = + tensor_view{buffer_view, in_out_desc}; - const OutDataType y_padding_value = acc_op.template GetIdentityValue(); - auto out_buffer_view = make_buffer_view( - p_out, in_desc.get_element_space_size(), y_padding_value); + return make_tile_window(in_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeInputBlockTileDistribution()); + }(); - auto y_tensor = - tensor_view{out_buffer_view, in_desc}; + const OutDataType out_padding_value = acc_op.template GetIdentityValue(); + [[maybe_unused]] auto out_window = [&]() { + auto out_buffer_view = make_buffer_view( + p_out, in_out_desc.get_element_space_size(), out_padding_value); - [[maybe_unused]] auto y_window = - make_tile_window(y_tensor, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeXBlockTileDistribution()); + auto out_tensor = tensor_view{ + out_buffer_view, in_out_desc}; - [[maybe_unused]] auto c_window = - make_null_tile_window(make_tuple(number{}, number{})); - [[maybe_unused]] auto x_tile = load_tile(x_window); + return make_tile_window(out_tensor, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeInputBlockTileDistribution()); + }(); - // // Run the first steps iteration of the Sinkhorn-Knopp algorithm - // // Exponentiate the matrix x - // elementwise() + auto input_tile = load_tile(input_window); + // Run the first steps iteration of the Sinkhorn-Knopp algorithm + // Exponentiate the input to make it strictly positive + auto exp_func = [](ComputeDataType x) { return ck_tile::exp(x); }; - // // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations - // // Use BlockReduce2D for row and column sums - // for(int i = 0; i <= args.max_iterations; i++) - // { - // // 0. LOAD x - // // 1. Compute row sums (REDUCE) - // // 2. Divide values by row sums (SWEEP) - // // 3. STORE the result of the division (in transposed format) - // // 4. LOAD transposed x - // // 5. Compute column sums (REDUCE) - // // 6. Divide values by column sums (SWEEP) - // // 7. STORE the result of the division (in transposed format) - // } + auto compute_tile = tile_elementwise_in(exp_func, input_tile); + + // Create a transposed tile + [[maybe_unused]] auto compute_tile_t = make_static_distributed_tensor( + Policy::template MakeTransposedInputBlockTileDistribution()); + + // Hot loop for Sinkhorn-Knopp iterations from 1 to iterations + // Use BlockReduce2D for row and column sums + auto row_sum = Policy::template GetSum(); + for(int i = 0; i <= args.iterations; i++) + { + // 1. Compute row sums (REDUCE) + // FIXME: Uses overload that is hardcoded to reduce 2nd dimension, be explicit instead + auto row_acc_tile = row_sum(compute_tile, out_padding_value, acc_op); + + // 2. Divide values by row sums (SWEEP) + constexpr auto c_spans = compute_tile.get_distributed_spans(); + sweep_tile_span(c_spans[number<0>{}], [&](const auto idx0) { + sweep_tile_span(c_spans[number<1>{}], [&](const auto idx1) { + constexpr auto c_idx = make_tuple(idx0, idx1); + constexpr auto row_acc_idx = make_tuple(idx0); + compute_tile(c_idx) = compute_tile(c_idx) / row_acc_tile(row_acc_idx); + }); + }); + + // 3. STORE the result of the division (in transposed format) + // store_tile(compute_tile_t, compute_tile); + + // 4. LOAD transposed x + // 5. Compute column sums (REDUCE) + // auto col_acc_tile = row_sum(compute_tile, out_padding_value, acc_op); + + // 6. Divide values by column sums (SWEEP) + // constexpr auto ct_spans = compute_tile.get_distributed_spans(); + // sweep_tile_span(ct_spans[number<0>{}], [&](const auto idx0) { + // sweep_tile_span(ct_spans[number<1>{}], [&](const auto idx1) { + // constexpr auto c_idx = make_tuple(idx0, idx1); + // constexpr auto col_acc_idx = make_tuple(idx1); + // compute_tile(c_idx) = compute_tile(c_idx) / col_acc_tile(acc_idx); + // }); + // }); + // 7. STORE the result of the division (in transposed format) + } + + // Copy the final values to the output + store_tile(out_window, compute_tile); } }; @@ -104,20 +142,6 @@ struct SinkhornKnoppKernelDummyNonStochastic return is_wave32() ? kBlockSize / 2 : kBlockSize; } - template - CK_TILE_DEVICE static auto MakeComputeBlockTile() - { - constexpr auto dstr = - make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( - XDistributedTensor_::get_tile_distribution() - .get_static_tile_distribution_encoding(), - sequence<0>{})); - - auto tensor = make_static_distributed_tensor(dstr); - - return tensor; - } - // template // CK_TILE_DEVICE static auto MakeYBlockTile() // { @@ -171,11 +195,11 @@ struct SinkhornKnoppKernelDummyNonStochastic in_desc.get_element_space_size(), type_convert(custom_padding_value)); - auto y_tensor = + auto out_tensor = tensor_view{out_buffer_view, in_desc}; - [[maybe_unused]] auto y_window = - make_tile_window(y_tensor, + [[maybe_unused]] auto out_window = + make_tile_window(out_tensor, make_tuple(number{}, number{}), {0, 0}, Policy::template MakeXBlockTileDistribution()); @@ -183,24 +207,24 @@ struct SinkhornKnoppKernelDummyNonStochastic [[maybe_unused]] auto input_tile = load_tile(input_window); - // auto y_tile = MakeYBlockTile(); - auto y_tile = make_static_distributed_tensor( + // auto out_tile = MakeYBlockTile(); + auto out_tile = make_static_distributed_tensor( Policy::template MakeXBlockTileDistribution()); // Set all output elements to the custom padding value. // // Simple solution to set the whole tile to a constant // - // set_tile(y_tile, custom_padding_value); - // store_tile(y_window, y_tile); + // set_tile(out_tile, custom_padding_value); + // store_tile(out_window, out_tile); - constexpr auto y_spans = y_tile.get_distributed_spans(); + constexpr auto y_spans = out_tile.get_distributed_spans(); sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) { constexpr auto distributed_indices = make_tuple(idx0, idx1); - y_tile(distributed_indices) = type_convert(custom_padding_value); + out_tile(distributed_indices) = type_convert(custom_padding_value); }); }); - store_tile(y_window, y_tile); + store_tile(out_window, out_tile); } }; diff --git a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp index 0fc8ba826a..cfbc4261d8 100644 --- a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp +++ b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp @@ -39,7 +39,7 @@ class TestCkTileSinkHorn : public ::testing::Test ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); auto buffer_size = h_x.get_element_space_size_in_bytes(); - ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_x_mem(buffer_size); ck_tile::DeviceMem d_y_mem(buffer_size); ck_tile::SinkhornKnoppArgs args{static_cast(d_y_mem.GetDeviceBuffer()), @@ -53,8 +53,7 @@ class TestCkTileSinkHorn : public ::testing::Test using Problem = ck_tile::SinkhornKnoppProblem; using Kernel = - ck_tile::SinkhornKnoppKernelDummyNonStochastic; + ck_tile::SinkhornKnoppKernelReduce; // Launch configuration const ck_tile::index_t kBlockSize = Kernel::BlockSize();