Files
composable_kernel/test/ck_tile/flatmm/CMakeLists.txt
Aviral Goel 1a4aa7fd89 [rocm-libraries] ROCm/rocm-libraries#5082 (commit 9313659)
ck_tile: add gtest unit tests for MX flatmm (gfx950)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add correctness unit tests for the MX-format flatmm kernel
(`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/`
- Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6,
FP8×FP4, FP4×FP8
- Tests cover all four kernel dispatch paths (the `has_hot_loop` ×
`tail_num` product):
  - `has_hot_loop=false, tail=ODD` (K=256, num_loop=1)
  - `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2)
  - `has_hot_loop=true, tail=ODD` (K=768, num_loop=3)
  - `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4)
- Remove unsupported `-split_k` CLI option from
`tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with
K-splitting and the option silently produced wrong results

## Changes

**New files (`test/ck_tile/flatmm/`):**
- `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT
library, links into 5 per-dtype test executables; forwards
`-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON
- `test_mx_flatmm_base.hpp` — base test fixture with
`run_test_with_validation(M, N, K, kbatch=1)`
- `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test
class and type aliases
- `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype
`TYPED_TEST_SUITE` files

**Modified files:**
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved
`preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by
both the example and the tests
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc`
— removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split`
formula, updated call sites after `preShuffleWeight` move
- `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)`
2026-03-11 22:47:59 +00:00

80 lines
3.3 KiB
CMake

# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(TEST_FLATMM_COMPILE_OPTIONS)
list(APPEND TEST_FLATMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(CK_USE_OCP_FP8)
list(APPEND TEST_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx95")
set(MXGEMM_EXAMPLE_DIR ${CMAKE_SOURCE_DIR}/example/ck_tile/18_flatmm/mxgemm)
# Generate the 40 kernel instance .cpp files.
# We inline the generation here (rather than calling mx_flatmm_instance_generate)
# so that configure_file paths resolve correctly from this directory.
set(C_DATA_TYPE FP16)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
set(FLATMM_INSTANCE_FILES)
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
set(ARCH MXFlatmm_GFX950_)
set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits")
foreach(SPLIT_K false)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
configure_file(
${MXGEMM_EXAMPLE_DIR}/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND FLATMM_INSTANCE_FILES ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
# Compile the 20 kernel instances once into an object library,
# shared across all 5 test executables to avoid redundant GPU compilation.
# SPLIT_K=true instances are omitted: split-K is confirmed broken at the
# kernel level for all dtype combinations and is not tested.
add_library(mx_flatmm_test_instances OBJECT ${FLATMM_INSTANCE_FILES})
target_include_directories(mx_flatmm_test_instances PRIVATE
${MXGEMM_EXAMPLE_DIR}
)
target_compile_options(mx_flatmm_test_instances PRIVATE ${TEST_FLATMM_COMPILE_OPTIONS})
foreach(DTYPE fp4fp4 fp8fp8 fp6fp6 fp8fp4 fp4fp8)
add_gtest_executable(test_tile_mx_flatmm_${DTYPE}
test_mx_flatmm_${DTYPE}.cpp
)
target_include_directories(test_tile_mx_flatmm_${DTYPE} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${MXGEMM_EXAMPLE_DIR}
)
target_compile_options(test_tile_mx_flatmm_${DTYPE} PRIVATE ${TEST_FLATMM_COMPILE_OPTIONS})
target_link_libraries(test_tile_mx_flatmm_${DTYPE} PRIVATE mx_flatmm_test_instances)
endforeach()
# Umbrella target to build all flatmm tests at once
add_custom_target(test_tile_mx_flatmm_all)
add_dependencies(test_tile_mx_flatmm_all
test_tile_mx_flatmm_fp4fp4
test_tile_mx_flatmm_fp8fp8
test_tile_mx_flatmm_fp6fp6
test_tile_mx_flatmm_fp8fp4
test_tile_mx_flatmm_fp4fp8
)
else()
message(DEBUG "Skipping ck_tile flatmm tests for current target")
endif()