mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
added type conversion in threadwise and blockwise copy
This commit is contained in:
@@ -287,9 +287,9 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
|
||||
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
|
||||
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
@@ -312,8 +312,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
@@ -321,25 +321,27 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_register_buffer);
|
||||
blockwise_in_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
// even iteration
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
@@ -347,19 +349,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_register_buffer);
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
@@ -431,9 +433,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
#if 1
|
||||
.template Run_generic<Float, Float, address_space_t::generic, address_space_t::global>
|
||||
.template Run_generic<Float,
|
||||
Float,
|
||||
address_space_t::generic,
|
||||
address_space_t::global>
|
||||
#elif 1
|
||||
.template Run_optimized_dst_address_calculation<Float, Float, address_space_t::global>
|
||||
.template Run_optimized_dst_address_calculation<Float,
|
||||
Float,
|
||||
address_space_t::global>
|
||||
#endif
|
||||
(p_out_thread, p_out_global);
|
||||
}
|
||||
|
||||
@@ -678,10 +678,10 @@ struct BlockwiseGenericTensorSliceCopy_v3
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename SubLengths,
|
||||
typename BlockSrcDesc,
|
||||
typename BlockDstDesc,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDimAccessOrder,
|
||||
@@ -692,24 +692,49 @@ template <index_t BlockSize,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v4(const Index& src_block_slice_origin,
|
||||
const Index& dst_block_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() &&
|
||||
nDim == SubLengths::Size() && nDim == ThreadClusterLengths::Size() &&
|
||||
static_assert(nDim == BlockSrcDesc::GetNumOfDimension() &&
|
||||
nDim == BlockDstDesc::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
#if 1
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
#else
|
||||
constexpr auto thread_cluster_lengths_in_arrange_order =
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{});
|
||||
|
||||
constexpr auto thread_cluster_desc = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(thread_cluster_lengths_in_arrange_order),
|
||||
make_tuple(Merge<decltype(thread_cluster_lengths_in_arrange_order)>{}),
|
||||
make_tuple(arithmetic)
|
||||
|
||||
::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
constexpr auto thread_cluster_id = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<KBlockWork, BBlockWork>{}),
|
||||
make_tuple(Merge<Sequence<KBlockWork, BBlockWork>>{}),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto block_work_multi_id = block_work_desc.CalculateLowerIndex(get_block_1d_id());
|
||||
#endif
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
@@ -720,7 +745,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
|
||||
const auto thread_data_id_begin = data_cluster_id * ThreadSliceLengths{};
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
@@ -729,51 +754,70 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
__device__ static constexpr index_t GetThreadBufferSize()
|
||||
{
|
||||
return RegisterBufferDesc::GetElementSpace();
|
||||
return ThreadBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename SrcData, typename BufferData, address_space_t SrcAddressSpace = address_space_t::generic>
|
||||
__device__ void RunLoadRegisterBuffer(const SrcData* p_src, BufferData* p_buffer) const
|
||||
template <typename BlockSrcData,
|
||||
typename ThreadBufferData,
|
||||
address_space_t BlockSrcAddressSpace = address_space_t::generic,
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
|
||||
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
#if 1
|
||||
mThreadwiseLoad.template Run_generic<SrcData, BufferData, SrcAddressSpace, address_space_t::generic>(
|
||||
p_src, p_buffer);
|
||||
mThreadwiseLoad.template Run_generic<BlockSrcData,
|
||||
ThreadBufferData,
|
||||
BlockSrcAddressSpace,
|
||||
ThreadBufferAddressSpace>(p_block_src,
|
||||
p_thread_buffer);
|
||||
#else
|
||||
mThreadwiseLoad.template Run_optimized_src_address_calculation<SrcData,
|
||||
BufferData,
|
||||
SrcAddressSpace,
|
||||
address_space_t::generic>(
|
||||
p_src, p_buffer);
|
||||
mThreadwiseLoad.template Run_optimized_src_address_calculation<BlockSrcData,
|
||||
ThreadBufferData,
|
||||
BlockSrcAddressSpace,
|
||||
ThreadBufferAddressSpace>(
|
||||
p_block_src, p_thread_buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename BufferData, typename DstData, address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void RunStoreRegisterBuffer(const BufferData* p_buffer, DstData* p_dst) const
|
||||
template <typename ThreadBufferData,
|
||||
typename BlockDstData,
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic,
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
#if 1
|
||||
mThreadwiseStore.template Run_generic<BufferData, DstData, address_space_t::generic, DstAddressSpace>(
|
||||
p_buffer, p_dst);
|
||||
mThreadwiseStore.template Run_generic<ThreadBufferData,
|
||||
BlockDstData,
|
||||
ThreadBufferAddressSpace,
|
||||
BlockDstAddressSpace>(p_thread_buffer, p_block_dst);
|
||||
#else
|
||||
mThreadwiseStore.template Run_optimized_dst_address_calculation<BufferData,
|
||||
DstData,
|
||||
address_space_t::generic,
|
||||
DstAddressSpace>(p_buffer,
|
||||
p_dst);
|
||||
mThreadwiseStore.template Run_optimized_dst_address_calculation<ThreadBufferData,
|
||||
BlockDstData,
|
||||
ThreadBufferAddressSpace,
|
||||
BlockDstAddressSpace>(
|
||||
p_thread_buffer, p_block_dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t SrcAddressSpace = address_space_t::generic,
|
||||
address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
template <typename BlockSrcData,
|
||||
typename BlockDstData,
|
||||
address_space_t BlockSrcAddressSpace = address_space_t::generic,
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
SrcData p_src_buffer[GetRegisterBufferSize()];
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
RunLoadRegisterBuffer<SrcData, SrcData, SrcAddressSpace>(p_src, p_buffer);
|
||||
RunStoreRegisterBuffer<SrcData, DstData, DstAddressSpace>(p_buffer, p_dst);
|
||||
RunLoadThreadBuffer<BlockSrcData,
|
||||
BlockSrcData,
|
||||
BlockSrcAddressSpace,
|
||||
address_space_t::generic>(p_block_src, p_thread_buffer);
|
||||
RunStoreThreadBuffer<BlockSrcData,
|
||||
BlockDstData,
|
||||
address_space_t::generic,
|
||||
BlockDstAddressSpace>(p_thread_buffer, p_block_dst);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
@@ -793,19 +837,19 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
}
|
||||
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_native_tensor_descriptor_packed(SubLengths{}));
|
||||
using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{}));
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SubLengths,
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<BlockSrcDesc,
|
||||
ThreadBufferDesc,
|
||||
ThreadSliceLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
|
||||
BlockDstDesc,
|
||||
ThreadSliceLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
|
||||
@@ -1180,7 +1180,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename DstData,
|
||||
address_space_t SrcAddressSpace = address_space_t::generic,
|
||||
address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run_generic(const SrcData* p_src, DstData* p_dst) const
|
||||
@@ -1233,7 +1233,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
__buffer_load<SrcData, SrcDataPerAccess>(p_src, src_coord.GetOffset(), 0);
|
||||
__buffer_load<SrcData, SrcDataPerAccess>(
|
||||
p_src, src_coord.GetOffset(), 0);
|
||||
#else
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
|
||||
@@ -1246,12 +1247,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
}
|
||||
}
|
||||
|
||||
// SrcData to DstData conversion
|
||||
// SrcData to DstData conversion
|
||||
DstData p_dst_long_vector[long_vector_size];
|
||||
|
||||
for(index_t i = 0; i < long_vector_size; ++i)
|
||||
for(index_t i = 0; i < long_vector_size; ++i)
|
||||
{
|
||||
p_dst_long_vector[i] = type_convert<DstData>(p_src_long_vector[i]);
|
||||
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
|
||||
}
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
|
||||
@@ -38,11 +38,11 @@ typedef float float4_t __attribute__((ext_vector_type(4)));
|
||||
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
// data type conversion
|
||||
template <class T>
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <class X>
|
||||
__device__ T operator()(X x) const
|
||||
template <typename X>
|
||||
__device__ T operator()(const X& x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user