mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Unify the naming of the math functions used by the host and kernel (#262)
* Use the unified naming for math functions on host and HIP kernel
* Corresponding change/simplification in reduction host/profiler/examples due to unified math functions renaming
* Renaming GetReductionZeroVal() to GetIdentityValue()
* Tiny renaming in profile_reduce_impl.hpp
* More renaming in profile_reduce_impl.hpp
* Replace zeroVal by identiyVal
* Remove ck_ prefix in the naming of ck::math provided functions
[ROCm/composable_kernel commit: 86185bd7ce]
This commit is contained in:
@@ -147,8 +147,6 @@ class SimpleAppArgs
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
const std::vector<int> reduceDims{0, 1, 2};
|
||||
const std::vector<int> invariantDims{3};
|
||||
|
||||
@@ -254,7 +252,9 @@ int main(int argc, char* argv[])
|
||||
ReductionHost<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
ReduceOpId,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
PropagateNan,
|
||||
|
||||
@@ -108,8 +108,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
const std::vector<size_t> outLengths = {64, 320, 80};
|
||||
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verify = true;
|
||||
@@ -191,7 +189,9 @@ int main(int argc, char* argv[])
|
||||
ReductionHost<InOutDataType,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
ReduceOpId,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
5, // Rank
|
||||
2, // NumReduceDim
|
||||
PropagateNan,
|
||||
|
||||
@@ -8,10 +8,12 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
#include "device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
|
||||
template <typename InDataType,
|
||||
@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
const std::array<ck::index_t, 2>& in_left_pads,
|
||||
const std::array<ck::index_t, 2>& /*in_right_pads*/)
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
|
||||
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using InElementwiseOperation = typename ck::
|
||||
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation = typename ck::
|
||||
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
const InElementwiseOperation in_elementwise_op(divider);
|
||||
const AccElementwiseOperation acc_elementwise_op(divider);
|
||||
|
||||
if constexpr(!OutputIndex)
|
||||
{
|
||||
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
|
||||
using Accumulation =
|
||||
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
auto accuVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
{
|
||||
@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
{
|
||||
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
|
||||
|
||||
PreUnaryOp(currVal);
|
||||
in_elementwise_op(currVal, currVal);
|
||||
|
||||
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
|
||||
Accumulation::Calculate(accuVal, currVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PosUnaryOp(accuVal);
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
|
||||
out(n, c, ho, wo) = accuVal;
|
||||
};
|
||||
@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
}
|
||||
else
|
||||
{
|
||||
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOperation::GetIdentityValue();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
|
||||
IndexDataType currIndex = y * window_spatial_lengths[1] + x;
|
||||
|
||||
PreUnaryOp(currVal);
|
||||
in_elementwise_op(currVal, currVal);
|
||||
|
||||
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
|
||||
opReduce, accuVal, currVal, accuIndex, currIndex);
|
||||
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PosUnaryOp(accuVal);
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
|
||||
out(n, c, ho, wo) = accuVal;
|
||||
out_indices(n, c, ho, wo) = accuIndex;
|
||||
@@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
|
||||
ck::index_t in_right_pad_h,
|
||||
ck::index_t in_right_pad_w)
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
|
||||
InDataType, // InDataType
|
||||
|
||||
@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
bool do_verification;
|
||||
int init_method;
|
||||
bool time_kernel;
|
||||
|
||||
@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
bool do_verification;
|
||||
int init_method;
|
||||
bool time_kernel;
|
||||
|
||||
@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal();
|
||||
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
d_reduce_op(d_acc, c_m_n_host_result(m, n));
|
||||
|
||||
@@ -261,8 +261,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetReductionZeroVal();
|
||||
float d1_acc = d1_reduce_op.GetReductionZeroVal();
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
|
||||
@@ -259,8 +259,8 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetReductionZeroVal();
|
||||
float d1_acc = d1_reduce_op.GetReductionZeroVal();
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
|
||||
@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
auto reduceSumOpInst = ReduceSumOp{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float mean_acc = reduceSumOpInst.GetReductionZeroVal();
|
||||
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal();
|
||||
float mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
float square_mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user