mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Unify the naming of the math functions used by the host and kernel (#262)
* Use the unified naming for math functions on host and HIP kernel * Corresponding change/simplification in reduction host/profiler/examples due to unified math functions renaming * Renaming GetReductionZeroVal() to GetIdentityValue() * Tiny renaming in profile_reduce_impl.hpp * More renaming in profile_reduce_impl.hpp * Replace zeroVal by identiyVal * Remove ck_ prefix in the naming of ck::math provided functions
This commit is contained in:
@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
|
||||
if constexpr(use_multiblock)
|
||||
{
|
||||
const auto zeroVal =
|
||||
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
|
||||
const auto identityVal =
|
||||
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
|
||||
OutMemoryDataOperation);
|
||||
|
||||
const auto kernel_pre =
|
||||
@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
|
||||
0,
|
||||
out_grid_desc_m_2,
|
||||
arg.out_dev_,
|
||||
zeroVal);
|
||||
identityVal);
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
#include "math_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -296,7 +297,7 @@ struct UnaryAbs<float, float>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); };
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); };
|
||||
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -312,7 +313,7 @@ struct UnaryAbs<double, double>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
|
||||
{
|
||||
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
y = (x ^ sgn) - sgn;
|
||||
};
|
||||
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
|
||||
};
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -336,7 +332,7 @@ struct UnarySqrt<float, float>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); };
|
||||
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -344,7 +340,10 @@ struct UnarySqrt<double, double>
|
||||
{
|
||||
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
|
||||
|
||||
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); };
|
||||
__host__ __device__ void operator()(double& y, const double& x) const
|
||||
{
|
||||
y = ck::math::sqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
|
||||
@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
AccDataType beta,
|
||||
OutDataType* const __restrict__ p_out_value_global)
|
||||
{
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
|
||||
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_value_buf(I) = identityVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
AccDataType tmpValue = zeroVal;
|
||||
AccDataType tmpValue = identityVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
AccDataType tmpValue = zeroVal;
|
||||
AccDataType tmpValue = identityVal;
|
||||
IndexDataType tmpIndex = 0;
|
||||
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
|
||||
@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
ReduceOperation,
|
||||
PropagateNan>;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
|
||||
|
||||
@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
|
||||
|
||||
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
|
||||
|
||||
@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
|
||||
(void)acc_elementwise_op;
|
||||
|
||||
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
|
||||
const auto identityVal = ReduceOperation::GetIdentityValue();
|
||||
|
||||
const auto in_global_val_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
|
||||
in_grid_desc_m_k.GetElementSpaceSize(),
|
||||
type_convert<InDataType>(zeroVal));
|
||||
type_convert<InDataType>(identityVal));
|
||||
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
accu_value_buf(I) = zeroVal;
|
||||
accu_value_buf(I) = identityVal;
|
||||
accu_index_buf(I) = 0;
|
||||
});
|
||||
|
||||
|
||||
@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
|
||||
const auto d_identityVal = DReduceOperation::GetIdentityValue();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
[&](auto I) { d_thread_buf(I) = d_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
|
||||
Reference in New Issue
Block a user