diff --git a/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp new file mode 100644 index 0000000000..9f7106ef84 --- /dev/null +++ b/include/ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" + +namespace ck_tile { + +struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeTransposedXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple< + sequence, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } +}; + +} // namespace ck_tile \ No newline at end of file diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index d8a6c1a99c..d437e8bdb1 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -1,11 +1,14 @@ #pragma once +#include "ck_tile/core.hpp" + namespace ck_tile { struct SinkhornKnoppArgs { + const void* p_x; + const index_t n; int max_iterations; - }; } // namespace ck_tile \ No newline at end of file