diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp index 69f0bee3bc..80c6fa40b1 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp @@ -29,7 +29,7 @@ template typename BlockTile, // block size, seq typename WarpTile, // warp size, seq typename ThreadTile> // contiguous pixels(vector size) along seq -struct SinkHornKnoppShape +struct SinkhornKnoppShape { static constexpr index_t Block_M = BlockTile::at(number<0>{}); static constexpr index_t Block_N = BlockTile::at(number<1>{}); diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index c9083230eb..c1b27e0c1f 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -53,6 +53,14 @@ struct SinkhornKnoppKernelReduce template struct SinkhornKnoppKernelDummyNonStochastic { + + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? kBlockSize / 2 : kBlockSize; + } + template CK_TILE_DEVICE static auto MakeYBlockTile() { @@ -67,16 +75,16 @@ struct SinkhornKnoppKernelDummyNonStochastic return tensor; } - CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const + CK_TILE_DEVICE void operator()([[maybe_unused]]const SinkhornKnoppArgs& args) const { - using S = Problem::BlockShape; + // using S = Problem::BlockShape; - static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); + // static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!"); - const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), - make_tuple(args.input_m, 1), - number<4>{}, // TODO: Hardcoded vectorization, we should calculate it! - number<1>{}); + // const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m), + // make_tuple(args.input_m, 1), + // number<4>{}, // TODO: Hardcoded vectorization, we should calculate it! + // number<1>{}); // auto buffer_view = make_buffer_view( // args.p_x, desc.get_element_space_size(), number<0>{}); diff --git a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp index 0fe5d9f356..3c5ccfc5f7 100644 --- a/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp +++ b/test/ck_tile/sinkhorn/test_sinkhorn_impl.hpp @@ -35,24 +35,25 @@ class TestCkTileSinkHorn: public ::testing::Test void RunGenericTest(const std::vector& input_shape, const int max_iterations) { + auto input_m = input_shape[0]; - SinkhornKnoppArgs args{}; - args.input_m = static_cast(input_shape[0]); - args.max_iterations = max_iterations; - - auto default_stride = {args.input_m, 1}; + auto default_stride = {input_m, 1}; ck_tile::HostTensor h_x(input_shape, default_stride); ck_tile::HostTensor h_y(input_shape, default_stride); ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); - auto buffer_size = h_xs.get_element_space_size_in_bytes(); + auto buffer_size = h_x.get_element_space_size_in_bytes(); ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_y_mem(output_buffer_size); + ck_tile::DeviceMem d_y_mem(buffer_size); - args.p_x = static_cast(d_x_mem.GetDeviceBuffer()); - args.out = static_cast(d_y_mem.GetDeviceBuffer()); + ck_tile::SinkhornKnoppArgs args{ + static_cast(d_y_mem.GetDeviceBuffer()), + static_cast(d_x_mem.GetDeviceBuffer()), + input_m, + max_iterations + }; d_x_mem.ToDevice(h_x.data()); d_y_mem.ToDevice(h_y.data()); @@ -64,13 +65,13 @@ class TestCkTileSinkHorn: public ::testing::Test >; using Kernel = ck_tile::SinkhornKnoppKernelDummyNonStochastic< Problem, - ck_tile::SinkhornKnoppPolicy>; + ck_tile::SinkhornKnoppDefaultPolicy>; // Launch configuration const ck_tile::index_t kBlockSize = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; - ck_tile::index_t kGridSize = 1 // TODO + ck_tile::index_t kGridSize = 1; // TODO //TODO // if(!Kernel::IsSupportedArgument())