mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Regulate reduction accumulator operations and Element-wise operations (#274)
* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces * Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers * Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers * Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation * Use struct-scope operator template instantiation for binary and unary element-wise operations * Change a few more elementwise operations to use template for operator() * Tiny correction in Normalize operator * Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons * Correction in some examples with regard to using ReduceAccDataType * Use static_assert for UnaryDivide * Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly * Tiny fix with regard to SetWorkSpacePointer()
This commit is contained in:
@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
ReduceOperation::template GetIdentityValue<InDataType>());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
ReduceOperation::template GetIdentityValue<InDataType>());
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
|
||||
@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
ReduceOperation::template GetIdentityValue<InDataType>());
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
ReduceOperation::template GetIdentityValue<InDataType>());
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
|
||||
@@ -927,7 +927,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal = DReduceOperation::GetIdentityValue();
|
||||
const auto d_zeroVal =
|
||||
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
|
||||
@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_identityVal = DReduceOperation::GetIdentityValue();
|
||||
const auto d_identityVal =
|
||||
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_identityVal; });
|
||||
|
||||
@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
|
||||
|
||||
{
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>;
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user