mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Add f4x2 init mode
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -81,6 +81,18 @@ struct GeneratorTensor_1<ck::f4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::f4x2_pk_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{value, value})};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
@@ -209,6 +221,21 @@ struct GeneratorTensor_2<ck::f4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ck::f4x2_pk_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
float tmp0 = (std::rand() % (max_value - min_value)) + min_value;
|
||||
float tmp1 = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{tmp0, tmp1})};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -296,6 +323,25 @@ struct GeneratorTensor_3<ck::f4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ck::f4x2_pk_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::f4x2_pk_t operator()(Is...)
|
||||
{
|
||||
float tmp0 = float(std::rand()) / float(RAND_MAX);
|
||||
float tmp1 = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp0 = min_value + tmp0 * (max_value - min_value);
|
||||
float fp32_tmp1 = min_value + tmp1 * (max_value - min_value);
|
||||
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_4
|
||||
{
|
||||
|
||||
@@ -127,7 +127,7 @@ TEST(MXMFMA, MXFP8MFMA32x32x64)
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA16x16x128)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto AB_init = 4;
|
||||
auto pass =
|
||||
run_mxmfma_test<f4x2_pk_t, f4x2_pk_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
@@ -135,7 +135,7 @@ TEST(MXMFMA, MXFP4MFMA16x16x128)
|
||||
|
||||
TEST(MXMFMA, MXFP4MFMA32x32x64)
|
||||
{
|
||||
auto AB_init = 7;
|
||||
auto AB_init = 4;
|
||||
auto pass =
|
||||
run_mxmfma_test<f4x2_pk_t, f4x2_pk_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
|
||||
EXPECT_TRUE(pass);
|
||||
|
||||
@@ -997,7 +997,6 @@ struct TestMXMFMA
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
b_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 129});
|
||||
break;
|
||||
|
||||
case 3:
|
||||
// expect small round off errors
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(0, 1));
|
||||
@@ -1007,6 +1006,14 @@ struct TestMXMFMA
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
break;
|
||||
case 4:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1., 1.});
|
||||
a_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
b_n_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1., 1.});
|
||||
b_scales.GenerateTensorValue(
|
||||
GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
|
||||
break;
|
||||
default:
|
||||
// all initial values are representable in FP8, BF8
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
|
||||
|
||||
Reference in New Issue
Block a user