// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "mx_mfma_op.hpp" using ck::e8m0_bexp_t; using ck::f4_t; using ck::f4x2_pk_t; using ck::f8_t; using ck::half_t; using ck::type_convert; /** * @brief Run the test for the given MFMA instruction * * @param init - selects initialization algorithm for A and B tensors */ template bool run_mfma_km_kn_nm_test(ck::index_t init) { using ALayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::ColumnMajor; using AccType = float; // only MFMA_F32 instructions supported using CPUAccType = AccType; ck::mfma_type(mfma)> mfma_instr; constexpr auto BLOCK_M = mfma_instr.m_per_blk; constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; const auto mfma_kernel = ck:: matmul; bool pass = true; pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); return pass; } TEST(MFMA, FP8MFMA16x16x128) { auto AB_init = 5; auto pass = run_mfma_km_kn_nm_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) { auto AB_init = 5; auto pass = run_mfma_km_kn_nm_test(AB_init); EXPECT_TRUE(pass); } /** * @brief Run the test for the given MFMA instruction * * @param init - selects initialization algorithm for A and B tensors */ template bool run_mfma_mk_kn_mn_test(ck::index_t init) { using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; using AccType = float; // only MFMA_F32 instructions supported using CPUAccType = AccType; ck::mfma_type(mfma)> mfma_instr; constexpr auto BLOCK_M = mfma_instr.m_per_blk; constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; const auto mfma_kernel = ck:: matmul; bool pass = true; pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); return pass; } TEST(MFMA, FP4MFMA16x16x128) { auto AB_init = 4; auto pass = run_mfma_mk_kn_mn_test( AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP4MFMA32x32x64) { auto AB_init = 4; auto pass = run_mfma_mk_kn_mn_test( AB_init); EXPECT_TRUE(pass); } /** * @brief Run the test for the given MX MFMA instruction * * @param init - selects initialization algorithm for A and B tensors */ template bool run_mxmfma_mk_kn_mn_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; using AccType = float; // only MFMA_F32 instructions supported using ScaleType = ck::e8m0_bexp_t; // biased exponent type ck::mfma_type(mfma)> mfma_instr; constexpr auto BLOCK_M = mfma_instr.m_per_blk; constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; constexpr auto BLOCK_X = 32; // scaling vector size const auto mx_mfma_kernel = ck::matmul; bool pass = true; pass = ck::mxmfma_test::TestMXMFMA{}(mx_mfma_kernel, init); return pass; } TEST(MXMFMA, MXFP8MFMA16x16x128) { auto AB_init = 5; auto pass = run_mxmfma_mk_kn_mn_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { auto AB_init = 5; auto pass = run_mxmfma_mk_kn_mn_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA16x16x128) { auto AB_init = 4; auto pass = run_mxmfma_mk_kn_mn_test( AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA32x32x64) { auto AB_init = 4; auto pass = run_mxmfma_mk_kn_mn_test( AB_init); EXPECT_TRUE(pass); }