diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index da71578f20..a8955cec91 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -168,19 +168,19 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] warp_tile_supported_combinations = { "gfx90a": { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] }, "gfx942": { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] }, "gfx950": { 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], - 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] } diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json index 312cefb4ed..6a6e726e40 100644 --- a/tile_engine/ops/gemm/configs/user_provided_config.json +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -34,17 +34,16 @@ "tile_config": { "tile_m": { "values": [ - 256 + 128 ] }, "tile_n": { "values": [ - 256 + 128 ] }, "tile_k": { "values": [ - 64, 32 ] }, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 1226169cf2..23675e73e8 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -526,10 +526,10 @@ struct GemmDispatcher { for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): content += f""" kernel_map["{trait}"] = {{""" - for i, tile in enumerate(tile_valid_params): + for _, tile in enumerate(tile_valid_params): for j in range(len(tile)): tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[j] - content += f"""[&](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ + content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ content += f""" if(structured_sparsity){{ // SMFMA""" sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \