diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index 497d8e7959..c8f5be068a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -617,7 +617,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 BlockToCTileMap_GemmStreamK block_2_ctile_map_streamk; }; @@ -1253,7 +1253,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK; template , // typename ThreadClusterArrangeOrder, CShuffleDataType, // typename SrcData, - CShuffleDataType, // typename DstData, + AccDataType, // typename DstData, decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, @@ -1990,7 +1990,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 Problem& problem, void* p_workspace) { - const AElementwiseOperation a_element_op{}; const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; @@ -2018,167 +2017,154 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - using Block2CTileMap_streamk2 = - BlockToCTileMap_GemmStreamK_v2; - Block2CTileMap_streamk2 block_2_ctile_map_streamk(problem.M, - problem.N, - AK0Number * problem.KPadded, - problem.Grid_size, - problem.Streamk_sel); - for(auto block_idx = get_block_1d_id(); - block_idx < block_2_ctile_map_streamk.get_grid_dims(); - block_idx += gridDim.x) + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.num_cu, + problem.occupancy, + problem.num_sk_blocks); + + auto block_idx = get_block_1d_id(); + is_sk_block = static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + uint32_t* p_semaphore = reinterpret_cast( + reinterpret_cast(p_workspace) + + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); + + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { - is_sk_block = - static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; - is_dp_block = - static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && - static_cast(block_idx) < - block_2_ctile_map_streamk.reduction_start_block_idx; - - block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); - num_k_block_main_loop = iter_end - iter_start; - - uint32_t* p_semaphore = reinterpret_cast( - reinterpret_cast(p_workspace) + - block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); - - if constexpr(Block2CTileMap_streamk2::ReductionStrategy == - StreamKReductionStrategy::Reduction) + is_reduction_block = static_cast(block_idx) >= + block_2_ctile_map_streamk.reduction_start_block_idx; + if(is_reduction_block) { - is_reduction_block = static_cast(block_idx) >= - block_2_ctile_map_streamk.reduction_start_block_idx; - if(is_reduction_block) - { - // descriptors - constexpr auto cluster_length_reduce = GetClusterLengthReduction(); - constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); - const auto reduce_thread_cluster_idx = - reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); - const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; - const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; - constexpr auto MReduceIters = math::integer_divide_ceil( - Number{}, cluster_length_reduce.At(I0)); - constexpr auto NReduceIters = math::integer_divide_ceil( - Number{}, - cluster_length_reduce.At(I1) * - Number{}); + constexpr auto MReduceIters = + math::integer_divide_ceil(Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); - constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{})); - constexpr auto acc_thread_buf_store_desc = - make_naive_tensor_descriptor_packed(make_tuple( - I1, I1, I1, Number{})); + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = + make_naive_tensor_descriptor_packed(make_tuple( + I1, I1, I1, Number{})); - constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); - constexpr auto partial_acc_load_step_n = - make_multi_index(0, - cluster_length_reduce.At(I1) * - CShuffleBlockTransferScalarPerVector_NPerBlock); - constexpr auto partial_acc_load_step_n_reverse = make_multi_index( - 0, - -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * - CShuffleBlockTransferScalarPerVector_NPerBlock); - constexpr auto partial_acc_load_step_m = - make_multi_index(cluster_length_reduce.At(I0), 0); + constexpr auto partial_acc_load_step_n = make_multi_index( + 0, + cluster_length_reduce.At(I1) * CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_n_reverse = + make_multi_index(0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); - constexpr auto partial_acc_store_step_n = - make_multi_index(0, - 0, - 0, - cluster_length_reduce.At(I1) * - CShuffleBlockTransferScalarPerVector_NPerBlock); - constexpr auto partial_acc_store_step_n_reverse = make_multi_index( - 0, - 0, - 0, - -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * - CShuffleBlockTransferScalarPerVector_NPerBlock); - constexpr auto partial_acc_store_step_m = - make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + constexpr auto partial_acc_store_step_n = make_multi_index( + 0, + 0, + 0, + cluster_length_reduce.At(I1) * CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_n_reverse = + make_multi_index(0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); - StaticBuffer - parcial_acc_buf; - StaticBuffer - acc_buf; + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; - // start to compute - auto reduction_idx = - block_idx - block_2_ctile_map_streamk.reduction_start_block_idx; - auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial( - reduction_idx, problem.M, problem.N); + // start to compute + auto reduction_idx = + block_idx - block_2_ctile_map_streamk.reduction_start_block_idx; + auto spatial_idx = + block_2_ctile_map_streamk.tile_to_spatial(reduction_idx, problem.M, problem.N); - workgroup_barrier wg_barrier(p_semaphore); + workgroup_barrier wg_barrier(p_semaphore); - uint32_t tile_acc_offset_start = - block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx); - uint32_t tile_acc_offset_end = - block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx + - 1); + uint32_t tile_acc_offset_start = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx + 1); - uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start; + // uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start; - if(threadIdx.x == 0) - { - p_semaphore[reduction_idx] = 0; - } + // if(threadIdx.x == 0) + // { + // p_semaphore[reduction_idx] = 0; + // } - __syncthreads(); + // __syncthreads(); - auto acc_load = ThreadwiseTensorSliceTransfer_v2< - AccDataType, // SrcData, - AccDataType, // DstData, - decltype(c_partial_acc_block_m_n), // SrcDesc, - decltype(acc_thread_buf_load_desc), // DstDesc, - Sequence<1, - CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, - Sequence<0, 1>, // DimAccessOrder, - 1, // SrcVectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector, - 1, // SrcScalarStrideInVector, - false // SrcResetCoordinateAfterRun, - >{c_partial_acc_block_m_n, - make_multi_index(thread_m_cluster_id, - thread_n_cluster_id * - CShuffleBlockTransferScalarPerVector_NPerBlock)}; + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + AccDataType, // SrcData, + AccDataType, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock)}; - auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, // SrcData, - CDataType, // DstData, - decltype(acc_thread_buf_store_desc), // SrcDesc, - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, - CElementwiseOperation, // ElementwiseOperation, - Sequence<1, - 1, - 1, - CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, - Sequence<0, 1, 2, 3>, // DimAccessOrder, - 3, // DstVectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector, - InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, - 1, // DstScalarStrideInVector, - false // DstResetCoordinateAfterRun, - >{c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), - thread_m_cluster_id, - __builtin_amdgcn_readfirstlane(spatial_idx[I1]), - thread_n_cluster_id * - CShuffleBlockTransferScalarPerVector_NPerBlock), - CElementwiseOperation{}}; + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, // SrcData, + CDataType, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, + 1, + 1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock), + CElementwiseOperation{}}; #if 0 if(threadIdx.x == 0) { @@ -2188,152 +2174,149 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 __builtin_amdgcn_readfirstlane(spatial_idx[I1])); } #endif - if(threadIdx.x == 0) - { - atomicAdd(&p_semaphore[reduction_idx], 1); - } + // if(threadIdx.x == 0) + // { + // atomicAdd(&p_semaphore[reduction_idx], 1); + // } - wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count); - using Accumulation = ck::detail:: - AccumulateWithNanCheck; + // wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count); + wg_barrier.wait_eq(0, block_2_ctile_map_streamk.sk_num_blocks); + using Accumulation = ck::detail:: + AccumulateWithNanCheck; - for(int i_m = 0; i_m < MReduceIters; i_m++) - { - static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { - acc_buf.Clear(); - for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) - { - auto c_partial_acc_buf = - make_dynamic_buffer( - reinterpret_cast(p_workspace) + - i * c_partial_acc_block_m_n.GetElementSpaceSize(), - c_partial_acc_block_m_n.GetElementSpaceSize()); - - acc_load.Run(c_partial_acc_block_m_n, - c_partial_acc_buf, - acc_thread_buf_load_desc, - make_tuple(I0, I0), - parcial_acc_buf); - - static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}( - [&](auto i_vec) { - constexpr auto offset = - acc_thread_buf_load_desc.CalculateOffset( - make_tuple(0, i_vec)); - Accumulation::Calculate(acc_buf(Number{}), - parcial_acc_buf[Number{}]); - }); - } - - if(thread_n_cluster_id * - CShuffleBlockTransferScalarPerVector_NPerBlock < - NPerBlock) - { - acc_store.Run(acc_thread_buf_store_desc, - make_tuple(I0, I0, I0, I0), - acc_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - } - if constexpr(NReduceIters != 1) - { - if constexpr(i_n_reduce != (NReduceIters - 1)) - { - acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, - partial_acc_load_step_n); - acc_store.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - partial_acc_store_step_n); - } - else - { - acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, - partial_acc_load_step_n_reverse); - acc_store.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - partial_acc_store_step_n_reverse); - } - } - }); + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) { - acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, - partial_acc_load_step_m); - acc_store.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - partial_acc_store_step_m); + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); } + + if(thread_n_cluster_id * CShuffleBlockTransferScalarPerVector_NPerBlock < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); } - - continue; } + + return; } + } - // offset for last acc buffer of this block - uint32_t block_acc_offset = - (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * - MPerBlock * NPerBlock; - while(true) - { + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * + MPerBlock * NPerBlock; + while(true) + { - uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( - block_2_ctile_map_streamk.get_current_iter_length( - iter_start, iter_end, num_k_block_main_loop)); - uint32_t tile_idx, iter_offset; - block_2_ctile_map_streamk.get_tile_idx_with_offset( - iter_end - 1, tile_idx, iter_offset); - iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + uint32_t current_iter_length = + __builtin_amdgcn_readfirstlane(block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); - auto block_work_idx = - block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - const index_t k0_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = - GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = - GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - // A matrix blockwise copy - auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< - ThisThreadBlock, - AElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), a_element_op, @@ -2341,30 +2324,30 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< - ThisThreadBlock, - BElementwiseOperation, - ck::tensor_operation::element_wise::PassThrough, - InMemoryDataOperationEnum::Set, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), b_element_op, @@ -2372,366 +2355,356 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - auto a_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - auto b_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - auto a_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - auto b_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); - auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); - constexpr auto a_block_slice_copy_step = - make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = - make_multi_index(KPerBlock / BK1Number, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); - blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = - GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle .GetElementSpaceSize()); - auto c_partial_acc_buf = - make_dynamic_buffer( - reinterpret_cast(p_workspace) + block_acc_offset, - c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle - .GetElementSpaceSize()); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1, 3, 7>{})); + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - // CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * - NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - false, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + AccDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; - // LDS to global partial acc - auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< - ThisThreadBlock, // index_t BlockSize, - CElementwiseOperation, // ElementwiseOperation, - // InMemoryDataOperationEnum::Set, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * - NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CShuffleDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be - // false, othre wise has scratch - false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be - // false, othre wise has scratch - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, - make_multi_index(0, 0, 0, 0), - c_element_op}; + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = SpaceFillingCurve< - Sequence<1, MPerBlock, 1, NPerBlock>, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); - // make sure it's safe to read from LDS - block_sync_lds(); - c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple(0, 0, 0, 0)); - - if(is_dp_block) + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Atomic) { // each block copy its data from LDS to global c_shuffle_block_copy_lds_to_global .template Run( + InMemoryDataOperationEnum::AtomicAdd>( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_shuffle_block_buf, c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_buf); } - else if(is_sk_block) + else if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) { - if constexpr(Block2CTileMap_streamk2::ReductionStrategy == - StreamKReductionStrategy::Atomic) - { - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global - .template Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - } - else if constexpr(Block2CTileMap_streamk2::ReductionStrategy == - StreamKReductionStrategy::Reduction) - { - // constexpr offset - c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple(0, 0, 0, 0)); - - c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_shuffle_block_buf, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, - make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); - - c_block_copy_lds_to_partial_acc - .template Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, - c_partial_acc_buf); - } + c_partial_acc_buf); } - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + } + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - // exit condition - iter_end -= current_iter_length; - if(iter_end <= iter_start) - break; - if constexpr(Block2CTileMap_streamk2::ReductionStrategy == - StreamKReductionStrategy::Reduction) - { - block_acc_offset -= MPerBlock * NPerBlock; - } - // make sure next loop LDS is ready for use - block_sync_lds(); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); } - if constexpr(Block2CTileMap_streamk2::ReductionStrategy == + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + if constexpr(Block2CTileMap_streamk::ReductionStrategy == StreamKReductionStrategy::Reduction) { - if(is_sk_block) - { - // increase the counter for this tile - workgroup_barrier wg_barrier(p_semaphore); - wg_barrier.inc(0); - } + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } + if constexpr(Block2CTileMap_streamk::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(0); } } }