diff --git a/CMakeLists.txt b/CMakeLists.txt index 121c663f64..349e7e15e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -390,6 +390,12 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +option(ENABLE_RESOURCE_USAGE "Enable printing kernel resource usage (VGPR/SGPR/LDS)" OFF) +if(ENABLE_RESOURCE_USAGE) + add_compile_options(-Rpass-analysis=kernel-resource-usage) + message(STATUS "CK compiled with ENABLE_RESOURCE_USAGE - will print VGPR/SGPR/LDS usage") +endif() + if (ENABLE_JSON_DUMP) add_compile_definitions(CK_ENABLE_JSON_DUMP) message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 8eff0e7469..2a6f46f9ce 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -98,9 +98,34 @@ struct GemmConfigComputeV3 : public GemmConfigBase template struct GemmConfigComputeV3_1 : public GemmConfigBase { + // Optimized config: 256x128x64, LDS=48KB, 16x16 warp tile for better occupancy static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16 + + static constexpr ck_tile::index_t M_Warp = 2; // 4 warps in M + static constexpr ck_tile::index_t N_Warp = 2; // 2 warps in N + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + + static constexpr int kBlockPerCu = 1; +}; + +// Optimized config: 256x128x64 with COMPUTE_V3 +// Smaller N tile for better L2 cache utilization +template +struct GemmConfigOptimized : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16 static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -113,6 +138,31 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; // Better occupancy with smaller tile +}; + +// Alternative: 128x128x128 config for higher K throughput +template +struct GemmConfigOptimized2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); // 128 for bf16 + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; }; template @@ -224,9 +274,11 @@ struct GemmConfigComputeV5 : public GemmConfigBase template struct GemmConfigComputeV6 : public GemmConfigBase { + // V6 pipeline: PrefetchStages=3, GlobalBufferNum=2, HotloopUnroll=2 + // Use 256x128x64 for better balance static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); // 64 for bf16 static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -234,11 +286,13 @@ struct GemmConfigComputeV6 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr int kBlockPerCu = 1; }; template diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index c1c8a2fc89..3b9f77e3e0 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -21,85 +21,24 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) using Invoker = UniversalInvoker; - if(data_type == "fp16") + if(data_type == "bf16") { - return run_gemm_example_prec_type, Invoker, ck_tile::half_t>( - a_layout, b_layout, arg_parser); - } - else if(data_type == "bf16") - { - return run_gemm_example_prec_type, Invoker, ck_tile::bf16_t>( - a_layout, b_layout, arg_parser); - } - else if(data_type == "fp8") - { - return run_gemm_example_prec_type, - Invoker, - ck_tile::fp8_t, - ck_tile::fp8_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else if(data_type == "bf8") - { - return run_gemm_example_prec_type, - Invoker, - ck_tile::bf8_t, - ck_tile::bf8_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else if(data_type == "int8") - { - return run_gemm_example_prec_type, - Invoker, - ck_tile::int8_t, - ck_tile::int8_t, - ck_tile::int32_t>(a_layout, b_layout, arg_parser); - } - else if(data_type == "fp16i4") - { - // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) + // Only support RC layout (A=Row, B=Column) to reduce compile time + if(a_layout != "R" || b_layout != "C") { - return run_gemm_example_prec_type, - Invoker, - ck_tile::half_t, - ck_tile::pk_int4_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else - { - throw std::runtime_error("Unsupported pipeline for this operation !!!"); - } - } - else if(data_type == "fp8i4") - { - if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) - { - return run_gemm_example_prec_type, - Invoker, - ck_tile::fp8_t, - ck_tile::pk_int4_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else - { - throw std::runtime_error("Unsupported pipeline for this operation !!!"); - } - } - else if(data_type == "bf8i4") - { - if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) - { - return run_gemm_example_prec_type, - Invoker, - ck_tile::bf8_t, - ck_tile::pk_int4_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else - { - throw std::runtime_error("Unsupported pipeline for this operation !!!"); + throw std::runtime_error("Only RC layout (A=Row, B=Column) is supported!"); } + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + // Directly call with fixed layout to avoid std::visit instantiating all combinations + return run_gemm_example_with_layouts, + Invoker, + ck_tile::bf16_t, + ck_tile::bf16_t, + ck_tile::bf16_t, + Row, + Col, + Row>(arg_parser, Row{}, Col{}, Row{}); } else { @@ -117,11 +56,8 @@ int main(int argc, char* argv[]) try { -#if CK_TILE_USE_WMMA - return !run_gemm_example(arg_parser); -#else - return !run_gemm_example(arg_parser); -#endif + return !run_gemm_example(arg_parser); + } catch(const std::runtime_error& e) {