mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
MI308 fix for streamk 1-Tile floating point exception (#2101)
This commit is contained in:
committed by
GitHub
parent
a738e43445
commit
b092c18da7
@@ -1438,6 +1438,7 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
|
||||
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
|
||||
{
|
||||
|
||||
// total output tiles
|
||||
uint32_t num_tiles =
|
||||
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
|
||||
@@ -1445,6 +1446,9 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
|
||||
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
|
||||
|
||||
// Ensure grid_size is at least 1 to avoid division by zero
|
||||
grid_size = math::max(grid_size, 1u);
|
||||
|
||||
// default to regular DP GEMM if sk blocks == 0
|
||||
if(streamk_sel == 0)
|
||||
{
|
||||
@@ -1460,31 +1464,45 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
// 2-tile sk + DP GEMM
|
||||
else
|
||||
{
|
||||
|
||||
// check if there's enough work for DP+ stream-k
|
||||
bool bigEnough = num_tiles > grid_size;
|
||||
// select between stream-k strategies
|
||||
|
||||
// Select between stream-k strategies
|
||||
// Add safety checks to prevent zero or negative values
|
||||
uint32_t sk_tiles = 0;
|
||||
if(streamk_sel == 1) // 1 tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
|
||||
|
||||
// Ensure sk_tiles is at least 1
|
||||
sk_tiles = math::max(sk_tiles, 1u);
|
||||
}
|
||||
else if(streamk_sel == 2) // 2-tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
|
||||
|
||||
// Ensure sk_tiles is at least 1 but not more than num_tiles
|
||||
sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
|
||||
}
|
||||
else if(streamk_sel == 3) // 3-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
|
||||
// Ensure sk_tiles is at least 1 but not more than num_tiles
|
||||
sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
|
||||
}
|
||||
else if(streamk_sel == 4) // 4-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
|
||||
// Ensure sk_tiles is at least 1 but not more than num_tiles
|
||||
sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles);
|
||||
}
|
||||
|
||||
sk_num_blocks = sk_tiles;
|
||||
// remaining tiles are DP tiles
|
||||
// Remaining tiles are DP tiles
|
||||
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
|
||||
|
||||
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
@@ -1500,24 +1518,51 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
// Add safety check for sk_num_blocks to prevent division by zero
|
||||
if(sk_num_blocks > 0)
|
||||
{
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Fallback to default GEMM if no stream-k blocks
|
||||
sk_num_blocks = 0;
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
dp_tiles = num_tiles;
|
||||
dp_num_blocks = num_tiles;
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0;
|
||||
}
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = sk_num_blocks;
|
||||
}
|
||||
|
||||
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
|
||||
// using multiple blocks for parallel reduction
|
||||
// Using multiple blocks for parallel reduction
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
|
||||
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
// Add additional safety checks
|
||||
if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
|
||||
{
|
||||
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
uint32_t upper_little =
|
||||
math::lcm(math::max(k_iters_per_big_block - 1, 1u), k_iters_per_tile.get());
|
||||
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
|
||||
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Default safe values
|
||||
equiv_tiles_big = MDiv(1);
|
||||
equiv_tiles_little = MDiv(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user