mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Reduction for int8 and bfloat16 (#125)
* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction * Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter * Rename the folder name for the pool2d and reduce examples * Update to reduction test scripts * Add Readme for pool2d_fwd and reduce_blockwise examples * Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX) * Tiny fix in reduce profiler and tiny update in reduce testing scripts * Tiny fix in testing script profile_reduce_no_index.sh * Tiny fix in testing script profile_reduce_no_index.sh * Add support for bfp16 reduction (using bhalf_t = ushort) * Tiny fix in amd_buffer_addressing.hpp * Tiny change in script/profile_reduce_with_index.sh * Use AccDataType for Beta value and use element_wise::PassThrough * Use type_convert for type converting in host layer reduction * Renaming and refining in Reduction profiler/device layer/examples * Renaming and refining in Reduction profiler/device layer/examples * Renaming all NumReduceDims to NumReduceDim * Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2 * Update to testing scripts to add bf16 support * added more static_assert * Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp * Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations * minor change * Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass * Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp * Tiny fix in script/profile_reduce_no_index.sh * Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims * Generic renaming in host reduction and DeviceReduce layer * Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances * Use multi-thread and simplification for host Reduction implementation * Add ctest for reduction * Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/ * Update to the reduce CTest executables to enable default testing behavior when no command argument * Renaming Co-authored-by: Jianfeng yan <jfyan008@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
#include "device_reduce.hpp"
|
||||
#include "device_reduce_instance.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
#include "host_generic_reduction.hpp"
|
||||
#include "host_reduction.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -20,34 +20,43 @@ struct ReduceDescription
|
||||
};
|
||||
|
||||
using reduce_description_instances = std::tuple<ReduceDescription<4, 3, 0, 0, 0>, // for ADD
|
||||
ReduceDescription<4, 4, 0, 0, 0>,
|
||||
ReduceDescription<4, 1, 0, 0, 0>,
|
||||
ReduceDescription<2, 1, 0, 0, 0>,
|
||||
|
||||
ReduceDescription<4, 3, 5, 0, 0>, // for AVG
|
||||
ReduceDescription<4, 4, 5, 0, 0>,
|
||||
ReduceDescription<4, 1, 5, 0, 0>,
|
||||
ReduceDescription<2, 1, 5, 0, 0>,
|
||||
|
||||
ReduceDescription<4, 3, 7, 0, 0>, // for NORM2
|
||||
ReduceDescription<4, 4, 7, 0, 0>,
|
||||
ReduceDescription<4, 1, 7, 0, 0>,
|
||||
ReduceDescription<2, 1, 7, 0, 0>,
|
||||
|
||||
ReduceDescription<4, 3, 2, 0, 0>, // for MIN
|
||||
ReduceDescription<4, 4, 2, 0, 0>,
|
||||
ReduceDescription<4, 1, 2, 0, 0>,
|
||||
ReduceDescription<2, 1, 2, 0, 0>,
|
||||
ReduceDescription<4, 3, 3, 0, 0>, // for MAX
|
||||
ReduceDescription<4, 4, 3, 0, 0>,
|
||||
ReduceDescription<4, 1, 3, 0, 0>,
|
||||
ReduceDescription<2, 1, 3, 0, 0>,
|
||||
ReduceDescription<4, 3, 4, 0, 0>, // for AMAX
|
||||
ReduceDescription<4, 4, 4, 0, 0>,
|
||||
ReduceDescription<4, 1, 4, 0, 0>,
|
||||
ReduceDescription<2, 1, 4, 0, 0>,
|
||||
|
||||
ReduceDescription<4, 3, 2, 0, 1>, // for MIN
|
||||
ReduceDescription<4, 4, 2, 0, 1>,
|
||||
ReduceDescription<4, 1, 2, 0, 1>,
|
||||
ReduceDescription<2, 1, 2, 0, 1>,
|
||||
ReduceDescription<4, 3, 3, 0, 1>, // for MAX
|
||||
ReduceDescription<4, 4, 3, 0, 1>,
|
||||
ReduceDescription<4, 1, 3, 0, 1>,
|
||||
ReduceDescription<2, 1, 3, 0, 1>,
|
||||
ReduceDescription<4, 3, 4, 0, 1>, // for AMAX
|
||||
ReduceDescription<4, 4, 4, 0, 1>,
|
||||
ReduceDescription<4, 1, 4, 0, 1>,
|
||||
ReduceDescription<2, 1, 4, 0, 1>>;
|
||||
|
||||
@@ -122,16 +131,16 @@ static void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
|
||||
};
|
||||
|
||||
// map the data type used by the GPU kernels to the corresponding type used by the host codes
|
||||
template <typename inDataType>
|
||||
template <typename InType>
|
||||
struct type_mapping
|
||||
{
|
||||
using outDataType = inDataType;
|
||||
using OutType = InType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_mapping<ck::half_t>
|
||||
{
|
||||
using outDataType = half_float::half;
|
||||
using OutType = half_float::half;
|
||||
};
|
||||
|
||||
template <typename InDataType,
|
||||
@@ -187,7 +196,26 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
constexpr bool invalid_reduce_3 =
|
||||
(!op_support_indices && IndicesOpt != ReduceTensorIndices_t::NO_INDICES);
|
||||
|
||||
constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3);
|
||||
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
|
||||
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
|
||||
// operations
|
||||
constexpr bool invalid_reduce_4 =
|
||||
std::is_same<InDataType, int8_t>::value &&
|
||||
((!op_support_indices && !std::is_same<AccDataType, int32_t>::value) ||
|
||||
(op_support_indices && !std::is_same<AccDataType, int8_t>::value));
|
||||
|
||||
// 1) If InDataType is int8_t, the supported operation must be either indexable operations or
|
||||
// ADD/AVG
|
||||
constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value &&
|
||||
(!op_support_indices && ReduceOpId != ReduceTensorOp_t::ADD &&
|
||||
ReduceOpId != ReduceTensorOp_t::AVG);
|
||||
|
||||
// 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
|
||||
constexpr bool invalid_reduce_6 =
|
||||
std::is_same<InDataType, bhalf_t>::value && !std::is_same<AccDataType, float>::value;
|
||||
|
||||
constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 ||
|
||||
invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
|
||||
|
||||
if constexpr(!invalid_reduce)
|
||||
{
|
||||
@@ -205,8 +233,8 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
|
||||
Tensor<OutDataType> out_ref(outLengths);
|
||||
Tensor<OutDataType> out(outLengths);
|
||||
Tensor<int> out_indices_ref(outLengths);
|
||||
Tensor<int> out_indices(outLengths);
|
||||
Tensor<int32_t> out_indices_ref(outLengths);
|
||||
Tensor<int32_t> out_indices(outLengths);
|
||||
|
||||
auto inStrides = in.mDesc.GetStrides();
|
||||
auto outStrides = out.mDesc.GetStrides();
|
||||
@@ -220,20 +248,22 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{}, num_thread);
|
||||
break;
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{1, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<InDataType>{1, 5}, num_thread);
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0},
|
||||
num_thread);
|
||||
}
|
||||
|
||||
if(beta != 0.0f)
|
||||
@@ -306,6 +336,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
IndicesOpt>(reduce0_ptrs);
|
||||
|
||||
if constexpr(use_atomic_add)
|
||||
{
|
||||
add_device_reduce_instance_multiblock_atomic_add<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
@@ -314,7 +345,9 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce0_ptrs);
|
||||
}
|
||||
else
|
||||
{
|
||||
add_device_reduce_instance_multiblock_partial_reduce<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
@@ -323,9 +356,11 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce1_ptrs);
|
||||
};
|
||||
|
||||
// used for secondary reduction
|
||||
if constexpr(!use_atomic_add)
|
||||
{
|
||||
add_device_reduce_instance_blockwise_second_call<AccDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
@@ -334,6 +369,7 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
ReduceOpId,
|
||||
NanOpt,
|
||||
IndicesOpt>(reduce2_ptrs);
|
||||
};
|
||||
|
||||
if(reduce0_ptrs.empty() && reduce1_ptrs.empty())
|
||||
{
|
||||
@@ -342,17 +378,24 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using hInType = typename type_mapping<InDataType>::outDataType;
|
||||
using hOutType = typename type_mapping<OutDataType>::outDataType;
|
||||
using hCompType = typename type_mapping<AccDataType>::outDataType;
|
||||
using HostInDataType = typename type_mapping<InDataType>::OutType;
|
||||
using HostOutDataType = typename type_mapping<OutDataType>::OutType;
|
||||
using HostAccDataType = typename type_mapping<AccDataType>::OutType;
|
||||
|
||||
ReductionHost<hInType, hCompType, hOutType, ReduceOpId, PropagateNan, NeedIndices>
|
||||
ReductionHost<HostInDataType,
|
||||
HostAccDataType,
|
||||
HostOutDataType,
|
||||
ReduceOpId,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
PropagateNan,
|
||||
NeedIndices>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(alpha,
|
||||
reinterpret_cast<const hInType*>(in.mData.data()),
|
||||
reinterpret_cast<const HostInDataType*>(in.mData.data()),
|
||||
beta,
|
||||
reinterpret_cast<hOutType*>(out_ref.mData.data()),
|
||||
reinterpret_cast<HostOutDataType*>(out_ref.mData.data()),
|
||||
out_indices_ref.mData.data());
|
||||
};
|
||||
|
||||
@@ -363,24 +406,27 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
|
||||
for(auto& reduce_ptr : reduce0_ptrs)
|
||||
{
|
||||
auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths);
|
||||
auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
|
||||
|
||||
DeviceMem ws_dev(wsSizeInBytes);
|
||||
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(
|
||||
i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
InElementwiseOperation_0{static_cast<int32_t>(reduce_total_length)},
|
||||
AccElementwiseOperation_0{static_cast<int32_t>(reduce_total_length)});
|
||||
InElementwiseOperation_0 in_elementwise_op_0(static_cast<int32_t>(reduce_total_length));
|
||||
AccElementwiseOperation_0 acc_elementwise_op_0(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
in_elementwise_op_0,
|
||||
acc_elementwise_op_0);
|
||||
|
||||
if(!reduce_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
continue;
|
||||
@@ -445,24 +491,27 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
|
||||
for(auto& reduce_ptr : reduce1_ptrs)
|
||||
{
|
||||
auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths);
|
||||
auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
|
||||
|
||||
DeviceMem ws_dev(wsSizeInBytes);
|
||||
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(
|
||||
i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
InElementwiseOperation_1{static_cast<int32_t>(reduce_total_length)},
|
||||
AccElementwiseOperation_1{static_cast<int32_t>(reduce_total_length)});
|
||||
InElementwiseOperation_1 in_elementwise_op_1(static_cast<int32_t>(reduce_total_length));
|
||||
AccElementwiseOperation_1 acc_elementwise_op_1(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
in_elementwise_op_1,
|
||||
acc_elementwise_op_1);
|
||||
|
||||
if(!reduce_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
continue;
|
||||
@@ -482,20 +531,25 @@ void profile_reduce_impl_impl(bool do_verification,
|
||||
|
||||
for(auto& reduce2_ptr : reduce2_ptrs)
|
||||
{
|
||||
auto argument2_ptr = reduce2_ptr->MakeArgumentPointer(
|
||||
inLengths2,
|
||||
inStrides2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
InElementwiseOperation_2{static_cast<int32_t>(reduce_total_length)},
|
||||
AccElementwiseOperation_2{static_cast<int32_t>(reduce_total_length)});
|
||||
InElementwiseOperation_2 in_elementwise_op_2(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
AccElementwiseOperation_2 acc_elementwise_op_2(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
auto argument2_ptr =
|
||||
reduce2_ptr->MakeArgumentPointer(inLengths2,
|
||||
inStrides2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
in_elementwise_op_2,
|
||||
acc_elementwise_op_2);
|
||||
|
||||
if(!reduce2_ptr->IsSupportedArgument(argument2_ptr.get()))
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user