From e04000a0d0077c6079fe20bd79d202da7ac0528c Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Mon, 15 Sep 2025 15:22:11 -0500 Subject: [PATCH] [CK Tile Engine] k_block_per_cu changes in Preshuffle (#2842) * kperblock changes in CK Tile Engine Preshuffle * Config file formatting changes [ROCm/composable_kernel commit: 7d7ded62d3cc0e8a3160b546becdc3d8847799db] --- .../configs/default_config.json | 95 ++++++++++--------- .../configs/user_provided_config.json | 5 +- .../gemm_preshuffle_instance_builder.py | 16 +++- 3 files changed, 62 insertions(+), 54 deletions(-) diff --git a/tile_engine/ops/gemm_preshuffle/configs/default_config.json b/tile_engine/ops/gemm_preshuffle/configs/default_config.json index d48b7b0ac2..d4c3537c65 100644 --- a/tile_engine/ops/gemm_preshuffle/configs/default_config.json +++ b/tile_engine/ops/gemm_preshuffle/configs/default_config.json @@ -1,51 +1,51 @@ { - "tile_config": { - "tile_m": { - "values": [ - 128 - ] - }, - "tile_n": { - "values": [ - 128 - ] - }, - "tile_k": { - "values": [ - 128 - ] - }, - "warp_m": { - "values": [ - 1 - ] - }, - "warp_n": { - "values": [ - 4 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 16 - ] - }, - "warp_tile_n": { - "values": [ - 16 - ] - }, - "warp_tile_k": { - "values": [ - 16,32 - ] - } + "tile_config": { + "tile_m": { + "values": [ + 128 + ] }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 128 + ] + }, + "warp_m": { + "values": [ + 1 + ] + }, + "warp_n": { + "values": [ + 4 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16,32 + ] + } + }, "trait_config": { "pipeline": { "values": [ @@ -86,5 +86,6 @@ false ] } - } + }, + "k_block_per_cu": 2 } \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json index a9937698d3..c0fc1f6cf8 100644 --- a/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json +++ b/tile_engine/ops/gemm_preshuffle/configs/user_provided_config.json @@ -80,7 +80,8 @@ "persistent": { "values": [ false - ] + ] } - } + }, + "k_block_per_cu": 8 } \ No newline at end of file diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 87b826467b..7734cb3a5e 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -296,7 +296,9 @@ class GemmPreshuffleKernelBuilder: pipeline, ) - def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + def _generate_kernel_instance( + self, tile_config, trait_combo, k_block_per_cu, is_header=True + ): """Generate a single kernel instance""" ( pipeline, @@ -531,7 +533,7 @@ struct SelectedKernel {{ }} // Launch kernel - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = {k_block_per_cu}; ave_time = ck_tile::launch_kernel( stream, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); @@ -575,6 +577,7 @@ struct SelectedKernel {{ tile_configs = self._get_tile_configs() trait_combos = self._generate_trait_combinations() + k_block_per_cu = self.config.get("k_block_per_cu") # Prepare work items for parallel processing work_items = [] @@ -584,6 +587,7 @@ struct SelectedKernel {{ ( tile_config, trait_combo, + k_block_per_cu, self.working_path, self.datatype, self.layout, @@ -675,14 +679,14 @@ struct SelectedKernel {{ def _generate_single_kernel_individual(work_item): """Worker function to generate a single individual kernel file""" - tile_config, trait_combo, working_path, datatype, layout = work_item + tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item # Create a temporary builder instance for this worker builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout) try: kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo + tile_config, trait_combo, k_block_per_cu ) # Create simplified filename without the "gemm_" prefix @@ -804,9 +808,11 @@ def main(): trait_parts[6] == "True", # persistent ) + k_block_per_cu = builder.config.get("k_block_per_cu") + # Generate the kernel kernel_name, instance_code = builder._generate_kernel_instance( - tile_config, trait_combo + tile_config, trait_combo, k_block_per_cu ) # Write the file