diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index 2d729ab10f..7e5c727d15 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -344,64 +344,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer // copy output: register to global memory { -#if 0 - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = - out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, b_thread_data_on_global, 0); - - ThreadwiseGenericTensorSliceCopy_v2r1< - decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), - decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc), - decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 8, 1>::type, - arithmetic_sequence_gen<0, 8, 1>::type, - 7, - 7, - 1, - 1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) - .Run(p_out_thread, p_out_thread_on_global); -#else constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; // define tensor descriptor for threadwise copy @@ -449,7 +391,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer b_thread_data_on_global, 0}) .template Run_amd_experiment(p_out_thread, p_out_global); -#endif } } }; diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp index 011112e49f..0a5b4a3c34 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp @@ -369,31 +369,37 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf // copy output: register to global memory { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - - static_assert(K % (K1 * K2) == 0, "wrong!"); + constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; + constexpr index_t K0 = K / K1; // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc = - make_native_tensor_descriptor_packed( - Sequence{}); + // output memory layout descriptor in register, src of threadwise copy + constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( + Sequence{}); - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc = - reorder_tensor_descriptor_given_upper2lower(out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc, - Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); + // output memory layout descriptor in device memory + constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc_old = + OutGlobalDesc::Fold(I1, Number{}).Fold(I0, Number{}, Number{}); - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(Unmerge>{}, - Unmerge>{}, - PassThrough{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}, Sequence<6>{}, Sequence<7>{})); + constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = make_native_tensor_descriptor( + out_n0_n1_n2_k0_k1_ho_wo_global_desc_old.GetLengths(), + out_n0_n1_n2_k0_k1_ho_wo_global_desc_old.GetStrides()); + + // output merged global tensor descriptor, dst of threadwise copy + constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor( + out_n0_n1_n2_k0_k1_ho_wo_global_desc, + make_tuple(PassThrough{}, + PassThrough{}, + PassThrough{}, + Merge>{}, + PassThrough{}), + make_tuple(Sequence<3>{}, + Sequence<4>{}, + Sequence<1>{}, + Sequence<0, 5, 6>{}, + Sequence<2>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -406,41 +412,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf const index_t b_thread_data_on_global = b_block_data_on_global + c_thread_mtx_on_block.col / N2; - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_n0_n1_n2_k_ho_wo_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(Unmerge>{}, - PassThrough{}, - PassThrough{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto out_k_n1_b_n2_global_desc = transform_tensor_descriptor( - out_n0_n1_n2_k_ho_wo_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Merge>{}, - PassThrough{}), - make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<0, 4, 5>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_b_n2_global_desc.CalculateOffset( - {k_thread_data_on_global, 0, b_thread_data_on_global, 0}); - - ThreadwiseGenericTensorSliceCopy_v4r2< - decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc), - decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc), - decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 8, 1>::type, - 7, - 1, - 1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) - .Run(p_out_thread, p_out_thread_on_global); + ThreadwiseGenericTensorSliceCopy_v4r2::type, + 3, + 1, + 1>({0, 0, 0, 0, 0}, + {k_thread_data_on_global / K1, + k_thread_data_on_global % K1, + 0, + b_thread_data_on_global, + 0}) + .template Run_amd_experiment(p_out_thread, p_out_global); } } }; diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index 0f9976f453..25cd5a819c 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -737,13 +737,23 @@ struct BlockwiseGenericTensorSliceCopy_v4 template __device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const { +#if 0 mThreadwiseLoad.Run(p_src, p_buffer); +#else + // hardcoded: global to register + mThreadwiseLoad.template Run_amd_experiment(p_src, p_buffer); +#endif } template __device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const { +#if 0 mThreadwiseStore.Run(p_buffer, p_dst); +#else + // hardcoded: register to LDS + mThreadwiseStore.template Run_amd_experiment(p_buffer, p_dst); +#endif } template @@ -751,8 +761,8 @@ struct BlockwiseGenericTensorSliceCopy_v4 { TData p_buffer[GetRegisterBufferSize()]; - mThreadwiseLoad.Run(p_src, p_buffer); - mThreadwiseStore.Run(p_buffer, p_dst); + RunLoadRegisterBuffer(p_src, p_buffer); + RunStoreRegisterBuffer(p_buffer, p_dst); } template diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 525b7a04f1..68e6b5bb02 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -1276,6 +1276,107 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 }); } + // memory-space + // 0: VGPR + // 1: LDS + // 2: global-memory + template + __device__ void Run_amd_experiment(const TData* p_src, TData* p_dst) const + { + using src_vector_t = typename vector_type::MemoryType; + using dst_vector_t = typename vector_type::MemoryType; + + constexpr auto vector_access_dim = Number{}; + + constexpr auto src_data_per_access = Number{}; + constexpr auto dst_data_per_access = Number{}; + + constexpr auto long_vector_size = Number{}; + + constexpr auto long_vector_access_lengths = SliceLengths::Modify( + vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); + + ford{}([&]( + auto long_vector_access_id) { + + // data id w.r.t slicing-window + auto long_vector_data_begin_id = long_vector_access_id; + long_vector_data_begin_id(vector_access_dim) = + long_vector_size * long_vector_access_id[vector_access_dim]; + + // buffer to hold a long-vector + TData p_long_vector[long_vector_size]; + + // set 0 + for(index_t i = 0; i < long_vector_size; ++i) + { + p_long_vector[i] = 0; + } + + // load data from src to the long-vector buffer + for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) + { + auto scalar_id = make_zero_array(); + scalar_id(vector_access_dim) = i * src_data_per_access; + + const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); + + // check for padding + // TODO: still kind of messy + if(!src_coord.IsAnyLevelIndexInPaddingArea()) + { + const index_t src_offset = src_coord.GetOffset(); + + const index_t buffer_offset = i * src_data_per_access; + + static_if{}([&](auto) { +#if 0 // source code + *reinterpret_cast(&p_long_vector[buffer_offset]) = + *reinterpret_cast(&p_src[src_offset]); +#elif 1 // inline asm using buffer_load + *reinterpret_cast(&p_long_vector[buffer_offset]) = + __buffer_load( + p_src, static_cast(src_offset), static_cast(0)); +#endif + }).Else([&](auto) { + // src can be all kinds of memory-space. + *reinterpret_cast(&p_long_vector[buffer_offset]) = + *reinterpret_cast(&p_src[src_offset]); + }); + } + } + + // store data from the long-vector buffer to dst + for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i) + { + auto scalar_id = make_zero_array(); + scalar_id(vector_access_dim) = i * dst_data_per_access; + + const index_t buffer_offset = i * dst_data_per_access; + + const index_t dst_offset = + (mDstSliceOrigin + (long_vector_data_begin_id + scalar_id)).GetOffset(); + + static_if{}([&](auto) { +#if 0 // source code + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_long_vector[buffer_offset]); +#elif 1 // inline asm using buffer_store + __buffer_store( + *reinterpret_cast(&p_long_vector[buffer_offset]), + p_dst, + dst_offset, + 0); +#endif + }).Else([&](auto) { + // dst can be all kinds of memory-space + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_long_vector[buffer_offset]); + }); + } + }); + } + template __device__ void MoveSrcSliceWindow(const T& step_sizes_, integral_constant) diff --git a/composable_kernel/include/utility/config_amd.hpp.in b/composable_kernel/include/utility/config_amd.hpp.in index 664d78b86b..1b57256a00 100644 --- a/composable_kernel/include/utility/config_amd.hpp.in +++ b/composable_kernel/include/utility/config_amd.hpp.in @@ -5,7 +5,7 @@ #include "hip/hip_fp16.h" #define CK_DEVICE_BACKEND_AMD 1 -#define CK_USE_UNSIGNED_INDEX_TYPE 1 +#define CK_USE_UNSIGNED_INDEX_TYPE 0 #define CK_USE_AMD_INLINE_ASM 1 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0 diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp index e550581fc4..ac95e09d7e 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded.hpp @@ -51,7 +51,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc, wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 0 +#if 1 // BlockSize = 256, each thread hold 64 data constexpr index_t BlockSize = 256; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 0c2e91a5f8..e9fa7ba38a 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -76,12 +76,12 @@ int main(int argc, char* argv[]) #if 0 constexpr index_t N = 64; - constexpr index_t C = 16; - constexpr index_t HI = 34; - constexpr index_t WI = 34; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr index_t C = 64; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; @@ -103,7 +103,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 1 +#elif 0 // 1x1 filter, 8x8 image // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% constexpr index_t N = 64; @@ -341,7 +341,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>; -#elif 0 +#elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; constexpr index_t C = 128; @@ -434,7 +434,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 1 +#elif 0 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,