mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
refactor
This commit is contained in:
@@ -14,7 +14,7 @@ template <index_t BlockSize,
|
||||
class DstAccessOrder,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerRead>
|
||||
struct BlockwiseTensorSliceCopy_generic_v1
|
||||
struct BlockwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
@@ -22,8 +22,8 @@ struct BlockwiseTensorSliceCopy_generic_v1
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__
|
||||
BlockwiseTensorSliceCopy_generic_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
{
|
||||
// check NDim consistent
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
@@ -155,7 +155,7 @@ struct BlockwiseTensorSliceCopy_generic_v1
|
||||
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex(
|
||||
clipboard_data_multi_id_begin); // cannot not constexpr, why?
|
||||
|
||||
threadwise_tensor_slice_copy_generic(SrcDesc{},
|
||||
threadwise_generic_tensor_slice_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_tensor_desc,
|
||||
@@ -193,7 +193,7 @@ struct BlockwiseTensorSliceCopy_generic_v1
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
dst_data_multi_id_begin); // cannot not constexpr, why?
|
||||
|
||||
threadwise_tensor_slice_copy_generic(thread_tensor_desc,
|
||||
threadwise_generic_tensor_slice_copy(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
DstDesc{},
|
||||
@@ -474,7 +474,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
|
||||
map_out_global2thread,
|
||||
Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#else
|
||||
threadwise_tensor_slice_copy_generic(
|
||||
threadwise_generic_tensor_slice_copy(
|
||||
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
|
||||
p_out_thread,
|
||||
make_zero_array<index_t, 10>(),
|
||||
|
||||
@@ -423,7 +423,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
map_out_global2thread,
|
||||
Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#else
|
||||
threadwise_tensor_slice_copy_generic(
|
||||
threadwise_generic_tensor_slice_copy(
|
||||
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
|
||||
p_out_thread,
|
||||
make_zero_array<index_t, 10>(),
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_merged_tensor_slice_op.hip.hpp"
|
||||
#include "blockwise_generic_tensor_slice_op.hip.hpp"
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
|
||||
@@ -123,7 +123,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
const auto blockwise_in_copy = BlockwiseTensorSliceCopy_generic_v1<
|
||||
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
@@ -152,7 +152,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
// this copy operator already have blockwise offset built-in
|
||||
const auto blockwise_wei_copy =
|
||||
#if 0
|
||||
BlockwiseTensorSliceCopy_generic_v1<BlockSize,
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
@@ -318,7 +318,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_tensor_slice_copy_generic(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
threadwise_generic_tensor_slice_copy(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
|
||||
@@ -194,7 +194,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SliceLengths, class DimAccessOrder>
|
||||
__device__ void threadwise_tensor_slice_copy_generic(
|
||||
__device__ void threadwise_generic_tensor_slice_copy(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
|
||||
|
||||
Reference in New Issue
Block a user