mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
simplified buffer_load/store (#971)
* simplified buffer_load/store * add bfp8/fp8 * fixed * fixed buffer_load * fixed buffer_store --------- Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -299,6 +299,109 @@ enum struct AmdBufferCoherenceEnum
|
||||
GLC_SLC = 3,
|
||||
};
|
||||
|
||||
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type<int8_t, N>::type
|
||||
amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x2_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x4_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x8_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
return bit_cast<int8x16_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp1 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
|
||||
|
||||
return bit_cast<int8x32_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 64)
|
||||
{
|
||||
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp1 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp2 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp3 =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
vector_type<int32_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
|
||||
tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
|
||||
tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
|
||||
|
||||
return bit_cast<int8x64_t>(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
@@ -307,315 +410,116 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, double>::value)
|
||||
using r_t = typename vector_type<T, N>::type;
|
||||
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
|
||||
return bit_cast<r_t>(raw_data);
|
||||
}
|
||||
|
||||
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ void
|
||||
amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
// use fp32 load to mimic fp64 load
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
const float2_t tmp =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<double>(tmp);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
const float4_t tmp =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<double2_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
const float4_t f32_0 =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
const float4_t f32_1 =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
vector_type<double, 4> tmp;
|
||||
|
||||
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
|
||||
tmp.AsType<double2_t>()(Number<1>{}) = bit_cast<double2_t>(f32_1);
|
||||
|
||||
return tmp.AsType<double4_t>()(Number<0>{});
|
||||
}
|
||||
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(is_same<T, float>::value)
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<float, 8> tmp;
|
||||
|
||||
tmp.AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<float4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return tmp.AsType<float8_t>()(Number<0>{});
|
||||
}
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
// use fp32 load to mimic fp16 load
|
||||
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<half8_t>(tmp);
|
||||
}
|
||||
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(is_same<T, bhalf_t>::value)
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<bhalf8_t>(tmp);
|
||||
}
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
return tmp.AsType<int32x8_t>()(Number<0>{});
|
||||
}
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
|
||||
|
||||
return bit_cast<int8x2_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x4_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 8> tmp;
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 4,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 64)
|
||||
{
|
||||
vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 4,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return tmp.AsType<int8x8_t>()(Number<0>{});
|
||||
#else
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 8,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x8_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int8_t),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int8_t),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return tmp.AsType<int8x16_t>()(Number<0>{});
|
||||
#else
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x16_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 12,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -628,255 +532,22 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, double>::value)
|
||||
{
|
||||
// use fp32 store to mimic fp64 store
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<float2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<float, 8> tmp{src_thread_data};
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if 0
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, bhalf_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<bhalf_t, 8> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset);
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
@@ -1127,54 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
return bit_cast<vector_t>(tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
|
||||
#else
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1232,62 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_addr_shift +
|
||||
dst_thread_addr_offset,
|
||||
0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp =
|
||||
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user