diff --git a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt index aa1a2d2d1c..4acab26c41 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk_tile_engine/CMakeLists.txt @@ -280,14 +280,18 @@ message(STATUS "Building StreamK GEMM tile engine tests for GPU targets: ${GEMM_ # All supported data types and layouts for comprehensive testing # Note: fp64 not included (no MFMA hardware support) set(TEST_DATATYPES "fp16;bf16") -set(TEST_LAYOUTS "rcr;rrr;ccr;crr") +# Temporarily only test rcr and crr +# set(TEST_LAYOUTS "rcr;rrr;ccr;crr") +set(TEST_LAYOUTS "rcr;crr") # ============================================================================ # Test Target Generation - Datatype-Specific Categories # ============================================================================ # 1. SMOKE TESTS: Test for basic functionality with data types (fp8, bf8, fp16, bf16) -set(SMALL_DATATYPES "fp16;bf16;fp8;bf8") +# Temporarily only consider fp16 +# set(SMALL_DATATYPES "fp16;bf16;fp8;bf8") +set(SMALL_DATATYPES "fp16") set(SIXTEEN_BIT_DATATYPES "fp16;bf16") set(EIGHT_BIT_DATATYPES "fp8;bf8") set(LARGE_TILES "256,256,32") diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py index 0f2673c6dd..2795303684 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py @@ -23,7 +23,9 @@ class TileConfig: warp_k: List[int] = field(default_factory=lambda: [1]) warp_tile_m: List[int] = field(default_factory=lambda: [16, 32]) warp_tile_n: List[int] = field(default_factory=lambda: [16, 32]) - warp_tile_k: List[int] = field(default_factory=lambda: [8, 16, 32]) + # Temporarily only consider 16 for warp_tile_k + # warp_tile_k: List[int] = field(default_factory=lambda: [8, 16, 32]) + warp_tile_k: List[int] = field(default_factory=lambda: [16]) def to_dict(self) -> Dict: return {k: {"values": v} for k, v in asdict(self).items()} @@ -33,7 +35,9 @@ class TileConfig: class TraitConfig: """Represents the Trait Config section of a Tile Engine config""" - pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"]) + # Temporarily only consider compv3 + # pipeline: List[str] = field(default_factory=lambda: ["compv3", "mem"]) + pipeline: List[str] = field(default_factory=lambda: ["compv3"]) epilogue: List[str] = field(default_factory=lambda: ["cshuffle"]) scheduler: List[str] = field(default_factory=lambda: ["intrawave"]) pad_m: List[bool] = field(default_factory=lambda: [False]) @@ -67,21 +71,27 @@ class TestVariant(Enum): 0, ["atomic"], [True, False], - ["fp16", "bf16", "fp8", "bf8"], + # Temporarily only run fp16 tests + # ["fp16", "bf16", "fp8", "bf8"], + ["fp16"], "Stream-K atomic smoke tests", ) REDUCTION_SMOKE = ( 2, ["linear", "tree"], [True, False], - ["fp16", "bf16", "fp8", "bf8"], + # Temporarily only run fp16 tests + # ["fp16", "bf16", "fp8", "bf8"], + ["fp16"], "Stream-K reduction smoke tests", ) EXTENDED = ( 3, ["atomic"], [True, False], - ["fp16", "bf16", "fp8", "bf8"], + # Temporarily only run fp16 tests + # ["fp16", "bf16", "fp8", "bf8"], + ["fp16"], "Stream-K extended smoke tests", )