From 94e5175ba358aa28516194f1b1daef74fb136a26 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Thu, 1 May 2025 17:22:34 +0000 Subject: [PATCH] Clean up --- .../cpu/reference_gemm.hpp | 5 ---- .../cpu/reference_mx_gemm.hpp | 1 + test/data_type/test_mx_fp4.cpp | 6 ---- test/mx_mfma_op/mx_mfma_op.cpp | 8 ++--- test/mx_mfma_op/mx_mfma_op.hpp | 30 +++++-------------- 5 files changed, 13 insertions(+), 37 deletions(-) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 1029dcec2f..c8d284a1d7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -122,11 +122,6 @@ struct ReferenceGemm : public device::BaseOperator v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); - - // if ((m == 2) && (n == 0)) - // { - // printf("K:%i A:%f, B:%f, C:%f \n", k, v_a, v_b, v_acc); - // } } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index e0697e3360..e8fdcf1acd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -120,6 +120,7 @@ struct ReferenceMXGemm : public device::BaseOperator { if constexpr(is_same_v) { + // TODO: add support for RowMajor layout as well if(k % 2 == 1) b_k_n_scaled(k, n) = type_convert( diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index a4fe044bc5..7aca42567c 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -240,10 +240,7 @@ TEST(MXFP4, HostScaledConvert) EXPECT_EQ(test_size, i); } -<<<<<<< HEAD -======= #if !CK_TEMP_DISABLE_FP4_TESTS ->>>>>>> develop __global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) { test_mx_fp4_scaled_convert(N, p_test, p_completed); @@ -543,7 +540,4 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) EXPECT_EQ(N, completed); EXPECT_EQ(N, i); } -<<<<<<< HEAD -======= #endif // CK_TEMP_DISABLE_FP4_TESTS ->>>>>>> develop diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index 30c57632ac..7db3d722c4 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -54,14 +54,14 @@ bool run_mfma_test(ck::index_t init) TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 7; + auto AB_init = 5; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) { - auto AB_init = 7; + auto AB_init = 5; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } @@ -127,14 +127,14 @@ bool run_mxmfma_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 7; + auto AB_init = 5; auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 7; + auto AB_init = 5; auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index a0f130f3b4..362884573d 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -296,7 +296,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // BLOCK_K is a stride in A matrix auto startOffset = row_major( startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); @@ -513,7 +514,8 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // BLOCK_K is a stride in B matrix auto startOffset = col_major( startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); @@ -937,7 +939,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - // auto storeC = store_C_col_major{}; auto storeC = store_C_row_major{}; storeC(c, fragC); } @@ -1134,20 +1135,12 @@ struct TestMXMFMA { case 0: a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); // 1/64 + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); break; - // b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - // a_scales.GenerateTensorValue( - // GeneratorTensor_1{ScaleType{1.0f}}); // 1/64 - // // NOTE: not all numbers are representable in FP8, BF8, etc. - // // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 - // 32 a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); - // b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); - // break; case 1: // results in C = {K} a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); @@ -1158,11 +1151,9 @@ struct TestMXMFMA case 2: // expect small round off errors a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); break; case 3: // expect small round off errors @@ -1343,15 +1334,10 @@ struct TestMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{0.015625f}); // NOTE: not all numbers are representable in FP8, BF8, etc. b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); break; - // case 0: - // b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - // // NOTE: not all numbers are representable in FP8, BF8, etc. - // a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); - // break; case 1: // results in C = {K} a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f});