MI308 fix for streamk 1-Tile floating point exception (#2101)

This commit is contained in:
Muhammed Emin Ozturk
2025-04-21 11:44:07 -07:00
committed by GitHub
parent a738e43445
commit b092c18da7
2 changed files with 56 additions and 39 deletions

View File

@@ -1438,6 +1438,7 @@ struct BlockToCTileMap_GemmStreamK_v2
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
{
// total output tiles
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
@@ -1445,6 +1446,9 @@ struct BlockToCTileMap_GemmStreamK_v2
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
// Ensure grid_size is at least 1 to avoid division by zero
grid_size = math::max(grid_size, 1u);
// default to regular DP GEMM if sk blocks == 0
if(streamk_sel == 0)
{
@@ -1460,31 +1464,45 @@ struct BlockToCTileMap_GemmStreamK_v2
// 2-tile sk + DP GEMM
else
{
// check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size;
// select between stream-k strategies
// Select between stream-k strategies
// Add safety checks to prevent zero or negative values
uint32_t sk_tiles = 0;
if(streamk_sel == 1) // 1 tile stream-k
{
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
// Ensure sk_tiles is at least 1
sk_tiles = math::max(sk_tiles, 1u);
}
else if(streamk_sel == 2) // 2-tile stream-k
{
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
// Ensure sk_tiles is at least 1 but not more than num_tiles
sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
}
else if(streamk_sel == 3) // 3-tile stream-k
{
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
: num_tiles;
// Ensure sk_tiles is at least 1 but not more than num_tiles
sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
}
else if(streamk_sel == 4) // 4-tile stream-k
{
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
: num_tiles;
// Ensure sk_tiles is at least 1 but not more than num_tiles
sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
}
sk_num_blocks = sk_tiles;
// remaining tiles are DP tiles
// Remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
@@ -1500,24 +1518,51 @@ struct BlockToCTileMap_GemmStreamK_v2
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
// Add safety check for sk_num_blocks to prevent division by zero
if(sk_num_blocks > 0)
{
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
}
else
{
// Fallback to default GEMM if no stream-k blocks
sk_num_blocks = 0;
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_tiles = num_tiles;
dp_num_blocks = num_tiles;
dp_start_block_idx = 0;
sk_total_iters = 0;
}
dp_num_blocks = dp_tiles;
dp_start_block_idx = sk_num_blocks;
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
// using multiple blocks for parallel reduction
// Using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
// Add additional safety checks
if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little =
math::lcm(math::max(k_iters_per_big_block - 1, 1u), k_iters_per_tile.get());
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
else
{
// Default safe values
equiv_tiles_big = MDiv(1);
equiv_tiles_little = MDiv(1);
}
}
}