mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Solve the Static Encoding Pattern compile error when the tile size is too small (#2079)
This commit is contained in:
@@ -73,10 +73,11 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
// # of rows in Y dim accessed by single wavefront in one iteration
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
@@ -124,10 +125,11 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
@@ -173,10 +175,11 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
static constexpr index_t Y1 = num_warps;
|
||||
|
||||
Reference in New Issue
Block a user