Merge commit '5f1ad09b610cb0e083f63988479ab022bda70588' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-13 01:40:16 +00:00
parent dfa4ac5eb3
commit c1e4ebf424
17 changed files with 727 additions and 110 deletions

View File

@@ -56,19 +56,24 @@ template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
tile_distribution_pattern DistributionPattern,
index_t NumWaveGroups = 1>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
{
};
// Thread raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEncodingPattern
tile_distribution_pattern::thread_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
@@ -83,45 +88,76 @@ struct TileDistributionEncodingPattern2D<BlockSize,
static constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
static constexpr index_t Y0 = num_warps / NumWaveGroups;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
"X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
if constexpr(NumWaveGroups != 1)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0>, sequence<1, 2>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
if constexpr(NumWaveGroups != 1)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
}
};
// Warp raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEncodingPattern
tile_distribution_pattern::warp_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
@@ -164,13 +200,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
};
// Block raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
index_t NumWaveGroups>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEncodingPattern
tile_distribution_pattern::block_raked,
NumWaveGroups> : public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!