mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add rounding for float to bf16 conversion as default (#1812)
* Add rounding for float to bf16 conversion * Add bhalf test * Add inf test bhalf * Refactor * update cmake * Fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -73,39 +73,9 @@ struct ReferencefpAintBGemm : public device::BaseOperator
|
||||
ScaleDataType v_scale;
|
||||
ADataType v_converted_b;
|
||||
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
|
||||
// same for scale matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
|
||||
arg.scale_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
|
||||
}
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
|
||||
|
||||
v_converted_b = type_convert<ADataType>(v_b) * v_scale;
|
||||
v_acc += ck::type_convert<AccDataType>(v_a) *
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -68,13 +68,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, pk_i4_t>)
|
||||
if constexpr(is_same_v<ADataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.a_m_k_(m, k).data;
|
||||
int8_t i4 = 0;
|
||||
@@ -89,13 +83,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.b_k_n_(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -74,26 +74,8 @@ struct ReferenceGemmMultipleD : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
}
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::ConvertBF16RTN>)
|
||||
{
|
||||
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
|
||||
Reference in New Issue
Block a user