mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Add rounding for float to bf16 conversion as default (#1812)
* Add rounding for float to bf16 conversion
* Add bhalf test
* Add inf test bhalf
* Refactor
* update cmake
* Fixes
[ROCm/composable_kernel commit: 7790e8c3f7]
This commit is contained in:
@@ -48,9 +48,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
|
||||
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8)
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceComputeType = float;
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ReferenceComputeType,
|
||||
ReferenceComputeType>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -155,6 +155,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
// LDS direct loads using inline assembly
|
||||
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
|
||||
|
||||
// set rounding to nearest even as default for bf16 conversions
|
||||
#define CK_USE_RNE_BF16_CONVERSION 1
|
||||
|
||||
// set rounding to nearest even as default for f8 conversions
|
||||
#define CK_USE_SR_F8_CONVERSION 0
|
||||
|
||||
|
||||
@@ -14,6 +14,41 @@ namespace ck {
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
// Convert fp32 to bf16 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
// Nan check
|
||||
if(x != x)
|
||||
{
|
||||
return uint16_t(0x7FC0);
|
||||
}
|
||||
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
|
||||
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
|
||||
|
||||
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return bf16_convert_rtn<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
// convert fp32 to bfp16
|
||||
// convert fp32 to bfp16, round to nearest even
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
#if CK_USE_RNE_BF16_CONVERSION
|
||||
return bf16_convert_rtn<bhalf_t>(x);
|
||||
#else
|
||||
return uint16_t(u.int32 >> 16);
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert bfp16 to fp16 via fp32
|
||||
@@ -615,60 +648,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
|
||||
}
|
||||
}
|
||||
|
||||
// Declare a template function for bf16 conversion using RTN
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
|
||||
// Convert fp32 to bf16 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
bool flag0 = ~u.int32 & 0x7f800000;
|
||||
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bfloat16's mantissa bits are all 0.
|
||||
bool flag1 = !flag0 && (u.int32 & 0xffff);
|
||||
|
||||
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
|
||||
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
|
||||
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
|
||||
{
|
||||
float x_fp32 = static_cast<float>(x);
|
||||
|
||||
return bf16_convert_rtn<bhalf_t>(x_fp32);
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -73,39 +73,9 @@ struct ReferencefpAintBGemm : public device::BaseOperator
|
||||
ScaleDataType v_scale;
|
||||
ADataType v_converted_b;
|
||||
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
|
||||
// same for scale matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
|
||||
arg.scale_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
|
||||
}
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
|
||||
|
||||
v_converted_b = type_convert<ADataType>(v_b) * v_scale;
|
||||
v_acc += ck::type_convert<AccDataType>(v_a) *
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -68,13 +68,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, pk_i4_t>)
|
||||
if constexpr(is_same_v<ADataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.a_m_k_(m, k).data;
|
||||
int8_t i4 = 0;
|
||||
@@ -89,13 +83,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.b_k_n_(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -74,26 +74,8 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
|
||||
@@ -49,3 +49,4 @@ if(result EQUAL 0)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
|
||||
add_gtest_executable(test_bhalf test_bhalf.cpp)
|
||||
|
||||
48
test/data_type/test_bhalf.cpp
Normal file
48
test/data_type/test_bhalf.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bhalf_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(BHALF_T, Nan)
|
||||
{
|
||||
const uint16_t binary_bhalf_nan = 0x7FC0;
|
||||
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
|
||||
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, Inf)
|
||||
{
|
||||
const uint16_t binary_bhalf_inf = 0x7F80;
|
||||
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
|
||||
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
|
||||
}
|
||||
|
||||
TEST(BHALF_T, MantisaOverflow)
|
||||
{
|
||||
const float abs_tol = std::pow(2, -7);
|
||||
const uint32_t val = 0x81FFFFFF;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
|
||||
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, ExpOverflow)
|
||||
{
|
||||
const uint32_t val = 0xFF800000;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
|
||||
}
|
||||
|
||||
TEST(BHALF_T, MantisaExpOverflow)
|
||||
{
|
||||
const uint32_t val = 0xFFFFFFFF;
|
||||
const float float_val = ck::bit_cast<float>(val);
|
||||
|
||||
ASSERT_TRUE(std::isnan(float_val));
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
|
||||
}
|
||||
Reference in New Issue
Block a user