mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Clean up
This commit is contained in:
@@ -122,11 +122,6 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
|
||||
// if ((m == 2) && (n == 0))
|
||||
// {
|
||||
// printf("K:%i A:%f, B:%f, C:%f \n", k, v_a, v_b, v_acc);
|
||||
// }
|
||||
}
|
||||
|
||||
CDataType v_c{0};
|
||||
|
||||
@@ -120,6 +120,7 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
{
|
||||
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
|
||||
@@ -240,10 +240,7 @@ TEST(MXFP4, HostScaledConvert)
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
#if !CK_TEMP_DISABLE_FP4_TESTS
|
||||
>>>>>>> develop
|
||||
__global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_mx_fp4_scaled_convert(N, p_test, p_completed);
|
||||
@@ -543,7 +540,4 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert)
|
||||
EXPECT_EQ(N, completed);
|
||||
EXPECT_EQ(N, i);
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
#endif // CK_TEMP_DISABLE_FP4_TESTS
|
||||
>>>>>>> develop
|
||||
|
||||
@@ -54,14 +54,14 @@ bool run_mfma_test(ck::index_t init)
|
||||
|
||||
TEST(MFMA, FP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MFMA, FP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -127,14 +127,14 @@ bool run_mxmfma_test(ck::index_t init)
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
TEST(MXMFMA, MXFP8MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto AB_init = 5;
|
||||
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
@@ -296,7 +296,8 @@ __device__ AFragT load_A_row_major(AType const* input_ptr)
|
||||
// BLOCK_K is a stride in A matrix
|
||||
auto startOffset = row_major(
|
||||
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
|
||||
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K /
|
||||
// (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
auto kMajorOffset =
|
||||
row_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<AType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
@@ -513,7 +514,8 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
|
||||
// BLOCK_K is a stride in B matrix
|
||||
auto startOffset = col_major(
|
||||
startCoord2D, BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K);
|
||||
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K /
|
||||
// (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
auto kMajorOffset =
|
||||
col_major(majorStepCoord2D,
|
||||
BLOCK_K / (ck::is_same_v<ck::remove_cvref_t<BType>, ck::f4x2_pk_t> ? 2 : 1));
|
||||
@@ -937,7 +939,6 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
|
||||
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
|
||||
}
|
||||
|
||||
// auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
||||
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
|
||||
storeC(c, fragC);
|
||||
}
|
||||
@@ -1134,20 +1135,12 @@ struct TestMXMFMA
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}}); // 1/64
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{0.015625f}}); // 1/6
|
||||
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
|
||||
break;
|
||||
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
||||
// a_scales.GenerateTensorValue(
|
||||
// GeneratorTensor_1<ScaleType>{ScaleType{1.0f}}); // 1/64
|
||||
// // NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
// // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30
|
||||
// 32 a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 1>{});
|
||||
// b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
|
||||
// break;
|
||||
case 1:
|
||||
// results in C = {K}
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
@@ -1158,11 +1151,9 @@ struct TestMXMFMA
|
||||
case 2:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
|
||||
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{512.0f}});
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f / 512}});
|
||||
break;
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
@@ -1343,15 +1334,10 @@ struct TestMFMA
|
||||
switch(init)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.015625f});
|
||||
// NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
break;
|
||||
// case 0:
|
||||
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
|
||||
// // NOTE: not all numbers are representable in FP8, BF8, etc.
|
||||
// a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 1>{});
|
||||
// break;
|
||||
case 1:
|
||||
// results in C = {K}
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
|
||||
|
||||
Reference in New Issue
Block a user