diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index 4c29f6f137..a0ec7483f5 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -73,12 +73,12 @@ TEST(MFMA, FP4MFMA16x16x128) EXPECT_TRUE(pass); } -TEST(MFMA, FP4MFMA32x32x64) -{ - auto AB_init = 4; - auto pass = run_mfma_test(AB_init); - EXPECT_TRUE(pass); -} +// TEST(MFMA, FP4MFMA32x32x64) +// { +// auto AB_init = 4; +// auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); +// } /** * @brief Run the test for the given MX MFMA instruction @@ -125,32 +125,34 @@ bool run_mxmfma_test(ck::index_t init) return pass; } -TEST(MXMFMA, MXFP8MFMA16x16x128) -{ - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); - EXPECT_TRUE(pass); -} +// TEST(MXMFMA, MXFP8MFMA16x16x128) +// { +// auto AB_init = 7; +// auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); +// } -TEST(MXMFMA, MXFP8MFMA32x32x64) -{ - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); - EXPECT_TRUE(pass); -} +// TEST(MXMFMA, MXFP8MFMA32x32x64) +// { +// auto AB_init = 7; +// auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); +// } -TEST(MXMFMA, MXFP4MFMA16x16x128) -{ - auto AB_init = 4; - auto pass = - run_mxmfma_test(AB_init); - EXPECT_TRUE(pass); -} +// TEST(MXMFMA, MXFP4MFMA16x16x128) +// { +// auto AB_init = 4; +// auto pass = +// run_mxmfma_test(AB_init); +// EXPECT_TRUE(pass); +// } -TEST(MXMFMA, MXFP4MFMA32x32x64) -{ - auto AB_init = 4; - auto pass = - run_mxmfma_test(AB_init); - EXPECT_TRUE(pass); -} +// TEST(MXMFMA, MXFP4MFMA32x32x64) +// { +// auto AB_init = 4; +// auto pass = +// run_mxmfma_test(AB_init); +// EXPECT_TRUE(pass); +// } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 5e00966521..49a7b766eb 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -131,7 +131,7 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | - // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | @@ -1279,7 +1279,7 @@ struct TestMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{0.015625f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); // NOTE: not all numbers are representable in FP8, BF8, etc. b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); break; @@ -1299,6 +1299,7 @@ struct TestMFMA b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; case 4: + // FP4 values case a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); break;