From a1245351330954de1d4363edd7551586a6efc036 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 11 Mar 2026 18:46:58 -0400 Subject: [PATCH] ck_tile: add gtest unit tests for MX flatmm (gfx950) (#5082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add correctness unit tests for the MX-format flatmm kernel (`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/` - Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6, FP8×FP4, FP4×FP8 - Tests cover all four kernel dispatch paths (the `has_hot_loop` × `tail_num` product): - `has_hot_loop=false, tail=ODD` (K=256, num_loop=1) - `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2) - `has_hot_loop=true, tail=ODD` (K=768, num_loop=3) - `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4) - Remove unsupported `-split_k` CLI option from `tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with K-splitting and the option silently produced wrong results ## Changes **New files (`test/ck_tile/flatmm/`):** - `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT library, links into 5 per-dtype test executables; forwards `-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON - `test_mx_flatmm_base.hpp` — base test fixture with `run_test_with_validation(M, N, K, kbatch=1)` - `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test class and type aliases - `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype `TYPED_TEST_SUITE` files **Modified files:** - `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved `preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by both the example and the tests - `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc` — removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split` formula, updated call sites after `preShuffleWeight` move - `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)` --------- Co-authored-by: Thomas Ning --- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 88 ++---- .../mxgemm/mx_flatmm_arch_traits.hpp | 41 +++ .../18_flatmm/mxgemm/run_mx_flatmm.inc | 4 +- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/flatmm/CMakeLists.txt | 79 ++++++ test/ck_tile/flatmm/test_mx_flatmm_base.hpp | 251 ++++++++++++++++++ .../flatmm/test_mx_flatmm_fixtures.hpp | 20 ++ test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp | 41 +++ test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp | 40 +++ test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp | 40 +++ test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp | 40 +++ test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp | 40 +++ .../test_gemm_streamk_simple.cpp | 7 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 8 +- 14 files changed, 627 insertions(+), 73 deletions(-) create mode 100644 test/ck_tile/flatmm/CMakeLists.txt create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_base.hpp create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_fixtures.hpp create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp create mode 100644 test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 3d51fd9907..702a89aa25 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -43,7 +43,6 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::index_t stride_A, ck_tile::index_t stride_B, ck_tile::index_t stride_C, - ck_tile::index_t kbatch, ScaleA scale_a, ScaleB scale_b, int n_warmup, @@ -55,7 +54,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, b_shuffle_dev_buf.GetDeviceBuffer(), {}, c_dev_buf.GetDeviceBuffer(), - kbatch, + 1, M, N, K, @@ -90,8 +89,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; - const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; - const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t k_grain = FlatmmConfig::K_Tile; + const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split); const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop); @@ -100,29 +99,24 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, [&](auto has_hot_loop_, auto tail_num_) { constexpr auto has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_num_v = tail_num_.value; - auto invoke_splitk_path = [&](auto split_k_) { - return mx_flatmm_calc( - args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); - }; - return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{}) - : invoke_splitk_path(std::true_type{}); + return mx_flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); }, has_hot_loop, tail_num); @@ -166,7 +160,6 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:constant(1)") .insert("persistent", "0", "0: no persistent, 1: persistent kernel") .insert("warp_tile", "0", "0: 16x16x128 on gfx950."); @@ -174,45 +167,6 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template -auto preShuffleWeight(ck_tile::HostTensor& src) -{ - auto src_lengths = src.get_lengths(); - const int K = src_lengths[0]; - const int N = src_lengths[1]; - constexpr int packed_size = ck_tile::numeric_traits::PackedSize; - int KPack = - std::is_same_v ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16 - - int KLane = ck_tile::get_warp_size() / NLane; - int K0 = K / (KLane * KPack); - - ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); - - // K -> K0 KLane KPack - // N -> N0 NLane - // N, K -> N0 K0 KLane NLane KPack - for(int n = 0; n < N; ++n) - { - for(int k = 0; k < K; k += packed_size) - { - int n0 = n / NLane; - int n1 = n % NLane; - - int k0 = k / (KLane * KPack); - int tempk = k % (KLane * KPack); - int k1 = tempk / KPack; - int k2 = tempk % KPack; - - int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + - k1 * KPack * NLane + n1 * KPack + k2; - - shuffled(outputIndex) = src(k, n); - } - } - return shuffled; -} - #include "run_mx_flatmm.inc" int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser) diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp index d8c3d41c5c..b496b37686 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp @@ -70,6 +70,47 @@ struct MXFlatmmArchTraits static constexpr int GetNLane() { return Config::N_Warp_Tile; } + template + static auto preShuffleWeight(ck_tile::HostTensor& src) + { + constexpr ck_tile::index_t NLane = Config::N_Warp_Tile; + auto src_lengths = src.get_lengths(); + const int K = src_lengths[0]; + const int N = src_lengths[1]; + constexpr int packed_size = ck_tile::numeric_traits::PackedSize; + int KPack = std::is_same_v + ? 32 + : 16 * packed_size; // fp4/fp6:32 or fp8:16 + + int KLane = ck_tile::get_warp_size() / NLane; + int K0 = K / (KLane * KPack); + + ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({N * K}, {1})); + + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; k += packed_size) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + int tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + shuffled(outputIndex) = src(k, n); + } + } + return shuffled; + } + template static auto preShuffleScale(ck_tile::HostTensor& src) { diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 0e49c6452a..2779dc6208 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -32,7 +32,6 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - ck_tile::index_t kbatch = arg_parser.get_int("split_k"); ck_tile::index_t init_method = arg_parser.get_int("init"); ck_tile::index_t n_warmup = arg_parser.get_int("warmup"); ck_tile::index_t n_repeat = arg_parser.get_int("repeat"); @@ -106,7 +105,7 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser, } } - const auto b_shuffled_host = preShuffleWeight(b_origin_host); + const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host); const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale(scale_a); const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale(scale_b); @@ -151,7 +150,6 @@ int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser, stride_A, stride_B, stride_C, - kbatch, scale_a_dev_ptr, scale_b_dev_ptr, n_warmup, diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index ef2897a8ef..320e5b1e91 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -57,6 +57,7 @@ add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(layernorm2d) # add_subdirectory(rmsnorm2d) add_subdirectory(gemm_block_scale) +add_subdirectory(flatmm) add_subdirectory(gemm_mx) add_subdirectory(utility) add_subdirectory(warp_gemm) diff --git a/test/ck_tile/flatmm/CMakeLists.txt b/test/ck_tile/flatmm/CMakeLists.txt new file mode 100644 index 0000000000..0568ba6b15 --- /dev/null +++ b/test/ck_tile/flatmm/CMakeLists.txt @@ -0,0 +1,79 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(TEST_FLATMM_COMPILE_OPTIONS) +list(APPEND TEST_FLATMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + +if(CK_USE_OCP_FP8) + list(APPEND TEST_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx95") + set(MXGEMM_EXAMPLE_DIR ${CMAKE_SOURCE_DIR}/example/ck_tile/18_flatmm/mxgemm) + + # Generate the 40 kernel instance .cpp files. + # We inline the generation here (rather than calling mx_flatmm_instance_generate) + # so that configure_file paths resolve correctly from this directory. + set(C_DATA_TYPE FP16) + set(A_LAYOUT ROW) + set(B_LAYOUT COL) + set(C_LAYOUT ROW) + + set(FLATMM_INSTANCE_FILES) + foreach(PERSISTENT false) + foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8) + string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE}) + list(GET DATA_TYPE_AB 0 A_DATA_TYPE) + list(GET DATA_TYPE_AB 1 B_DATA_TYPE) + set(ARCH MXFlatmm_GFX950_) + set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits") + foreach(SPLIT_K false) + foreach(HAS_HOT_LOOP false true) + foreach(TAIL_NUMBER ODD EVEN) + set(KERNEL_FILE instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp) + string(TOLOWER ${KERNEL_FILE} KERNEL_FILE) + configure_file( + ${MXGEMM_EXAMPLE_DIR}/mx_flatmm_instance.cpp.in + ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE} + @ONLY) + list(APPEND FLATMM_INSTANCE_FILES ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}) + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + + # Compile the 20 kernel instances once into an object library, + # shared across all 5 test executables to avoid redundant GPU compilation. + # SPLIT_K=true instances are omitted: split-K is confirmed broken at the + # kernel level for all dtype combinations and is not tested. + add_library(mx_flatmm_test_instances OBJECT ${FLATMM_INSTANCE_FILES}) + target_include_directories(mx_flatmm_test_instances PRIVATE + ${MXGEMM_EXAMPLE_DIR} + ) + target_compile_options(mx_flatmm_test_instances PRIVATE ${TEST_FLATMM_COMPILE_OPTIONS}) + + foreach(DTYPE fp4fp4 fp8fp8 fp6fp6 fp8fp4 fp4fp8) + add_gtest_executable(test_tile_mx_flatmm_${DTYPE} + test_mx_flatmm_${DTYPE}.cpp + ) + target_include_directories(test_tile_mx_flatmm_${DTYPE} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${MXGEMM_EXAMPLE_DIR} + ) + target_compile_options(test_tile_mx_flatmm_${DTYPE} PRIVATE ${TEST_FLATMM_COMPILE_OPTIONS}) + target_link_libraries(test_tile_mx_flatmm_${DTYPE} PRIVATE mx_flatmm_test_instances) + endforeach() + + # Umbrella target to build all flatmm tests at once + add_custom_target(test_tile_mx_flatmm_all) + add_dependencies(test_tile_mx_flatmm_all + test_tile_mx_flatmm_fp4fp4 + test_tile_mx_flatmm_fp8fp8 + test_tile_mx_flatmm_fp6fp6 + test_tile_mx_flatmm_fp8fp4 + test_tile_mx_flatmm_fp4fp8 + ) +else() + message(DEBUG "Skipping ck_tile flatmm tests for current target") +endif() diff --git a/test/ck_tile/flatmm/test_mx_flatmm_base.hpp b/test/ck_tile/flatmm/test_mx_flatmm_base.hpp new file mode 100644 index 0000000000..3cfd861d0a --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_base.hpp @@ -0,0 +1,251 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/check_err.hpp" +#include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "mx_flatmm.hpp" + +// Base class for MX Flatmm unit tests. +// +// Tuple layout: +template +class TestMXFlatmmBase : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using CDataType = std::tuple_element_t<2, Tuple>; + using MXFlatmmArchTraits = std::tuple_element_t<3, Tuple>; + + using FlatmmConfig = typename MXFlatmmArchTraits::Config; + using AccDataType = float; + using ScaleType = ck_tile::e8m0_t; + + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + static constexpr int ScaleGranularityM = 1; + static constexpr int ScaleGranularityN = 1; + static constexpr int ScaleGranularityK = 32; + + using ScaleA = ck_tile::FlatmmScalePointer; + using ScaleB = ck_tile::FlatmmScalePointer; + + void + run_test_with_validation(ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t kbatch = 1, + std::optional expected_has_hot_loop = std::nullopt, + std::optional expected_tail_num = std::nullopt) + { + constexpr int APackedSize = ck_tile::numeric_traits::PackedSize; + constexpr int BPackedSize = ck_tile::numeric_traits::PackedSize; + + ASSERT_EQ(K % ScaleGranularityK, 0) << "K must be a multiple of ScaleGranularityK (32)"; + ASSERT_EQ(K % APackedSize, 0) << "K must be a multiple of A PackedSize"; + ASSERT_EQ(K % BPackedSize, 0) << "K must be a multiple of B PackedSize"; + + constexpr bool a_row_major = true; + constexpr bool b_row_major = false; + constexpr bool c_row_major = true; + + const ck_tile::index_t stride_A = + ck_tile::get_default_stride(M, K, 0, ck_tile::bool_constant{}); + const ck_tile::index_t stride_B = + ck_tile::get_default_stride(K, N, 0, ck_tile::bool_constant{}); + const ck_tile::index_t stride_C = + ck_tile::get_default_stride(M, N, 0, ck_tile::bool_constant{}); + + const auto scale_stride_A = ck_tile::get_default_stride( + M / ScaleGranularityM, K / ScaleGranularityK, 0, ck_tile::bool_constant{}); + const auto scale_stride_B = ck_tile::get_default_stride( + K / ScaleGranularityK, N / ScaleGranularityN, 0, ck_tile::bool_constant{}); + + // Host tensors + ck_tile::HostTensor a_host( + ck_tile::host_tensor_descriptor(M, K, stride_A, ck_tile::bool_constant{})); + ck_tile::HostTensor b_origin_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, ck_tile::bool_constant{})); + ck_tile::HostTensor c_rslt_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, ck_tile::bool_constant{})); + + ck_tile::HostTensor scale_a( + ck_tile::host_tensor_descriptor(M / ScaleGranularityM, + K / ScaleGranularityK, + scale_stride_A, + ck_tile::bool_constant{})); + ck_tile::HostTensor scale_b( + ck_tile::host_tensor_descriptor(K / ScaleGranularityK, + N / ScaleGranularityN, + scale_stride_B, + ck_tile::bool_constant{})); + + // Initialize data + if constexpr(std::is_same_v) + { + // FP6: fill raw bytes with values 1..4 (avoids denormals) + auto a_bytes = a_host.get_element_space_size_in_bytes(); + auto b_bytes = b_origin_host.get_element_space_size_in_bytes(); + std::vector buf_a(a_bytes), buf_b(b_bytes); + std::mt19937 gen(42); + std::uniform_int_distribution dis(1, 4); + for(auto& v : buf_a) + v = static_cast(dis(gen)); + for(auto& v : buf_b) + v = static_cast(dis(gen)); + memcpy(a_host.data(), buf_a.data(), a_bytes); + memcpy(b_origin_host.data(), buf_b.data(), b_bytes); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b); + } + else + { + ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b); + } + + // Preshuffle B and scales + const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host); + const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale(scale_a); + const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale(scale_b); + + // Device buffers + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); + + a_dev_buf.ToDevice(a_host.data()); + b_shuffled_dev_buf.ToDevice(b_shuffled_host.data()); + c_rslt_host.SetZero(); + c_dev_buf.ToDevice(c_rslt_host.data()); + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); + + auto scale_a_dev_ptr = ScaleA{static_cast(scale_a_dev_buf.GetDeviceBuffer()), + M / ScaleGranularityM}; + auto scale_b_dev_ptr = ScaleB{static_cast(scale_b_dev_buf.GetDeviceBuffer()), + N / ScaleGranularityN}; + + // Build args + ck_tile::ScaleFlatmmHostArgs args{a_dev_buf.GetDeviceBuffer(), + b_shuffled_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_a_dev_ptr, + scale_b_dev_ptr}; + + // Compute hot_loop / tail_num + using FlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using GemmTraits = ck_tile::TileGemmTraits; + using GemmPipelineProblem = ck_tile:: + GemmPipelineProblem; + using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split); + const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop); + + if(expected_has_hot_loop.has_value()) + ASSERT_EQ(has_hot_loop, *expected_has_hot_loop) + << "has_hot_loop mismatch for (M=" << M << ", N=" << N << ", K=" << K << ")"; + if(expected_tail_num.has_value()) + ASSERT_EQ(tail_num, *expected_tail_num) + << "tail_num mismatch for (M=" << M << ", N=" << N << ", K=" << K << ")"; + + // Launch kernel (warmup=0, repeat=1 for correctness testing) + // mx_flatmm_calc is explicitly instantiated in the linked object library; + // suppress the -Wundefined-func-template warning that fires when the + // compiler sees only the forward declaration in mx_flatmm.hpp. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wundefined-func-template" + BaseFlatmmPipeline::template TailHandler( + [&](auto has_hot_loop_, auto tail_num_) { + constexpr auto has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_num_v = tail_num_.value; + // SplitK (kbatch>1) is excluded: confirmed broken at the kernel level. + // Always dispatch the kbatch=1 (SPLIT_K=false) path. + mx_flatmm_calc, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + ScaleA, + ScaleB, + /*persistent=*/false, + ck_tile::element_wise::PassThrough, + /*split_k=*/false, + has_hot_loop_v, + tail_num_v>(args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); + }, + has_hot_loop, + tail_num); +#pragma clang diagnostic pop + + c_dev_buf.FromDevice(c_rslt_host.data()); + + // CPU reference + ck_tile::HostTensor c_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, ck_tile::bool_constant{})); + c_ref.SetZero(); + + ck_tile::reference_mx_gemm( + a_host, b_origin_host, c_ref, scale_a, scale_b); + + const float rtol = 1e-2f; + const float atol = 1e-2f; + EXPECT_TRUE( + ck_tile::check_err(c_rslt_host, c_ref, "MX Flatmm result mismatch", rtol, atol)); + } +}; diff --git a/test/ck_tile/flatmm/test_mx_flatmm_fixtures.hpp b/test/ck_tile/flatmm/test_mx_flatmm_fixtures.hpp new file mode 100644 index 0000000000..c4adb3e2da --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_fixtures.hpp @@ -0,0 +1,20 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "test_mx_flatmm_base.hpp" +#include "mx_flatmm_arch_traits.hpp" + +// Convenience type aliases for use in test .cpp files +using FP4 = ck_tile::pk_fp4_t; +using FP6 = ck_tile::pk_fp6x16_t; +using FP8 = ck_tile::fp8_t; +using FP16 = ck_tile::fp16_t; + +// Concrete test fixture — inherits all logic from TestMXFlatmmBase. +// Tuple layout: +template +class TestMXFlatmm : public TestMXFlatmmBase +{ +}; diff --git a/test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp b/test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp new file mode 100644 index 0000000000..46ae98b16b --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_fp4fp4.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include +#include "test_mx_flatmm_fixtures.hpp" + +// FP4 x FP4 -> FP16 +// N_Tile = 512 (MXfp4_FlatmmConfig16), so N must be a multiple of 512. +// K must be a multiple of 32 (ScaleGranularityK) and 8 (FP4 PackedSize) -> multiple of 32. +// clang-format off +using FP4FP4Types = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestMXFlatmm, FP4FP4Types); + +// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd +TYPED_TEST(TestMXFlatmm, SmallMNK) +{ + this->run_test_with_validation(128, 512, 256, 1, false, ck_tile::TailNumber::Odd); +} + +// K=512 -> num_loop=2: has_hot_loop=false, tail=Even +TYPED_TEST(TestMXFlatmm, MediumMNK) +{ + this->run_test_with_validation(256, 1024, 512, 1, false, ck_tile::TailNumber::Even); +} + +// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd) +{ + this->run_test_with_validation(128, 512, 768, 1, true, ck_tile::TailNumber::Odd); +} + +// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven) +{ + this->run_test_with_validation(128, 512, 1024, 1, true, ck_tile::TailNumber::Even); +} diff --git a/test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp b/test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp new file mode 100644 index 0000000000..dcbb080090 --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_fp4fp8.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include +#include "test_mx_flatmm_fixtures.hpp" + +// FP4 x FP8 -> FP16 +// N_Tile = 256, K must be a multiple of lcm(32, 8) = 32. +// clang-format off +using FP4FP8Types = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestMXFlatmm, FP4FP8Types); + +// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd +TYPED_TEST(TestMXFlatmm, SmallMNK) +{ + this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd); +} + +// K=512 -> num_loop=2: has_hot_loop=false, tail=Even +TYPED_TEST(TestMXFlatmm, MediumMNK) +{ + this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even); +} + +// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd) +{ + this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd); +} + +// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven) +{ + this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even); +} diff --git a/test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp b/test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp new file mode 100644 index 0000000000..e94287b0ad --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_fp6fp6.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include +#include "test_mx_flatmm_fixtures.hpp" + +// FP6 x FP6 -> FP16 +// N_Tile = 256, K must be a multiple of lcm(32, 16) = 32 (FP6 PackedSize=16, lcm(32,16)=32). +// clang-format off +using FP6FP6Types = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestMXFlatmm, FP6FP6Types); + +// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd +TYPED_TEST(TestMXFlatmm, SmallMNK) +{ + this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd); +} + +// K=512 -> num_loop=2: has_hot_loop=false, tail=Even +TYPED_TEST(TestMXFlatmm, MediumMNK) +{ + this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even); +} + +// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd) +{ + this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd); +} + +// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven) +{ + this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even); +} diff --git a/test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp b/test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp new file mode 100644 index 0000000000..15a767e163 --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_fp8fp4.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include +#include "test_mx_flatmm_fixtures.hpp" + +// FP8 x FP4 -> FP16 +// N_Tile = 256, K must be a multiple of lcm(32, 8) = 32. +// clang-format off +using FP8FP4Types = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestMXFlatmm, FP8FP4Types); + +// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd +TYPED_TEST(TestMXFlatmm, SmallMNK) +{ + this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd); +} + +// K=512 -> num_loop=2: has_hot_loop=false, tail=Even +TYPED_TEST(TestMXFlatmm, MediumMNK) +{ + this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even); +} + +// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd) +{ + this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd); +} + +// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven) +{ + this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even); +} diff --git a/test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp b/test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp new file mode 100644 index 0000000000..a0e85a60f9 --- /dev/null +++ b/test/ck_tile/flatmm/test_mx_flatmm_fp8fp8.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include +#include "test_mx_flatmm_fixtures.hpp" + +// FP8 x FP8 -> FP16 +// N_Tile = 256, K must be a multiple of 32. +// clang-format off +using FP8FP8Types = ::testing::Types< + std::tuple +>; +// clang-format on + +TYPED_TEST_SUITE(TestMXFlatmm, FP8FP8Types); + +// K=256 -> num_loop=1: has_hot_loop=false, tail=Odd +TYPED_TEST(TestMXFlatmm, SmallMNK) +{ + this->run_test_with_validation(128, 256, 256, 1, false, ck_tile::TailNumber::Odd); +} + +// K=512 -> num_loop=2: has_hot_loop=false, tail=Even +TYPED_TEST(TestMXFlatmm, MediumMNK) +{ + this->run_test_with_validation(256, 512, 512, 1, false, ck_tile::TailNumber::Even); +} + +// K=768 -> num_loop=3: has_hot_loop=true, tail=Odd +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopOdd) +{ + this->run_test_with_validation(128, 256, 768, 1, true, ck_tile::TailNumber::Odd); +} + +// K=1024 -> num_loop=4: has_hot_loop=true, tail=Even +TYPED_TEST(TestMXFlatmm, LargeK_HotLoopEven) +{ + this->run_test_with_validation(128, 256, 1024, 1, true, ck_tile::TailNumber::Even); +} diff --git a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp index 1c06d33e77..284feb477d 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp +++ b/test/ck_tile/gemm_streamk_tile_engine/test_gemm_streamk_simple.cpp @@ -52,7 +52,12 @@ bool compare_results(std::string instanceName, ck_tile::HostTensor& c_m_n_host_result) { const float max_accumulated_value = - *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + std::abs(static_cast(*std::max_element(c_m_n_host_result.mData.begin(), + c_m_n_host_result.mData.end(), + [](CDataType a, CDataType b) { + return std::abs(static_cast(a)) < + std::abs(static_cast(b)); + }))); const auto rtol_atol = calculate_rtol_atol( K, kbatch, max_accumulated_value); bool pass = ck_tile::check_err(c_m_n_dev_result, diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 4cc111b7cf..58e9168c6a 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -447,8 +447,12 @@ class TestCkTileGroupedGemm : public ::testing::Test c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const float max_accumulated_value = std::abs(static_cast(*std::max_element( + c_m_n_host_ref.mData.begin(), + c_m_n_host_ref.mData.end(), + [](CDataType a, CDataType b) { + return std::abs(static_cast(a)) < std::abs(static_cast(b)); + }))); const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value); pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref,