mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Simulate TF32 with BF16x3 (#3142)
* tf32:bf16x3:use bf16x3 emulate tf32 gemm * change blockwiseGemm to demo bf16x3 * temp push * self review * self review * fix multi-device compile error * bug fix * code refactor * limit to gfx950 * enhance gemm gfx942 threshold * lower change from blockwise to warpwise * refact codes * refact codes * error fix * change threshold * bug fix * fix threshold error * change host reference implement to same as device * bug fix * bug fix * code refact * fix clang-format fail * code refine
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
@@ -92,7 +94,8 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
in_right_pads_{input_right_pads},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op}
|
||||
out_element_op_{out_element_op},
|
||||
device_name_{ck::get_device_name()}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -112,6 +115,7 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
::std::string device_name_; // the device which this conv is compared with
|
||||
};
|
||||
|
||||
struct Invoker : public device::BaseInvoker
|
||||
@@ -251,10 +255,39 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
x);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
if(arg.device_name_ == "gfx942")
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_wei));
|
||||
}
|
||||
else if(arg.device_name_ == "gfx950")
|
||||
{
|
||||
ck::bhalf_t v_in_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_in);
|
||||
ck::bhalf_t v_in_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_in - type_convert<float>(v_in_bf16_big));
|
||||
ck::bhalf_t v_wei_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_wei);
|
||||
ck::bhalf_t v_wei_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_wei - type_convert<float>(v_wei_bf16_big));
|
||||
|
||||
v_acc += ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_small) +
|
||||
ck::type_convert<float>(v_in_bf16_small) *
|
||||
ck::type_convert<float>(v_wei_bf16_big) +
|
||||
ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported device: " + arg.device_name_ +
|
||||
" for tf32 computation");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -350,10 +383,41 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
x);
|
||||
if constexpr(is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ComputeDataType>(v_wei));
|
||||
if(arg.device_name_ == "gfx942")
|
||||
{
|
||||
v_acc += ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_in)) *
|
||||
ck::type_convert<float>(
|
||||
ck::type_convert<ck::tf32_t>(v_wei));
|
||||
}
|
||||
else if(arg.device_name_ == "gfx950")
|
||||
{
|
||||
ck::bhalf_t v_in_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_in);
|
||||
ck::bhalf_t v_in_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_in - type_convert<float>(v_in_bf16_big));
|
||||
ck::bhalf_t v_wei_bf16_big =
|
||||
ck::type_convert<ck::bhalf_t>(v_wei);
|
||||
ck::bhalf_t v_wei_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(
|
||||
v_wei -
|
||||
type_convert<float>(v_wei_bf16_big));
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_small) +
|
||||
ck::type_convert<float>(v_in_bf16_small) *
|
||||
ck::type_convert<float>(v_wei_bf16_big) +
|
||||
ck::type_convert<float>(v_in_bf16_big) *
|
||||
ck::type_convert<float>(v_wei_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Unsupported device: " + arg.device_name_ +
|
||||
" for tf32 computation");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
@@ -45,7 +46,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
c_m_n_{c_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
c_element_op_{c_element_op},
|
||||
device_name_{ck::get_device_name()}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -56,6 +58,7 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
::std::string device_name_; // the device which this gemm is compared with
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -142,12 +145,37 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t>)
|
||||
{ // only for tf32 now
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float> && is_same_v<AccDataType, float> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t> &&
|
||||
is_same_v<ComputeTypeB, ck::tf32_t>)
|
||||
{
|
||||
if(arg.device_name_ == "gfx942")
|
||||
{
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
|
||||
}
|
||||
else if(arg.device_name_ == "gfx950")
|
||||
{
|
||||
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
|
||||
ck::bhalf_t v_a_bf16_small = ck::type_convert<ck::bhalf_t>(
|
||||
v_a - type_convert<float>(v_a_bf16_big));
|
||||
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
|
||||
ck::bhalf_t v_b_bf16_small = ck::type_convert<ck::bhalf_t>(
|
||||
v_b - type_convert<float>(v_b_bf16_big));
|
||||
|
||||
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported device: " + arg.device_name_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -82,9 +82,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
// multiply and accumulate
|
||||
if constexpr(is_same_v<ComputeTypeA, ComputeTypeB> &&
|
||||
is_same_v<ComputeTypeA, ck::tf32_t>)
|
||||
{ // only for tf32 now
|
||||
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeA>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ComputeTypeB>(v_b));
|
||||
{
|
||||
#if defined(__gfx942__)
|
||||
v_acc += ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_a)) *
|
||||
ck::type_convert<AccDataType>(ck::type_convert<ck::tf32_t>(v_b));
|
||||
#elif defined(__gfx950__)
|
||||
ck::bhalf_t v_a_bf16_big = ck::type_convert<ck::bhalf_t>(v_a);
|
||||
ck::bhalf_t v_a_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(v_a - type_convert<float>(v_a_bf16_big));
|
||||
ck::bhalf_t v_b_bf16_big = ck::type_convert<ck::bhalf_t>(v_b);
|
||||
ck::bhalf_t v_b_bf16_small =
|
||||
ck::type_convert<ck::bhalf_t>(v_b - type_convert<float>(v_b_bf16_big));
|
||||
|
||||
v_acc += ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_small) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_small) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big) +
|
||||
ck::type_convert<AccDataType>(v_a_bf16_big) *
|
||||
ck::type_convert<AccDataType>(v_b_bf16_big);
|
||||
#else
|
||||
v_acc += type_convert<AccDataType>(v_a) * type_convert<AccDataType>(v_b);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user