From 16fd74e8d3ec2d9bd5f12279d9cfb6b1236adedb Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Mon, 21 Apr 2025 11:44:07 -0700 Subject: [PATCH] MI308 fix for streamk 1-Tile floating point exception (#2101) [ROCm/composable_kernel commit: b092c18da708422fb529193de40b6224446007c5] --- .../gpu/grid/block_to_ctile_map.hpp | 67 ++++++++++++++++--- ...t_gemm_universal_streamk_ut_cases_bf16.inc | 28 -------- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 64fad1ca48..311545aad6 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1438,6 +1438,7 @@ struct BlockToCTileMap_GemmStreamK_v2 __host__ __device__ BlockToCTileMap_GemmStreamK_v2( uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) { + // total output tiles uint32_t num_tiles = math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); @@ -1445,6 +1446,9 @@ struct BlockToCTileMap_GemmStreamK_v2 uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + // Ensure grid_size is at least 1 to avoid division by zero + grid_size = math::max(grid_size, 1u); + // default to regular DP GEMM if sk blocks == 0 if(streamk_sel == 0) { @@ -1460,31 +1464,45 @@ struct BlockToCTileMap_GemmStreamK_v2 // 2-tile sk + DP GEMM else { - // check if there's enough work for DP+ stream-k bool bigEnough = num_tiles > grid_size; - // select between stream-k strategies + + // Select between stream-k strategies + // Add safety checks to prevent zero or negative values uint32_t sk_tiles = 0; if(streamk_sel == 1) // 1 tile stream-k { sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 + sk_tiles = math::max(sk_tiles, 1u); } else if(streamk_sel == 2) // 2-tile stream-k { sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } else if(streamk_sel == 3) // 3-tile stream-k { sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } else if(streamk_sel == 4) // 4-tile stream-k { sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } + sk_num_blocks = sk_tiles; - // remaining tiles are DP tiles + // Remaining tiles are DP tiles dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; sk_total_iters = k_iters_per_tile.get() * sk_tiles; @@ -1500,24 +1518,51 @@ struct BlockToCTileMap_GemmStreamK_v2 // => sk_blocks * m + b = sk_total_iters // => b = sk_total_iters - m * sk_blocks // NOTE: big could be zero - uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; - sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; - k_iters_per_big_block = k_iters_per_sk_block + 1; + + // Add safety check for sk_num_blocks to prevent division by zero + if(sk_num_blocks > 0) + { + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + } + else + { + // Fallback to default GEMM if no stream-k blocks + sk_num_blocks = 0; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + dp_tiles = num_tiles; + dp_num_blocks = num_tiles; + dp_start_block_idx = 0; + sk_total_iters = 0; + } dp_num_blocks = dp_tiles; dp_start_block_idx = sk_num_blocks; } n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); - // using multiple blocks for parallel reduction + // Using multiple blocks for parallel reduction reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) { - uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); - uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); - equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); - equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + // Add additional safety checks + if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = + math::lcm(math::max(k_iters_per_big_block - 1, 1u), k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + else + { + // Default safe values + equiv_tiles_big = MDiv(1); + equiv_tiles_little = MDiv(1); + } } } diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc index b6970c4233..22977866b5 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc @@ -44,34 +44,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573};