Regulate reduction accumulator operations and Element-wise operations (#274)

* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces

* Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers

* Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers

* Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation

* Use struct-scope operator template instantiation for binary and unary element-wise operations

* Change a few more elementwise operations to use template for operator()

* Tiny correction in Normalize operator

* Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons

* Correction in some examples with regard to using ReduceAccDataType

* Use static_assert for UnaryDivide

* Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly

* Tiny fix with regard to SetWorkSpacePointer()
This commit is contained in:
Qianfeng
2022-06-18 04:10:25 +08:00
committed by GitHub
parent 63cdd92398
commit 1f543bfa79
48 changed files with 891 additions and 837 deletions

View File

@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(

View File

@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation,
PropagateNan>;
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op;
const auto identityVal = ReduceOperation::GetIdentityValue();
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());

View File

@@ -927,7 +927,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetIdentityValue();
const auto d_zeroVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });

View File

@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>;
// Global write Gemm shuffle + reduction
const auto d_identityVal = DReduceOperation::GetIdentityValue();
const auto d_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; });

View File

@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
constexpr auto I0 = Number<0>{};