This commit is contained in:
Rostyslav Geyyer
2025-05-01 17:22:34 +00:00
parent 0fc2f528e0
commit 94e5175ba3
5 changed files with 13 additions and 37 deletions

View File

@@ -122,11 +122,6 @@ struct ReferenceGemm : public device::BaseOperator
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
// if ((m == 2) && (n == 0))
// {
// printf("K:%i A:%f, B:%f, C:%f \n", k, v_a, v_b, v_acc);
// }
}
CDataType v_c{0};

View File

@@ -120,6 +120,7 @@ struct ReferenceMXGemm : public device::BaseOperator
{
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(

View File

@@ -240,10 +240,7 @@ TEST(MXFP4, HostScaledConvert)
EXPECT_EQ(test_size, i);
}
<<<<<<< HEAD
=======
#if !CK_TEMP_DISABLE_FP4_TESTS
>>>>>>> develop
__global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_mx_fp4_scaled_convert(N, p_test, p_completed);
@@ -543,7 +540,4 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert)
EXPECT_EQ(N, completed);
EXPECT_EQ(N, i);
}
<<<<<<< HEAD
=======
#endif // CK_TEMP_DISABLE_FP4_TESTS
>>>>>>> develop

View File

@@ -54,14 +54,14 @@ bool run_mfma_test(ck::index_t init)
TEST(MFMA, FP8MFMA16x16x128)
{
auto AB_init = 7;
auto AB_init = 5;
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MFMA, FP8MFMA32x32x64)
{
auto AB_init = 7;
auto AB_init = 5;
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
@@ -127,14 +127,14 @@ bool run_mxmfma_test(ck::index_t init)
TEST(MXMFMA, MXFP8MFMA16x16x128)
{
auto AB_init = 7;
auto AB_init = 5;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64)
{
auto AB_init = 7;
auto AB_init = 5;
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}

View File

@@ -296,7 +296,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
row_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
@@ -513,7 +514,8 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K);
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
auto kMajorOffset =
col_major(majorStepCoord2D,
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
@@ -937,7 +939,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
// auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
}
@@ -1134,20 +1135,12 @@ struct TestMXMFMA
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}}); // 1/64
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
// NOTE: not all numbers are representable in FP8, BF8, etc.
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// a_scales.GenerateTensorValue(
// GeneratorTensor_1<ScaleType>{ScaleType{1.0f}}); // 1/64
// // NOTE: not all numbers are representable in FP8, BF8, etc.
// // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30
// 32 a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 1>{});
// b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
@@ -1158,11 +1151,9 @@ struct TestMXMFMA
case 2:
// expect small round off errors
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_scales.GenerateTensorValue(
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
break;
case 3:
// expect small round off errors
@@ -1343,15 +1334,10 @@ struct TestMFMA
switch(init)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
// NOTE: not all numbers are representable in FP8, BF8, etc.
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
break;
// case 0:
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// // NOTE: not all numbers are representable in FP8, BF8, etc.
// a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 1>{});
// break;
case 1:
// results in C = {K}
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});