diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp index 155eb5225c..ae543031de 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp @@ -276,13 +276,14 @@ struct DeviceGemmXdl_C_Shuffle const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); float ave_time = 0; - if(has_main_k0_block_loop) + if(has_main_k_block_loop) { const auto kernel = kernel_gemm_xdlops_v3r1< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index fc9cd51c4f..a5e768f44e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -113,7 +113,7 @@ template < index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 { static constexpr auto I0 = Number<0>{}; @@ -131,6 +131,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { constexpr auto max_lds_align = AK1; @@ -246,21 +250,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K / KPerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -290,12 +283,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumPrefetch * KPerBlock)) > 1; + const index_t num_loop = K / KPerBlock; - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } __host__ __device__ static constexpr auto @@ -434,7 +426,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 1, AThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -465,7 +457,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 1, BThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -484,7 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // shuffle C and write out {