From 9510d3df1faefc2dc3b1bbb5ee1b239be4450a2b Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Mon, 26 May 2025 04:33:44 +0000 Subject: [PATCH] default config test & fix codegen bug --- tile_engine/ops/gemm/codegen_utils.py | 6 +++--- tile_engine/ops/gemm/configs/user_provided_config.json | 5 ++--- tile_engine/ops/gemm/gemm_instance_builder.py | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index da71578f20..a8955cec91 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -168,19 +168,19 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] warp_tile_supported_combinations = { "gfx90a": { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] }, "gfx942": { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] }, "gfx950": { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] } diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 312cefb4ed..6a6e726e40 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -34,17 +34,16 @@ "tile_config": { "tile_m": { "values": [ - 256 + 128 ] }, "tile_n": { "values": [ - 256 + 128 ] }, "tile_k": { "values": [ - 64, 32 ] }, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 1226169cf2..23675e73e8 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -526,10 +526,10 @@ struct GemmDispatcher { for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): content += f""" kernel_map["{trait}"] = {{""" - for i, tile in enumerate(tile_valid_params): + for _, tile in enumerate(tile_valid_params): for j in range(len(tile)): tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[j] - content += f"""[&](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ + content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ content += f""" if(structured_sparsity){{ // SMFMA""" sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \