diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index a0ec7483f5..30c57632ac 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -20,9 +20,9 @@ using ck::type_convert; template bool run_mfma_test(ck::index_t init) { - using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; using AccType = float; // only MFMA_F32 instructions supported using CPUAccType = AccType; @@ -52,19 +52,19 @@ bool run_mfma_test(ck::index_t init) return pass; } -// TEST(MFMA, FP8MFMA16x16x128) -// { -// auto AB_init = 7; -// auto pass = run_mfma_test(AB_init); -// EXPECT_TRUE(pass); -// } +TEST(MFMA, FP8MFMA16x16x128) +{ + auto AB_init = 7; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} -// TEST(MFMA, FP8MFMA32x32x64) -// { -// auto AB_init = 7; -// auto pass = run_mfma_test(AB_init); -// EXPECT_TRUE(pass); -// } +TEST(MFMA, FP8MFMA32x32x64) +{ + auto AB_init = 7; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} TEST(MFMA, FP4MFMA16x16x128) { @@ -73,12 +73,12 @@ TEST(MFMA, FP4MFMA16x16x128) EXPECT_TRUE(pass); } -// TEST(MFMA, FP4MFMA32x32x64) -// { -// auto AB_init = 4; -// auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); -// } +TEST(MFMA, FP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} /** * @brief Run the test for the given MX MFMA instruction @@ -125,34 +125,32 @@ bool run_mxmfma_test(ck::index_t init) return pass; } -// TEST(MXMFMA, MXFP8MFMA16x16x128) -// { -// auto AB_init = 7; -// auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); -// } +TEST(MXMFMA, MXFP8MFMA16x16x128) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} -// TEST(MXMFMA, MXFP8MFMA32x32x64) -// { -// auto AB_init = 7; -// auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); -// } +TEST(MXMFMA, MXFP8MFMA32x32x64) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} -// TEST(MXMFMA, MXFP4MFMA16x16x128) -// { -// auto AB_init = 4; -// auto pass = -// run_mxmfma_test(AB_init); -// EXPECT_TRUE(pass); -// } +TEST(MXMFMA, MXFP4MFMA16x16x128) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} -// TEST(MXMFMA, MXFP4MFMA32x32x64) -// { -// auto AB_init = 4; -// auto pass = -// run_mxmfma_test(AB_init); -// EXPECT_TRUE(pass); -// } +TEST(MXMFMA, MXFP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +}