From 7f9758a186a5af61e82653fdac2b32e00a8b1a8f Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 3 Jun 2025 20:16:10 -0700 Subject: [PATCH] [CK_Tile] Fix gemm kernel for 4,64,16 and 64,4,16 warp tile sizes (#2262) * debugging issue * debugging issue * debugging * debugging * reverting debugging code * clang formatted * updating default_config.json * fix ci failure * clang formatted [ROCm/composable_kernel commit: 59a85cb4bcca9482fbccef570f6e9dc818d6deef] --- example/ck_tile/03_gemm/universal_gemm.cpp | 7 +++++-- .../ck_tile/ops/epilogue/cshuffle_epilogue.hpp | 15 ++++++++++++++- .../ck_tile/ops/epilogue/default_2d_epilogue.hpp | 8 ++++++-- tile_engine/ops/gemm/codegen_utils.py | 12 ++++++------ tile_engine/ops/gemm/configs/default_config.json | 12 +++++++++--- tile_engine/ops/gemm/gemm_instance_builder.py | 9 ++++++--- 6 files changed, 46 insertions(+), 17 deletions(-) diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 5dcb685839..0a094c29fe 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -121,8 +121,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9b8dde1905..1f53dfd93c 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -92,7 +92,20 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { constexpr index_t MaxVectorStoreSize = 16; - return MaxVectorStoreSize / sizeof(ODataType); + if constexpr(std::is_same_v) + { + return std::min(static_cast(kNPerIteration), + static_cast(MaxVectorStoreSize / sizeof(ODataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(kMPerIteration), + static_cast(MaxVectorStoreSize / sizeof(ODataType))); + } + else + { + static_assert(false, "Unsupported CLayout!"); + } } template diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index a2915f5c8f..ab3c0df88d 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -149,7 +149,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue else { // In this case each thread has just a single item in Ndim - return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + return (WG::WarpGemmAttribute::Impl::kCNLane * + WG::WarpGemmAttribute::Impl::kBNBlock) / + WG::kN; } } // M is contiguous dimension @@ -158,7 +160,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue if constexpr(isCTransposed) { // In this case each thread has just a single item in Mdim - return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + return (WG::WarpGemmAttribute::Impl::kCNLane * + WG::WarpGemmAttribute::Impl::kAMBlock) / + WG::kN; } else { diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index a8955cec91..58eed45dc6 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -167,20 +167,20 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] # To Do: add some more supported combinations warp_tile_supported_combinations = { "gfx90a": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] }, "gfx942": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] }, "gfx950": { - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] } diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index 09fe3b83ac..d20c5eef7d 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -48,7 +48,7 @@ "max": 512, "min": 64, "step": 64, - "exclude": [] + "exclude": [192] }, "warp_m": { "values": [ @@ -71,14 +71,20 @@ }, "warp_tile_m": { "values": [ + 4, + 8, 16, - 32 + 32, + 64 ] }, "warp_tile_n": { "values": [ + 4, + 8, 16, - 32 + 32, + 64 ] }, "warp_tile_k": { diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index ea7fa4e67c..a677b842c5 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -420,12 +420,12 @@ struct GemmKernel {{ # LDS capacity verification matrix_a_size = (tile_m * tile_k) * \ - pow(2, element_size(self.config.problem.datatype_map['matrix_a'])) + element_size(self.config.problem.datatype_map['matrix_a']) matrix_b_size = (tile_n * tile_k) * \ - pow(2, element_size(self.config.problem.datatype_map['matrix_b'])) + element_size(self.config.problem.datatype_map['matrix_b']) total_tile_in_lds = matrix_a_size + matrix_b_size - max_tile_size = 2**16 if pipeline == "compv4" else 2**15 + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 if total_tile_in_lds > max_tile_size: logging.debug( f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > " @@ -493,6 +493,9 @@ struct GemmKernel {{ for trait in self.valid_trait_names: tile_valid_params = list( filter(lambda t: self.is_tile_valid(t, trait), tile_params)) + + # if len(tile_valid_params) == 0: + # raise RuntimeError(f"No valid kernel instance selected for trait: {trait}") if trait not in self.valid_trait_tile_combinations: self.valid_trait_tile_combinations[trait] = [] self.valid_trait_tile_combinations[trait].append(tile_valid_params)