Simulate TF32 with BF16x3 (#3142)

* tf32:bf16x3:use bf16x3 emulate tf32 gemm

* change blockwiseGemm to demo bf16x3

* temp push

* self review

* self review

* fix multi-device compile error

* bug fix

* code refactor

* limit to gfx950

* enhance gemm gfx942 threshold

* lower change from blockwise to warpwise

* refact codes

* refact codes

* error fix

* change threshold

* bug fix

* fix threshold error

* change host reference implement to same as device

* bug fix

* bug fix

* code refact

* fix clang-format fail

* code refine
This commit is contained in:
yinglu
2025-11-14 08:21:09 +08:00
committed by GitHub
parent f2cfc6b94e
commit 2a73eb3bc0
16 changed files with 419 additions and 49 deletions

View File

@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
@@ -92,7 +94,8 @@ struct ReferenceConvFwd : public device::BaseOperator
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op}
out_element_op_{out_element_op},
device_name_{ck::get_device_name()}
{
}
@@ -112,6 +115,7 @@ struct ReferenceConvFwd : public device::BaseOperator
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
::std::string device_name_; // the device which this conv is compared with
};
struct Invoker : public device::BaseInvoker
@@ -251,10 +255,39 @@ struct ReferenceConvFwd : public device::BaseOperator
x);
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
v_acc += ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_wei));
if(arg.device_name_ == "gfx942")
{
v_acc += ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_wei));
}
else if(arg.device_name_ == "gfx950")
{
ck::bhalf_t v_in_bf16_big =
ck::type_convert<ck::bhalf_t>(v_in);
ck::bhalf_t v_in_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_in - type_convert<float>(v_in_bf16_big));
ck::bhalf_t v_wei_bf16_big =
ck::type_convert<ck::bhalf_t>(v_wei);
ck::bhalf_t v_wei_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_wei - type_convert<float>(v_wei_bf16_big));
v_acc += ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_small) +
ck::type_convert<float>(v_in_bf16_small) *
ck::type_convert<float>(v_wei_bf16_big) +
ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_big);
}
else
{
throw std::runtime_error(
"Unsupported device: " + arg.device_name_ +
" for tf32 computation");
}
}
else
{
@@ -350,10 +383,41 @@ struct ReferenceConvFwd : public device::BaseOperator
x);
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
{
v_acc += ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ComputeDataType>(v_wei));
if(arg.device_name_ == "gfx942")
{
v_acc += ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_in)) *
ck::type_convert<float>(
ck::type_convert<ck::tf32_t>(v_wei));
}
else if(arg.device_name_ == "gfx950")
{
ck::bhalf_t v_in_bf16_big =
ck::type_convert<ck::bhalf_t>(v_in);
ck::bhalf_t v_in_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_in - type_convert<float>(v_in_bf16_big));
ck::bhalf_t v_wei_bf16_big =
ck::type_convert<ck::bhalf_t>(v_wei);
ck::bhalf_t v_wei_bf16_small =
ck::type_convert<ck::bhalf_t>(
v_wei -
type_convert<float>(v_wei_bf16_big));
v_acc +=
ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_small) +
ck::type_convert<float>(v_in_bf16_small) *
ck::type_convert<float>(v_wei_bf16_big) +
ck::type_convert<float>(v_in_bf16_big) *
ck::type_convert<float>(v_wei_bf16_big);
}
else
{
throw std::runtime_error(
"Unsupported device: " + arg.device_name_ +
" for tf32 computation");
}
}
else
{

View File

@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
@@ -45,7 +46,8 @@ struct ReferenceGemm : public device::BaseOperator
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
device_name_{ck::get_device_name()}
{
}
@@ -56,6 +58,7 @@ struct ReferenceGemm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
::std::string device_name_; // the device which this gemm is compared with
};
// Invoker
@@ -142,12 +145,37 @@ struct ReferenceGemm : public device::BaseOperator
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
is_same_v<ComputeTypeA, ck::tf32_t>)
{ // only for tf32 now
v_acc +=
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float> && is_same_v<AccDataType, float> &&
is_same_v<ComputeTypeA, ck::tf32_t> &&
is_same_v<ComputeTypeB, ck::tf32_t>)
{
if(arg.device_name_ == "gfx942")
{
v_acc +=
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
}
else if(arg.device_name_ == "gfx950")
{
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
ck::bhalf_t v_a_bf16_small = ck::type_convert<ck::bhalf_t>(
v_a - type_convert<float>(v_a_bf16_big));
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
ck::bhalf_t v_b_bf16_small = ck::type_convert<ck::bhalf_t>(
v_b - type_convert<float>(v_b_bf16_big));
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_small) +
ck::type_convert<AccDataType>(v_a_bf16_small) *
ck::type_convert<AccDataType>(v_b_bf16_big) +
ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_big);
}
else
{
throw std::runtime_error("Unsupported device: " + arg.device_name_);
}
}
else
{

View File

@@ -82,9 +82,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// multiply and accumulate
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
is_same_v<ComputeTypeA, ck::tf32_t>)
{ // only for tf32 now
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
{
#if defined(__gfx942__)
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
#elif defined(__gfx950__)
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
ck::bhalf_t v_a_bf16_small =
ck::type_convert<ck::bhalf_t>(v_a - type_convert<float>(v_a_bf16_big));
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
ck::bhalf_t v_b_bf16_small =
ck::type_convert<ck::bhalf_t>(v_b - type_convert<float>(v_b_bf16_big));
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_small) +
ck::type_convert<AccDataType>(v_a_bf16_small) *
ck::type_convert<AccDataType>(v_b_bf16_big) +
ck::type_convert<AccDataType>(v_a_bf16_big) *
ck::type_convert<AccDataType>(v_b_bf16_big);
#else
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
#endif
}
else
{