mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
GlobalAtomicAdd for fp32/int32 (#23)
* add f32/i32 atomicAdd support into dynamicBuffer, and enable it in v1r3 * fixed * fixed * update comment Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -208,10 +208,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
|
||||
{
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
|
||||
{
|
||||
dst_buf.template AtomicAdd<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
}
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
|
||||
@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
// atomic add
|
||||
// int
|
||||
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
|
||||
|
||||
// float
|
||||
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
@@ -581,6 +597,128 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<float, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(float),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<float, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(float),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(float),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(float),
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<int32_t, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<int32_t, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(int32_t),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(int32_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
@@ -645,7 +783,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_dst_wave must be global memory
|
||||
// 2) p_dst_wave to be a wavewise pointer.
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
@@ -677,5 +815,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_atomic_add requires:
|
||||
// 1) p_dst_wave must be global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -85,8 +85,8 @@
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
|
||||
@@ -223,6 +223,38 @@ struct DynamicBuffer
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
#else
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomicAdd(&p_data_[i], x);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
|
||||
Reference in New Issue
Block a user