mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element
This commit is contained in:
@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
|||||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||||
|
|
||||||
// thread A buffer for GEMM
|
// thread A buffer for GEMM
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||||
a_thread_buf;
|
a_thread_buf;
|
||||||
|
|
||||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||||
|
|||||||
@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
// register allocation for output
|
// register allocation for output
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||||
|
true>
|
||||||
c_thread_buf;
|
c_thread_buf;
|
||||||
|
|
||||||
// initialize output thread tensor
|
// initialize output thread tensor
|
||||||
@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
|||||||
// double regsiter buffer for b
|
// double regsiter buffer for b
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||||
|
true>
|
||||||
b_thread_even_buf, b_thread_odd_buf;
|
b_thread_even_buf, b_thread_odd_buf;
|
||||||
|
|
||||||
// LDS double buffer: preload data
|
// LDS double buffer: preload data
|
||||||
|
|||||||
@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
vector_type<FloatAcc, BlkSize>,
|
vector_type<FloatAcc, BlkSize>,
|
||||||
c_mr_nr_blk_desc.GetElementSpaceSize()>
|
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||||
|
true>
|
||||||
c_thread_buf;
|
c_thread_buf;
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
Number<M2>{},
|
Number<M2>{},
|
||||||
Number<1>{}));
|
Number<1>{}));
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
|
||||||
c_blk_buf_;
|
c_blk_buf_;
|
||||||
|
|
||||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||||
|
|||||||
@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
|||||||
|
|
||||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||||
|
|
||||||
SrcCoord src_coord_;
|
SrcCoord src_coord_;
|
||||||
DstCoord dst_coord_;
|
DstCoord dst_coord_;
|
||||||
|
|||||||
@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
|||||||
|
|
||||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||||
|
|
||||||
SrcCoord src_coord_;
|
SrcCoord src_coord_;
|
||||||
DstCoord dst_coord_;
|
DstCoord dst_coord_;
|
||||||
|
|||||||
@@ -10,25 +10,25 @@ union BufferResource
|
|||||||
{
|
{
|
||||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||||
int32x4_t data;
|
int32x4_t content;
|
||||||
StaticallyIndexedArray<T*, 2> address;
|
StaticallyIndexedArray<T*, 2> address;
|
||||||
StaticallyIndexedArray<int32_t, 4> range;
|
StaticallyIndexedArray<int32_t, 4> range;
|
||||||
StaticallyIndexedArray<int32_t, 4> config;
|
StaticallyIndexedArray<int32_t, 4> config;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
|
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
||||||
{
|
{
|
||||||
BufferResource<T> wave_buffer_resource;
|
BufferResource<T> wave_buffer_resource;
|
||||||
|
|
||||||
// wavewise base address (64 bit)
|
// wavewise base address (64 bit)
|
||||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||||
// wavewise range (32 bit)
|
// wavewise range (32 bit)
|
||||||
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
|
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
||||||
// wavewise setting (32 bit)
|
// wavewise setting (32 bit)
|
||||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||||
|
|
||||||
return wave_buffer_resource.data;
|
return wave_buffer_resource.content;
|
||||||
}
|
}
|
||||||
|
|
||||||
// load
|
// load
|
||||||
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
|||||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||||
|
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ typename vector_type<T, N>::type
|
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||||
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
index_t src_thread_addr_offset,
|
||||||
index_t src_thread_addr_offset,
|
index_t src_wave_addr_offset)
|
||||||
index_t src_wave_addr_offset)
|
|
||||||
{
|
{
|
||||||
static_assert(
|
static_assert(
|
||||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||||
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
|
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||||
int32x4_t dst_wave_buffer_resource,
|
int32x4_t dst_wave_buffer_resource,
|
||||||
index_t dst_thread_addr_offset,
|
index_t dst_thread_addr_offset,
|
||||||
index_t dst_wave_addr_offset)
|
index_t dst_wave_addr_offset)
|
||||||
{
|
{
|
||||||
static_assert(
|
static_assert(
|
||||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||||
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
|||||||
|
|
||||||
// buffer_load requires:
|
// buffer_load requires:
|
||||||
// 1) p_src_wave must be in global memory space
|
// 1) p_src_wave must be in global memory space
|
||||||
// 2) p_src_wave to be a wavewise pointer.
|
// 2) p_src_wave must be a wavewise pointer.
|
||||||
// It is user's responsibility to make sure that is true.
|
// It is user's responsibility to make sure that is true.
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ typename vector_type_maker<T, N>::type::type
|
__device__ typename vector_type_maker<T, N>::type::type
|
||||||
amd_buffer_load_v2(const T* p_src_wave,
|
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
|
||||||
index_t src_thread_data_offset,
|
index_t src_thread_element_offset,
|
||||||
bool src_thread_data_valid,
|
bool src_thread_element_valid,
|
||||||
index_t src_element_space)
|
index_t src_element_space_size)
|
||||||
{
|
{
|
||||||
const int32x4_t src_wave_buffer_resource =
|
const int32x4_t src_wave_buffer_resource =
|
||||||
make_wave_buffer_resource(p_src_wave, src_element_space);
|
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||||
|
|
||||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
|
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||||
|
|
||||||
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||||
|
using scalar_t = typename scalar_type<vector_t>::type;
|
||||||
|
|
||||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
||||||
using scalar_t = typename scalar_type<vector_t>::type;
|
|
||||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||||
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
|
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
|
||||||
|
|
||||||
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
return amd_buffer_load_impl<scalar_t, vector_size>(
|
||||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||||
#else
|
#else
|
||||||
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||||
|
|
||||||
return src_thread_data_valid ? tmp : vector_t(0);
|
return src_thread_element_valid ? tmp : vector_t(0);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buffer_load requires:
|
||||||
|
// 1) p_src_wave must be in global memory space
|
||||||
|
// 2) p_src_wave must be a wavewise pointer.
|
||||||
|
// It is user's responsibility to make sure that is true.
|
||||||
|
template <typename T, index_t N>
|
||||||
|
__device__ typename vector_type_maker<T, N>::type::type
|
||||||
|
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||||
|
index_t src_thread_element_offset,
|
||||||
|
bool src_thread_element_valid,
|
||||||
|
index_t src_element_space_size,
|
||||||
|
T customized_value)
|
||||||
|
{
|
||||||
|
const int32x4_t src_wave_buffer_resource =
|
||||||
|
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||||
|
|
||||||
|
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||||
|
|
||||||
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||||
|
using scalar_t = typename scalar_type<vector_t>::type;
|
||||||
|
|
||||||
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||||
|
|
||||||
|
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||||
|
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||||
|
|
||||||
|
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||||
|
}
|
||||||
|
|
||||||
// buffer_store requires:
|
// buffer_store requires:
|
||||||
// 1) p_dst_wave must be global memory
|
// 1) p_dst_wave must be global memory
|
||||||
// 2) p_dst_wave to be a wavewise pointer.
|
// 2) p_dst_wave to be a wavewise pointer.
|
||||||
// It is user's responsibility to make sure that is true.
|
// It is user's responsibility to make sure that is true.
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ void
|
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||||
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
T* p_dst_wave,
|
||||||
T* p_dst_wave,
|
const index_t dst_thread_element_offset,
|
||||||
const index_t dst_thread_data_offset,
|
const bool dst_thread_element_valid,
|
||||||
const bool dst_thread_data_valid,
|
const index_t dst_element_space_size)
|
||||||
const index_t dst_element_space)
|
|
||||||
{
|
{
|
||||||
const int32x4_t dst_wave_buffer_resource =
|
const int32x4_t dst_wave_buffer_resource =
|
||||||
make_wave_buffer_resource(p_dst_wave, dst_element_space);
|
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||||
|
|
||||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
|
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||||
|
|
||||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||||
using scalar_t = typename scalar_type<vector_t>::type;
|
using scalar_t = typename scalar_type<vector_t>::type;
|
||||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||||
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
|
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||||
|
|
||||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||||
#else
|
#else
|
||||||
if(dst_thread_data_valid)
|
if(dst_thread_element_valid)
|
||||||
{
|
{
|
||||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -6,34 +6,43 @@
|
|||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||||
|
typename T,
|
||||||
|
typename ElementSpaceSize,
|
||||||
|
bool InvalidElementUseNumericalZeroValue>
|
||||||
struct DynamicBuffer
|
struct DynamicBuffer
|
||||||
{
|
{
|
||||||
using type = T;
|
using type = T;
|
||||||
|
|
||||||
T* p_data_;
|
T* p_data_;
|
||||||
ElementSpaceSize element_space_size_;
|
ElementSpaceSize element_space_size_;
|
||||||
|
T invalid_element_value_ = T{0};
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__host__ __device__ constexpr DynamicBuffer(T* p_data,
|
||||||
|
ElementSpaceSize element_space_size,
|
||||||
|
T invalid_element_value)
|
||||||
|
: p_data_{p_data},
|
||||||
|
element_space_size_{element_space_size},
|
||||||
|
invalid_element_value_{invalid_element_value}
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||||
{
|
{
|
||||||
return BufferAddressSpace;
|
return BufferAddressSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
|
||||||
|
|
||||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
|
||||||
|
|
||||||
template <typename X,
|
template <typename X,
|
||||||
typename std::enable_if<
|
typename std::enable_if<
|
||||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
|
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||||
{
|
{
|
||||||
// X contains multiple T
|
// X contains multiple T
|
||||||
constexpr index_t scalar_per_t_vector =
|
constexpr index_t scalar_per_t_vector =
|
||||||
@@ -45,20 +54,41 @@ struct DynamicBuffer
|
|||||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||||
"wrong! X need to be multiple T");
|
"wrong! X need to be multiple T");
|
||||||
|
|
||||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
|
||||||
{
|
|
||||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||||
|
bool constexpr use_amd_buffer_addressing = true;
|
||||||
|
#else
|
||||||
|
bool constexpr use_amd_buffer_addressing = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
|
||||||
|
{
|
||||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||||
|
|
||||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||||
p_data_, i, is_valid_offset, element_space_size_);
|
{
|
||||||
#else
|
return amd_buffer_load_invalid_element_return_return_zero<
|
||||||
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
remove_cv_t<remove_reference_t<T>>,
|
||||||
#endif
|
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return amd_buffer_load_invalid_element_return_customized_value<
|
||||||
|
remove_cv_t<remove_reference_t<T>>,
|
||||||
|
t_per_x>(
|
||||||
|
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||||
|
{
|
||||||
|
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
|
||||||
|
: X{invalid_element_value_};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,7 +97,7 @@ struct DynamicBuffer
|
|||||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
|
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||||
{
|
{
|
||||||
// X contains multiple T
|
// X contains multiple T
|
||||||
constexpr index_t scalar_per_t_vector =
|
constexpr index_t scalar_per_t_vector =
|
||||||
@@ -84,10 +114,10 @@ struct DynamicBuffer
|
|||||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||||
|
|
||||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||||
x, p_data_, i, is_valid_offset, element_space_size_);
|
x, p_data_, i, is_valid_element, element_space_size_);
|
||||||
#else
|
#else
|
||||||
if(is_valid_offset)
|
if(is_valid_element)
|
||||||
{
|
{
|
||||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
}
|
}
|
||||||
@@ -95,7 +125,7 @@ struct DynamicBuffer
|
|||||||
}
|
}
|
||||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
||||||
{
|
{
|
||||||
if(is_valid_offset)
|
if(is_valid_element)
|
||||||
{
|
{
|
||||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
@@ -185,7 +215,7 @@ struct DynamicBuffer
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if(is_valid_offset)
|
if(is_valid_element)
|
||||||
{
|
{
|
||||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
}
|
}
|
||||||
@@ -197,12 +227,18 @@ struct DynamicBuffer
|
|||||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||||
typename T,
|
|
||||||
typename ElementSpaceSize>
|
|
||||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||||
{
|
{
|
||||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
|
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||||
|
__host__ __device__ constexpr auto
|
||||||
|
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
|
||||||
|
{
|
||||||
|
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
|
||||||
|
p, element_space_size, invalid_element_value};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
@@ -5,30 +5,66 @@
|
|||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||||
|
typename T,
|
||||||
|
index_t N,
|
||||||
|
bool InvalidElementUseNumericalZeroValue>
|
||||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||||
{
|
{
|
||||||
using type = T;
|
using type = T;
|
||||||
using base = StaticallyIndexedArray<T, N>;
|
using base = StaticallyIndexedArray<T, N>;
|
||||||
|
|
||||||
|
T invalid_element_value_ = T{0};
|
||||||
|
|
||||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||||
|
|
||||||
|
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
|
||||||
|
: base{}, invalid_element_value_{invalid_element_value}
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||||
{
|
{
|
||||||
return BufferAddressSpace;
|
return BufferAddressSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <index_t I>
|
||||||
|
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
|
||||||
|
{
|
||||||
|
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||||
|
{
|
||||||
|
return is_valid_element ? At(i) : T{0};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return is_valid_element ? At(i) : invalid_element_value_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <index_t I>
|
||||||
|
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
|
||||||
|
{
|
||||||
|
if(is_valid_element)
|
||||||
|
{
|
||||||
|
At(i) = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||||
|
|
||||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||||
typename T,
|
|
||||||
index_t N>
|
|
||||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||||
{
|
{
|
||||||
return StaticBuffer<BufferAddressSpace, T, N>{};
|
return StaticBuffer<BufferAddressSpace, T, N, true>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||||
|
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
|
||||||
|
{
|
||||||
|
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
|
|
||||||
#define USE_MODE 1
|
#define USE_MODE 1
|
||||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||||
|
|||||||
Reference in New Issue
Block a user