mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Gemm reduce max (#209)
* [What] Rename the example [Why] Prepare to add unary reduction * Add global oparation to the parameter * Add atomicmax * Fix compile error * Support atomicMax (hip library) * Rename the reduction example * Fix target name * use p_d1_grid as the indicator directly * Prevent performance issue. Let passthrough handle it. * Implement the function template the specialize the float2 * No need to separate into two lines * Remove empty line * add comment * Fix compile error due to merge from develop * make the implementation of atomic_max / atomic_add explicit for each datatype * Refine typo * For future CI test * Fix compiler error in ckProfiler * Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe' * simply use remove_pointer * Rename type and var * Refine example * Modify reducemax example * Fix bug in reduction * Change initialize range * Implement F64 version of atomicMax * Move reduction code together * Add buffer atomic_max * Fix coding style by clang-format * Integrate new api of DeviceGemmReduce_Xdl_CShuffle * Integrate Batch gemm reduction * Fix example * fix example * clean up * Fix batch gemm tensor operation * Fix coding style * Fix template augument * Fix clang format * Keep flexible of different stride for each D tensor * Fix compile error for ckProfiler * Fix typo * [What] Fix naming [Why] Prepare to add out elementop * Add DoutElementOp Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
#include "enable_if.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "generic_memory_space_atomic_add.hpp"
|
||||
#include "generic_memory_space_atomic.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -125,6 +125,10 @@ struct DynamicBuffer
|
||||
{
|
||||
this->template AtomicAdd<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
|
||||
{
|
||||
this->template AtomicMax<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
{
|
||||
auto tmp = this->template Get<X>(i, is_valid_element);
|
||||
@@ -326,6 +330,42 @@ struct DynamicBuffer
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
|
||||
Reference in New Issue
Block a user