mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
MX FP GEMM - Test MX FP8 MFMA Instructions (#1902)
* Refactored `load_A_row_major` to follow scale mapping * Refactored `load_A_col_major` to follow scale mapping * Refactored `load_B_col_major` to follow scale mapping * Verified non-scaled test * Verified scaled tests * Used ReferenceMXGemm for verification * Updated license headers
This commit is contained in:
committed by
GitHub
parent
3ace125c30
commit
ffa13455a2
@@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init)
|
||||
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
|
||||
const auto mx_mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
|
||||
const auto mfma_kernel = ck::matmul<AType, BType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = ck::mfma_test::TestMFMA<decltype(mx_mfma_kernel),
|
||||
pass = ck::mfma_test::TestMFMA<decltype(mfma_kernel),
|
||||
AType,
|
||||
BType,
|
||||
CType,
|
||||
@@ -45,21 +45,80 @@ bool run_mfma_test(ck::index_t init)
|
||||
CLayout,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K>{}(mx_mfma_kernel, init);
|
||||
BLOCK_K>{}(mfma_kernel, init);
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 0;
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 0;
|
||||
auto AB_init = 4;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Run the test for the given MX MFMA instruction
|
||||
*
|
||||
* @param init - selects initialization algorithm for A and B tensors
|
||||
*/
|
||||
template <typename AType, typename BType, typename CType, ck::MFMA_F8F6F4 mfma>
|
||||
bool run_mxmfma_test(ck::index_t init)
|
||||
{
|
||||
static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 ||
|
||||
mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64,
|
||||
"Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported");
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AccType = float; // only MFMA_F32 instructions supported
|
||||
using ScaleType = ck::e8m0_bexp_t; // biased exponent type
|
||||
|
||||
ck::mfma_type<static_cast<ck::MfmaInstr>(mfma)> mfma_instr;
|
||||
constexpr auto BLOCK_M = mfma_instr.m_per_blk;
|
||||
constexpr auto BLOCK_N = mfma_instr.n_per_blk;
|
||||
constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk;
|
||||
constexpr auto BLOCK_X = 32; // scaling vector size
|
||||
|
||||
const auto mx_mfma_kernel =
|
||||
ck::matmul<AType, BType, ScaleType, CType, AccType, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_X>;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = ck::mxmfma_test::TestMXMFMA<decltype(mx_mfma_kernel),
|
||||
AType,
|
||||
BType,
|
||||
ScaleType,
|
||||
CType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
BLOCK_X>{}(mx_mfma_kernel, init);
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user