mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
support dynamic buffer using memory coherence glc_slc bit from template (#725)
This commit is contained in:
@@ -286,7 +286,22 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
template <typename T, index_t N>
|
||||
// memory coherency bit for buffer store/load instruction
|
||||
// check ISA manual for each GFX target
|
||||
// e.g. for
|
||||
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
|
||||
// page 67~68
|
||||
enum struct AmdBufferCoherenceEnum
|
||||
{
|
||||
DefaultCoherence = 0, // default value
|
||||
GLC = 1,
|
||||
SLC = 2,
|
||||
GLC_SLC = 3,
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
@@ -305,28 +320,37 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
// use fp32 load to mimic fp64 load
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
const float2_t tmp =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<double>(tmp);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
const float4_t tmp =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<double2_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
const float4_t f32_0 =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
const float4_t f32_1 =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
vector_type<double, 4> tmp;
|
||||
|
||||
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
|
||||
@@ -339,31 +363,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<float, 8> tmp;
|
||||
|
||||
tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
tmp.AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<float4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return tmp.AsType<float8_t>()(Number<0>{});
|
||||
}
|
||||
@@ -372,24 +405,32 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
// use fp32 load to mimic fp16 load
|
||||
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<half8_t>(tmp);
|
||||
}
|
||||
@@ -398,23 +439,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<bhalf8_t>(tmp);
|
||||
}
|
||||
@@ -423,31 +472,40 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
return tmp.AsType<int32x8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
@@ -455,17 +513,23 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i8(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return llvm_amdgcn_raw_buffer_load_i8x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x2_t>(tmp);
|
||||
#endif
|
||||
@@ -473,11 +537,15 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x4_t>(tmp);
|
||||
#endif
|
||||
@@ -487,19 +555,24 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return tmp.AsType<int8x8_t>()(Number<0>{});
|
||||
#else
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x8_t>(tmp);
|
||||
#endif
|
||||
@@ -509,31 +582,36 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int8_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int8_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return tmp.AsType<int8x16_t>()(Number<0>{});
|
||||
#else
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x16_t>(tmp);
|
||||
#endif
|
||||
@@ -541,7 +619,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
@@ -565,7 +645,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
@@ -573,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, float>::value)
|
||||
@@ -584,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
@@ -592,7 +672,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
@@ -600,7 +680,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
@@ -611,7 +691,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
@@ -619,7 +699,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
@@ -627,7 +707,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
@@ -638,19 +718,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -662,7 +742,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
@@ -670,7 +750,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
@@ -678,7 +758,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
@@ -688,13 +768,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
@@ -705,7 +785,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
@@ -713,7 +793,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
@@ -721,7 +801,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
@@ -732,7 +812,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
@@ -741,13 +821,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
@@ -757,13 +837,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
@@ -772,7 +852,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
@@ -780,7 +860,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1012,7 +1092,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::typ
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
@@ -1032,10 +1114,10 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
return amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#else
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
@@ -1046,7 +1128,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
@@ -1064,7 +1148,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||
@@ -1074,7 +1158,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
@@ -1093,12 +1179,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user