mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Unify the naming of the math functions used by the host and kernel (#262)
* Use the unified naming for math functions on host and HIP kernel
* Corresponding change/simplification in reduction host/profiler/examples due to unified math functions renaming
* Renaming GetReductionZeroVal() to GetIdentityValue()
* Tiny renaming in profile_reduce_impl.hpp
* More renaming in profile_reduce_impl.hpp
* Replace zeroVal by identiyVal
* Remove ck_ prefix in the naming of ck::math provided functions
[ROCm/composable_kernel commit: 86185bd7ce]
This commit is contained in:
@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
const auto zeroVal =
|
||||
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
|
||||
const auto identityVal =
|
||||
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
|
||||
OutMemoryDataOperation);
|
||||
|
||||
const auto kernel_pre =
|
||||
@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
0,
|
||||
out_grid_desc_m_2,
|
||||
arg.out_dev_,
|
||||
zeroVal);
|
||||
identityVal);
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
#include "math_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -296,7 +297,7 @@ struct UnaryAbs<float, float>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); };
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); };
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -312,7 +313,7 @@ struct UnaryAbs<double, double>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
y = (x ^ sgn) - sgn;
|
||||
};
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -336,7 +332,7 @@ struct UnarySqrt<float, float>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); };
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -344,7 +340,10 @@ struct UnarySqrt<double, double>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); };
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
{
|
||||
y = ck::math::sqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
|
||||
@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_value_buf(I) = identityVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
AccDataType tmpValue = identityVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
AccDataType tmpValue = identityVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
|
||||
@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_value_buf(I) = identityVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
|
||||
@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
|
||||
const auto d_identityVal = DReduceOperation::GetIdentityValue();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
[&](auto I) { d_thread_buf(I) = d_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
|
||||
#include <cmath>
|
||||
#include "data_type.hpp"
|
||||
#include "half.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace math {
|
||||
|
||||
// math functions for the host, some are implemented by calling C++ std functions
|
||||
|
||||
static inline __host__ float abs(float x) { return std::abs(x); };
|
||||
|
||||
static inline __host__ double abs(double x) { return std::abs(x); };
|
||||
@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
|
||||
|
||||
static inline __host__ half_t abs(half_t x)
|
||||
{
|
||||
half_float::half xx = *reinterpret_cast<half_float::half*>(&x);
|
||||
uint16_t xx = ck::bit_cast<uint16_t>(x);
|
||||
|
||||
half_float::half abs_xx = half_float::abs(xx);
|
||||
uint16_t abs_xx = xx & 0x7fff;
|
||||
|
||||
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx);
|
||||
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
|
||||
|
||||
return abs_x;
|
||||
};
|
||||
|
||||
static inline __host__ float isnan(float x) { return std::isnan(x); };
|
||||
static inline __host__ bool isnan(float x) { return std::isnan(x); };
|
||||
|
||||
static inline __host__ double isnan(double x) { return std::isnan(x); };
|
||||
static inline __host__ bool isnan(double x) { return std::isnan(x); };
|
||||
|
||||
static inline __host__ int8_t isnan(int8_t x)
|
||||
static inline __host__ bool isnan(int8_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
static inline __host__ int32_t isnan(int32_t x)
|
||||
static inline __host__ bool isnan(int32_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
|
||||
|
||||
static inline __host__ bool isnan(half_t x)
|
||||
{
|
||||
half_float::half xx = *reinterpret_cast<half_float::half*>(&x);
|
||||
uint16_t xx = ck::bit_cast<uint16_t>(x);
|
||||
|
||||
return half_float::isnan(xx);
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
|
||||
|
||||
static inline __device__ float abs(float x) { return ::abs(x); };
|
||||
|
||||
static inline __device__ double abs(double x) { return ::abs(x); };
|
||||
|
||||
static inline __device__ int8_t abs(int8_t x)
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
static inline __device__ int32_t abs(int32_t x)
|
||||
{
|
||||
int32_t sgn = x >> (32 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
|
||||
|
||||
static inline __device__ bool isnan(float x) { return ::isnan(x); };
|
||||
|
||||
static inline __device__ bool isnan(double x) { return ::isnan(x); };
|
||||
|
||||
static inline __device__ bool isnan(int8_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
static inline __device__ bool isnan(int32_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
|
||||
|
||||
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
|
||||
|
||||
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "math_v2.hpp"
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
@@ -34,18 +35,6 @@
|
||||
namespace ck {
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
static inline __device__ bool is_nan(T x)
|
||||
{
|
||||
return (isnan(x));
|
||||
};
|
||||
|
||||
template <>
|
||||
inline __device__ bool is_nan<half_t>(half_t x)
|
||||
{
|
||||
return (__hisnan(x));
|
||||
};
|
||||
|
||||
template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck;
|
||||
|
||||
@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
{
|
||||
// cppcheck-suppress constParameter
|
||||
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
{
|
||||
ReduceOperation{}(accuVal, currVal);
|
||||
};
|
||||
@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
|
||||
template <typename ReduceOperation, typename AccDataType>
|
||||
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
|
||||
{
|
||||
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
__host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
|
||||
{
|
||||
if(is_nan(currVal))
|
||||
using ck::math::isnan;
|
||||
|
||||
if(isnan(currVal))
|
||||
{
|
||||
accuVal = currVal;
|
||||
}
|
||||
@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
|
||||
template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
|
||||
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
|
||||
{
|
||||
__device__ static inline void
|
||||
__host__ __device__ static inline void
|
||||
// cppcheck-suppress constParameter
|
||||
Calculate(AccDataType& accuVal,
|
||||
AccDataType currVal,
|
||||
@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
|
||||
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
|
||||
{
|
||||
// The method is called when the ReduceOperation is indexable and the user asked for indices
|
||||
__device__ static inline void Calculate(AccDataType& accuVal,
|
||||
AccDataType currVal,
|
||||
IndexDataType& accuIndex,
|
||||
IndexDataType currIndex)
|
||||
__host__ __device__ static inline void Calculate(AccDataType& accuVal,
|
||||
AccDataType currVal,
|
||||
IndexDataType& accuIndex,
|
||||
IndexDataType currIndex)
|
||||
{
|
||||
if(is_nan(currVal))
|
||||
using ck::math::isnan;
|
||||
|
||||
if(isnan(currVal))
|
||||
{
|
||||
accuVal = currVal;
|
||||
accuIndex = currIndex;
|
||||
|
||||
@@ -36,7 +36,7 @@ namespace reduce {
|
||||
// Every binary operator used in reduction is represented by a templated functor class. Each functor
|
||||
// class must provide at least
|
||||
// three members:
|
||||
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
|
||||
// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
|
||||
// operator, "identity element" is the unique
|
||||
// element in the algebraic space that doesn't affect the value of other elements
|
||||
// when operated against them, and the concept is similar to zero vector in
|
||||
@@ -59,7 +59,7 @@ struct Add
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
@@ -76,7 +76,7 @@ struct Mul
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
|
||||
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
@@ -92,7 +92,7 @@ struct Max
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal()
|
||||
__host__ __device__ static constexpr T GetIdentityValue()
|
||||
{
|
||||
return NumericLimits<T>::Lowest();
|
||||
};
|
||||
@@ -125,10 +125,7 @@ struct Min
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal()
|
||||
{
|
||||
return NumericLimits<T>::Max();
|
||||
};
|
||||
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
@@ -158,7 +155,7 @@ struct AMax
|
||||
{
|
||||
using dataType = T;
|
||||
|
||||
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
|
||||
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
|
||||
|
||||
__device__ static constexpr bool
|
||||
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
@@ -184,7 +181,7 @@ struct AMax
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
|
||||
{
|
||||
T result = ck::type_convert<T>(0.0f);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user