diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index 8f6a5a3e43..9cd51c61d6 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, { // sanity check { + static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() && + NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), + "wrong! inconsitent number of transform"); + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, NewLowerDimensionOldVisibleIdss{}); diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp index b2722bf078..aa37fc32f1 100644 --- a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp @@ -33,7 +33,8 @@ template + bool ThreadTransferDstResetCoordinateAfterRun, + index_t NumThreadScratch = 1> struct BlockwiseTensorSliceTransfer_v4r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); @@ -86,45 +87,39 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); } } - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf); + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); } } - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunWrite(dst_desc, dst_buf); - } - } - - template + template __device__ void Run(const SrcDesc& src_desc, const SrcBuffer& src_buf, const DstDesc& dst_desc, - DstBuffer& dst_buf) + DstBuffer& dst_buf, + Number thread_scratch_id) { - RunRead(src_desc, src_buf); - RunWrite(dst_desc, dst_buf); + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); } __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) @@ -136,21 +131,6 @@ struct BlockwiseTensorSliceTransfer_v4r1 } } - // SrcMoveSliceWindowStepHack to control index calculation move slice window - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& step, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_step_hack); - } - } - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { if(BlockSize == thread_cluster_desc_.GetElementSize() or @@ -182,7 +162,8 @@ struct BlockwiseTensorSliceTransfer_v4r1 SrcScalarStrideInVector, DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, - ThreadTransferDstResetCoordinateAfterRun>; + ThreadTransferDstResetCoordinateAfterRun, + NumThreadScratch>; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp new file mode 100644 index 0000000000..dcacd99ae1 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_pipeline_v1.hpp @@ -0,0 +1,325 @@ +#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP +#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP + +#include "common_header.hpp" + +namespace ck { + +template +struct GridwiseGemmPipeline_v1; + +// 1-stage prefetch +template +struct GridwiseGemmPipeline_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { +#if 0 + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#else + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } +#endif + } +}; + +// 2-stage prefetch +template +struct GridwiseGemmPipeline_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + // Write num_loop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 751015e6b2..47622ad148 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -8,6 +8,7 @@ #include "blockwise_gemm_xdlops.hpp" #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -21,7 +22,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -40,17 +41,17 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); } template + index_t CThreadTransferDstScalarPerVector, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; @@ -194,6 +196,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -219,9 +240,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } @@ -316,7 +338,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -381,7 +403,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -411,7 +434,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -455,51 +479,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // output: register to global memory { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp deleted file mode 100644 index b4d7ef7d84..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp +++ /dev/null @@ -1,635 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R5_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R5_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer_v1r4.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r5( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_c1_grid, - p_shared_block, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - // TODO fix this - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; - } - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{})); - - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c1_grid_buf = make_dynamic_buffer( - p_c1_grid, c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r4, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_buf, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_grid_buf); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp deleted file mode 100644 index 7d6c86f516..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r6.hpp +++ /dev/null @@ -1,617 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP -#define CK_GRIDWISE_GEMM_XDLOPS_V2R6_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "threadwise_tensor_slice_transfer_v1r5.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_xdlops_v2r6( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_c0_grid, - p_shared_block, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); -} - -template -struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) - { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) - return false; - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) - { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; - - return has_main_k0_block_loop; - } - - // TODO fix this - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N_any& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; - } - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - auto c0_grid_buf = make_dynamic_buffer( - p_c0_grid, c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - - // divide block work by [M, N] - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_grid_desc_k0_m_k1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_k0_m_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0_n_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[K0PerBlock, MPerBlock] is in LDS - // b_mtx[K0PerBlock, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - - auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block = p_shared_block; - FloatAB* p_b_block = p_shared_block + a_block_space_size; - - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - - auto a_block_buf = make_dynamic_buffer( - p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); - - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r5, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c0_grid_buf); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp index 14d8b10b3d..336617d9d4 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp @@ -9,6 +9,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -22,7 +23,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -42,7 +43,7 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -95,7 +96,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 { static constexpr auto I0 = Number<0>{}; @@ -228,6 +230,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -253,9 +274,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } @@ -329,7 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -397,7 +419,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -427,7 +450,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -471,51 +495,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp index c566dc046f..588c16d01b 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r2.hpp @@ -9,6 +9,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v6r2.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -23,7 +24,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -46,7 +47,7 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -102,7 +103,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 { static constexpr auto I0 = Number<0>{}; @@ -235,6 +237,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -260,9 +281,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } @@ -342,7 +364,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -417,7 +439,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 1, 1, AThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -447,7 +470,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 1, 1, BThreadTransferSrcResetCoordinateAfterRun, - true>( + true, + NumPrefetch>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -491,51 +515,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp index 337550819a..3f8b74f544 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp @@ -9,6 +9,7 @@ #include "blockwise_tensor_slice_transfer_v4r1.hpp" #include "blockwise_tensor_slice_transfer_v6r3.hpp" #include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" namespace ck { @@ -24,7 +25,7 @@ template + bool HasMainK0BlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -50,7 +51,7 @@ __global__ void { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -109,7 +110,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 { static constexpr auto I0 = Number<0>{}; @@ -242,6 +244,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; + // check NumPrefetch + if constexpr(NumPrefetch == 1) + { + // 1-stage prefetch always supported + } + else if constexpr(NumPrefetch == 2) + { + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % 2 == 0)) + { + return false; + } + } + else + { + return false; + } + // check M01, N01 constexpr auto M1 = Number{}; constexpr auto N1 = Number{}; @@ -267,9 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 return grid_size; } + // TODO move this function into GEMM-pipeline class __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; return has_main_k0_block_loop; } @@ -354,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -510,51 +532,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumPrefetch, + HasMainK0BlockLoop>{}; - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - } + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc_k0_n_k1, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp index 438f925306..b20b391196 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp @@ -64,9 +64,10 @@ template // control whether to move back dst coordinate after each + bool DstResetCoordinateAfterRun, // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation + index_t NumThreadScratch = 1> struct ThreadwiseTensorSliceTransfer_v3r1 { static constexpr index_t nDim = SliceLengths::Size(); @@ -78,6 +79,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + static constexpr auto I0 = Number<0>{}; + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -102,9 +105,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -114,9 +118,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer and SrcData data type are inconsistent"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( @@ -138,8 +139,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - src_desc, forward_step_idx, src_step_hacks[I0][i]); + return make_tensor_coordinate_step(src_desc, forward_step_idx); }, Number{}); @@ -152,8 +152,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - src_desc, backward_step_idx, src_step_hacks[I1][i]); + return make_tensor_coordinate_step(src_desc, backward_step_idx); }, Number{}); @@ -215,8 +214,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); // copy data from src_vector_container into src_thread_scratch_ - src_thread_scratch_.template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); constexpr auto move_on_dim = [&]() constexpr { @@ -263,12 +263,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } - __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) { #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE static_ford{}([&](auto idx) { // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple[thread_scratch_id][idx]); }); #else // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ @@ -318,7 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 const auto src_vector_refs = generate_tie( [&](auto i) -> const src_vector_t& { // i increment corresponds to movement in DstVectorDim - return src_thread_scratch_.GetVectorTypeReference( + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( data_idx_seq + i * dst_scalar_step_in_vector); }, Number{}); @@ -342,19 +345,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { static_ford{}([&](auto idx) { // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); }); } #endif } - template - __device__ void - RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) { // if there is transpose, it's done here // TODO move this elsewhere - TransferDataFromSrcThreadScratchToDstThreadScratch(); + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -364,9 +369,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_same, remove_cvref_t>::value, "wrong! SrcBuffer or DstBuffer data type is wrong"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( @@ -388,8 +390,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + return make_tensor_coordinate_step(dst_desc, forward_step_idx); }, Number{}); @@ -402,8 +403,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; }); - return make_tensor_coordinate_step( - dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + return make_tensor_coordinate_step(dst_desc, backward_step_idx); }, Number{}); @@ -515,39 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - constexpr index_t ntransform_src = remove_cvref_t::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto src_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - RunRead(src_desc, src_buf, src_step_hacks); - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) - { - // TODO: why need remove_cvref_t ? - constexpr index_t ntransform_dst = remove_cvref_t::GetNumOfTransform(); - - constexpr auto zeros = typename uniform_sequence_gen::type{}; - - constexpr auto dst_step_hacks = - make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), - generate_tuple([&](auto) { return zeros; }, Number{})); - - RunWrite(dst_desc, dst_buf, dst_step_hacks); - } - __device__ static constexpr auto GetSrcCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( @@ -606,8 +575,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetDstCoordinateResetStep() { - constexpr auto I0 = Number<0>{}; - // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( @@ -679,25 +646,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template - __device__ void - MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step( - src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) @@ -815,19 +763,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - StaticTensorTupleOfVectorBuffer - src_thread_scratch_; + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; - StaticTensorTupleOfVectorBuffer - dst_thread_scratch_; + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index 5872b69b99..c54f4e4d92 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -26,6 +26,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index da047a5140..71a2e088fe 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -52,7 +52,8 @@ template + ck::index_t CThreadTransferDstScalarPerVector, + ck::index_t NumPrefetch = 1> struct DeviceGemmXdl : public DeviceGemm { @@ -218,7 +219,8 @@ struct DeviceGemmXdl BBlockLdsAddExtraN, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector>; + CThreadTransferDstScalarPerVector, + NumPrefetch>; // Argument struct Argument : public BaseArgument @@ -494,7 +496,12 @@ struct DeviceGemmXdl << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ">"; // clang-format on diff --git a/device_operation/include/device_gemm_xdl_c_shuffle.hpp b/device_operation/include/device_gemm_xdl_c_shuffle.hpp index 9aa1ab158d..faabd1a8ae 100644 --- a/device_operation/include/device_gemm_xdl_c_shuffle.hpp +++ b/device_operation/include/device_gemm_xdl_c_shuffle.hpp @@ -4,9 +4,7 @@ #include #include #include "device.hpp" -#include "device_base.hpp" #include "device_gemm.hpp" -#include "device_gemm_xdl.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" @@ -54,7 +52,8 @@ template < index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumPrefetch = 1> struct DeviceGemmXdl_C_Shuffle : public DeviceGemm { @@ -174,7 +173,8 @@ struct DeviceGemmXdl_C_Shuffle CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl>; + CBlockTransferScalarPerVector_NWaveNPerXdl, + NumPrefetch>; // Argument struct Argument : public BaseArgument diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..ee25f2ba40 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 41479f60e8..42b20fe21f 100644 --- a/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -26,23 +26,36 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +// irregular tile size +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> // clang-format on >; @@ -50,6 +63,8 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); } } // namespace device_gemm_instance diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 5d289f40e8..fd2f79fb09 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -11,13 +11,23 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" +#include "gemm_specialization.hpp" template using S = ck::Sequence; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ADataType = ck::half_t; using BDataType = ck::half_t; using CDataType = ck::half_t; @@ -31,45 +41,56 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; + // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +#if 1 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// [256, 128, 4, 8], 1 stage, 2 occupancy + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>; +#elif 0 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size +// 99 TFlops, 120 blocks (1024x2160x3840) +// 99 TFlops, 960 blocks (4096x4320x3840) + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [128, 144, 4, 8], 1 stage, 2 occupancy, +// 92 TFlops, 120 blocks (1024x2160x3840) +// 120 TFlops, 240 blocks (1024x4320x3840) +// 128 TFlops, 960 blocks (4096x4320x3840) +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [ 64, 144, 8, 8], 1 stage, 2 occupancy/ +// 96 TFlops, 240 blocks (1024x2160x3840) +// 96 TFlops, 480 blocks (1024x4320x3840) +// 99 TFlops,1920 blocks (4096x4320x3840) +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [ 64, 144, 8, 8], 2 stage, 2 occupancy +// 93 TFlops +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 8, 8, 16, 16, 1, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; +// [ 64, 144, 4, 8], 1 stage, 2 occupancy +// 87 TFlops +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1>; +// [ 64, 144, 4, 8], 2 stage, 2 occupancy +// 85 TFlops +// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>; +#elif 1 +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size + < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8, 1>; +#endif // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: @@ -198,8 +219,8 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); diff --git a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp index 26d3ea3f74..4f255fda9d 100644 --- a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp +++ b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp @@ -12,6 +12,7 @@ #include "device_tensor.hpp" #include "tensor_layout.hpp" #include "element_wise_operation.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" #include "reference_conv_fwd.hpp" #include "convolution_utility.hpp" @@ -35,45 +36,41 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; -// clang-format off using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvFwdDefault, // ConvForwardSpecialization - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvFwdDefault, // ConvForwardSpecialization + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector using ReferenceConvFwdInstance = ck::tensor_operation::host:: ReferenceConvFwd; diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 9962c6579d..b3924e44a1 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -30,6 +31,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector&); + void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); @@ -225,6 +229,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -293,8 +300,8 @@ void profile_gemm_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; if(tflops > best_tflops) {