mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
* [What] Rename the example [Why] Prepare to add unary reduction * Add global oparation to the parameter * Add atomicmax * Fix compile error * Support atomicMax (hip library) * Rename the reduction example * Fix target name * use p_d1_grid as the indicator directly * Prevent performance issue. Let passthrough handle it. * Implement the function template the specialize the float2 * No need to separate into two lines * Remove empty line * add comment * Fix compile error due to merge from develop * make the implementation of atomic_max / atomic_add explicit for each datatype * Refine typo * For future CI test * Fix compiler error in ckProfiler * Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe' * simply use remove_pointer * Rename type and var * Refine example * Modify reducemax example * Fix bug in reduction * Change initialize range * Implement F64 version of atomicMax * Move reduction code together * Add buffer atomic_max * Fix coding style by clang-format * Integrate new api of DeviceGemmReduce_Xdl_CShuffle * Integrate Batch gemm reduction * Fix example * fix example * clean up * Fix batch gemm tensor operation * Fix coding style * Fix template augument * Fix clang format * Keep flexible of different stride for each D tensor * Fix compile error for ckProfiler * Fix typo * [What] Fix naming [Why] Prepare to add out elementop * Add DoutElementOp Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: rocking <chunylai@amd.com>
98 lines
2.6 KiB
C++
98 lines
2.6 KiB
C++
#pragma once
|
|
#include "data_type.hpp"
|
|
|
|
namespace ck {
|
|
|
|
// Caution: DO NOT REMOVE
|
|
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
|
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
|
// each datatype.
|
|
template <typename X>
|
|
__device__ X atomic_add(X* p_dst, const X& x);
|
|
|
|
template <>
|
|
__device__ int32_t atomic_add<int32_t>(int32_t* p_dst, const int32_t& x)
|
|
{
|
|
return atomicAdd(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ uint32_t atomic_add<uint32_t>(uint32_t* p_dst, const uint32_t& x)
|
|
{
|
|
return atomicAdd(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ float atomic_add<float>(float* p_dst, const float& x)
|
|
{
|
|
return atomicAdd(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
|
{
|
|
constexpr auto I0 = Number<0>{};
|
|
constexpr auto I1 = Number<1>{};
|
|
|
|
const vector_type<float, 2> vx{x};
|
|
vector_type<float, 2> vy{0};
|
|
|
|
vy.template AsType<float>()(I0) =
|
|
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
|
vy.template AsType<float>()(I1) =
|
|
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
|
|
|
return vy.template AsType<float2_t>()[I0];
|
|
}
|
|
|
|
// Caution: DO NOT REMOVE
|
|
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
|
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
|
|
// each datatype.
|
|
|
|
template <typename X>
|
|
__device__ X atomic_max(X* p_dst, const X& x);
|
|
|
|
template <>
|
|
__device__ int32_t atomic_max<int32_t>(int32_t* p_dst, const int32_t& x)
|
|
{
|
|
return atomicMax(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ uint32_t atomic_max<uint32_t>(uint32_t* p_dst, const uint32_t& x)
|
|
{
|
|
return atomicMax(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ float atomic_max<float>(float* p_dst, const float& x)
|
|
{
|
|
return atomicMax(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ double atomic_max<double>(double* p_dst, const double& x)
|
|
{
|
|
return atomicMax(p_dst, x);
|
|
}
|
|
|
|
template <>
|
|
__device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
|
|
{
|
|
constexpr auto I0 = Number<0>{};
|
|
constexpr auto I1 = Number<1>{};
|
|
|
|
const vector_type<float, 2> vx{x};
|
|
vector_type<float, 2> vy{0};
|
|
|
|
vy.template AsType<float>()(I0) =
|
|
atomicMax(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
|
vy.template AsType<float>()(I1) =
|
|
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
|
|
|
return vy.template AsType<float2_t>()[I0];
|
|
}
|
|
|
|
} // namespace ck
|