Use defined shape

This commit is contained in:
Matti Eskelinen
2026-02-06 10:43:54 +00:00
parent 2b5a5e364c
commit deecabacf8

View File

@@ -12,10 +12,10 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeInputBlockTileDistribution()
{
// using S = typename Problem::BlockShape;
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<4, 1>, sequence<1, 4>>,
tuple<sequence<S::Block_M, 1>, sequence<1, S::Block_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2, 1>,
@@ -25,10 +25,10 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeTransposedInputBlockTileDistribution()
{
// using S = typename Problem::BlockShape;
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<1, 4>, sequence<4, 1>>,
tuple<sequence<1, S::Block_N>, sequence<S::Block_M, 1>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 0>>,
sequence<2, 1>,