From 65b57a61c8ec2c0a4bf785783ee57b68c8a50d4d Mon Sep 17 00:00:00 2001 From: Philip Maybank Date: Tue, 9 Sep 2025 15:04:20 +0100 Subject: [PATCH] use values in json config, change valid warp configurations --- tile_engine/ops/gemm/configs/gfx120x_config.json | 15 ++++++++------- tile_engine/ops/gemm/validation_utils.py | 3 ++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tile_engine/ops/gemm/configs/gfx120x_config.json b/tile_engine/ops/gemm/configs/gfx120x_config.json index 900cc16d1d..11a956c656 100644 --- a/tile_engine/ops/gemm/configs/gfx120x_config.json +++ b/tile_engine/ops/gemm/configs/gfx120x_config.json @@ -9,14 +9,16 @@ ] }, "tile_n": { - "max": 256, - "min": 64, - "step": 64 + "values": [ + 256, + 128 + ] }, "tile_k": { - "max": 256, - "min": 64, - "step": 64 + "values": [ + 256, + 128 + ] }, "warp_m": { "values": [ @@ -57,7 +59,6 @@ "pipeline": { "values": [ "compv3", - "compv4", "mem" ] }, diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py index 7367f2446d..717521e3e2 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/validation_utils.py @@ -157,7 +157,8 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: """Validate warp configuration.""" - return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + # return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + return (warp_m, warp_n, warp_k) in [(2, 4, 1), (1, 8, 1), (8, 1, 1), (4, 2, 1)] def validate_dimension_alignment(