mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element
This commit is contained in:
@@ -133,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||
a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||
|
||||
@@ -227,7 +227,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
@@ -251,7 +252,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
|
||||
@@ -402,7 +402,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, BlkSize>,
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize()>
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
Number<M2>{},
|
||||
Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
|
||||
c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
|
||||
@@ -1242,7 +1242,7 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -602,7 +602,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -10,25 +10,25 @@ union BufferResource
|
||||
{
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
int32x4_t data;
|
||||
int32x4_t content;
|
||||
StaticallyIndexedArray<T*, 2> address;
|
||||
StaticallyIndexedArray<int32_t, 4> range;
|
||||
StaticallyIndexedArray<int32_t, 4> config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
||||
{
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
|
||||
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.data;
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
// load
|
||||
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type
|
||||
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave to be a wavewise pointer.
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_v2(const T* p_src_wave,
|
||||
index_t src_thread_data_offset,
|
||||
bool src_thread_data_valid,
|
||||
index_t src_element_space)
|
||||
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space);
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
||||
return amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#else
|
||||
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_data_valid ? tmp : vector_t(0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size,
|
||||
T customized_value)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_dst_wave must be global memory
|
||||
// 2) p_dst_wave to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_data_offset,
|
||||
const bool dst_thread_data_valid,
|
||||
const index_t dst_element_space)
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space);
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_data_valid)
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -6,34 +6,43 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
typename ElementSpaceSize,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
|
||||
T* p_data_;
|
||||
ElementSpaceSize element_space_size_;
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data,
|
||||
ElementSpaceSize element_space_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data},
|
||||
element_space_size_{element_space_size},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
@@ -45,20 +54,41 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
p_data_, i, is_valid_offset, element_space_size_);
|
||||
#else
|
||||
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
#endif
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_return_zero<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_customized_value<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(
|
||||
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_offset ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
|
||||
: X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +97,7 @@ struct DynamicBuffer
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
@@ -84,10 +114,10 @@ struct DynamicBuffer
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_offset, element_space_size_);
|
||||
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
#else
|
||||
if(is_valid_offset)
|
||||
if(is_valid_element)
|
||||
{
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
@@ -95,7 +125,7 @@ struct DynamicBuffer
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
||||
{
|
||||
if(is_valid_offset)
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
@@ -185,7 +215,7 @@ struct DynamicBuffer
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_offset)
|
||||
if(is_valid_element)
|
||||
{
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
@@ -197,12 +227,18 @@ struct DynamicBuffer
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
||||
typename T,
|
||||
typename ElementSpaceSize>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
|
||||
p, element_space_size, invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,30 +5,66 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
|
||||
: base{}, invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? At(i) : T{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? At(i) : invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
At(i) = x;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
||||
typename T,
|
||||
index_t N>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N>{};
|
||||
return StaticBuffer<BufferAddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
|
||||
Reference in New Issue
Block a user