diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index f65e89bb82..cafd5179f4 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -6,6 +6,8 @@ #include "mx_mfma_op.hpp" using ck::e8m0_bexp_t; +using ck::f4_t; +using ck::f4x2_pk_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -122,3 +124,19 @@ TEST(MXMFMA, MXFP8MFMA32x32x64) auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } + +TEST(MXMFMA, MXFP4MFMA16x16x128) +{ + auto AB_init = 7; + auto pass = + run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA32x32x64) +{ + auto AB_init = 7; + auto pass = + run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 1f9091ebc5..fee1047a8d 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -319,8 +320,8 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | // clang-format on - static constexpr uint32_t VW = vectorSize(AFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + const uint32_t VW = vectorSize(AFragT{}); + // static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element @@ -487,8 +488,8 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | // clang-format on - static constexpr uint32_t VW = vectorSize(BFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + const uint32_t VW = vectorSize(BFragT{}); + // static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element @@ -800,8 +801,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type;