diff --git a/CMakeLists.txt b/CMakeLists.txt index 06d270c16e..6849001372 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -396,6 +396,12 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +option(ENABLE_KERNEL_RESOURCE_USAGE "Print kernel resource usage (VGPR, SGPR, scratch) during compilation." OFF) +if(ENABLE_KERNEL_RESOURCE_USAGE) + add_compile_options(-Rpass-analysis=kernel-resource-usage) + message("CK compiled with ENABLE_KERNEL_RESOURCE_USAGE set to ${ENABLE_KERNEL_RESOURCE_USAGE}") +endif() + if (ENABLE_JSON_DUMP) add_compile_definitions(CK_ENABLE_JSON_DUMP) message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 4d88f5edec..feb5cc2711 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -279,6 +279,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, std:: CDataType, AccDataType>(argc, argv, Col{}, Row{}, Col{}); } + else if(a_layout == "R" && b_layout == "R" && c_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Col{}); + } else { throw std::runtime_error("Unsupported data layout configuration for A, B and C tensors!"); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 67b411c1f0..208e4120d2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -67,8 +67,8 @@ struct GemmConfigBase template struct GemmConfigComputeV3_2 : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 2; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 343e37ed66..159d0b30f8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -45,7 +45,10 @@ struct GemmPipelineAgBgCrImplBase else return std::is_same_v; }(); - + + // TEMP: Disable transpose load for B matrix to test RRR performance regression + static constexpr bool is_b_load_tr = false; + #if 0 // Original code - disabled for testing static constexpr bool is_b_load_tr = []() { using WarpTile = typename BlockGemmShape::WarpTile; constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); @@ -57,6 +60,7 @@ struct GemmPipelineAgBgCrImplBase else return std::is_same_v; }(); + #endif #else static constexpr bool is_a_load_tr = false; static constexpr bool is_b_load_tr = false; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index a45d41189b..a1cb772a20 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -57,7 +57,11 @@ struct UniversalGemmBasePolicy return std::is_same_v, tensor_layout::gemm::ColumnMajor>; }(); - + + // TEMP: Disable transpose load for B matrix to test RRR performance regression + template + static constexpr bool is_b_load_tr = false; + #if 0 // Original code - disabled for testing template static constexpr bool is_b_load_tr = []() { using BDataType = remove_cvref_t; @@ -73,6 +77,7 @@ struct UniversalGemmBasePolicy return std::is_same_v, tensor_layout::gemm::RowMajor>; }(); + #endif #else template static constexpr bool is_a_load_tr = false; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 27ca805c2e..703f5fe40a 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -78,6 +78,21 @@ class GemmKernelBuilder: persistent, ) = trait_combo + # Validate that this tile_config is valid for this specific pipeline + if not self._validate_tile_config( + tile_config['tile_m'], + tile_config['tile_n'], + tile_config['tile_k'], + tile_config['warp_m'], + tile_config['warp_n'], + tile_config['warp_k'], + tile_config['warp_tile_m'], + tile_config['warp_tile_n'], + tile_config['warp_tile_k'], + pipeline, + ): + continue + # Create kernel name with proper boolean capitalization kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" @@ -160,14 +175,18 @@ class GemmKernelBuilder: warp_tile_n_values = tile_config.get("warp_tile_n").get("values") warp_tile_k_values = tile_config.get("warp_tile_k").get("values") - # Generate all combinations - default_pipeline = "" - if self.kernel_name_prefix == "gemm_universal": - default_pipeline = "compv4" - elif self.kernel_name_prefix == "gemm_multi_d": - default_pipeline = "compv4" - elif self.kernel_name_prefix == "gemm_preshuffle": - default_pipeline = "preshufflev2" + # Get pipelines from trait config to validate against + trait_config = self.config.get("trait_config", {}) + pipelines = trait_config.get("pipeline", {}).get("values", []) + + # Fallback to default pipeline if no pipelines in config + if not pipelines: + if self.kernel_name_prefix == "gemm_universal": + pipelines = ["compv4"] + elif self.kernel_name_prefix == "gemm_multi_d": + pipelines = ["compv4"] + elif self.kernel_name_prefix == "gemm_preshuffle": + pipelines = ["preshufflev2"] configs = [] for tile_m in tile_m_values: @@ -179,19 +198,25 @@ class GemmKernelBuilder: for warp_tile_m in warp_tile_m_values: for warp_tile_n in warp_tile_n_values: for warp_tile_k in warp_tile_k_values: - # Validate configuration - if self._validate_tile_config( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - default_pipeline, - ): + # Validate configuration against any pipeline + # A tile config is valid if it works for at least one pipeline + is_valid = False + for pipeline in pipelines: + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline, + ): + is_valid = True + break + if is_valid: configs.append( { "tile_m": tile_m, diff --git a/tile_engine/ops/gemm/gemm_universal/configs/default_config.json b/tile_engine/ops/gemm/gemm_universal/configs/default_config.json index 2447428158..8031220f13 100644 --- a/tile_engine/ops/gemm/gemm_universal/configs/default_config.json +++ b/tile_engine/ops/gemm/gemm_universal/configs/default_config.json @@ -63,19 +63,16 @@ "values": [ "compv3", "compv4", - "mem" ] }, "scheduler": { "values": [ - "intrawave", - "interwave" + "intrawave" ] }, "epilogue": { "values": [ - "cshuffle", - "default" + "cshuffle" ] }, "pad_m": { @@ -95,7 +92,6 @@ }, "persistent": { "values": [ - false, true ] } diff --git a/tile_engine/ops/gemm/gemm_universal/configs/rrr_test_config.json b/tile_engine/ops/gemm/gemm_universal/configs/rrr_test_config.json new file mode 100644 index 0000000000..d98cd6ff6e --- /dev/null +++ b/tile_engine/ops/gemm/gemm_universal/configs/rrr_test_config.json @@ -0,0 +1,62 @@ +{ + "tile_config": { + "tile_m": { + "values": [128, 256] + }, + "tile_n": { + "values": [128, 256] + }, + "tile_k": { + "values": [64] + }, + "warp_m": { + "values": [1, 2, 4] + }, + "warp_n": { + "values": [1, 2, 4] + }, + "warp_k": { + "values": [1] + }, + "warp_tile_m": { + "values": [16, 32] + }, + "warp_tile_n": { + "values": [16, 32] + }, + "warp_tile_k": { + "values": [16, 32] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [false] + }, + "pad_n": { + "values": [false] + }, + "pad_k": { + "values": [false] + }, + "persistent": { + "values": [true] + } + }, + "k_block_per_cu": 1 +}