From 27fb28ed31ccdfa0c26a98c032eddd6fe4f77571 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Tue, 4 Mar 2025 11:06:15 +0800 Subject: [PATCH] i4 support lds multiple shuffle --- ...oup_tensor_slice_transfer_v7r3_scatter.hpp | 24 +-- .../gpu/grid/gridwise_moe_gemm.hpp | 162 ++++++++++-------- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 46 +++-- 3 files changed, 132 insertions(+), 100 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index 4bda56e48b..ac4c6f678e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -65,16 +65,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const StaticallyIndexedArray& src_block_slice_origins, const DstDescs& dst_descs, const StaticallyIndexedArray& dst_block_slice_origins, - const ElementwiseOperation& element_op, - const StaticallyIndexedArray &scatter_offsets, - const StaticallyIndexedArray &scatter_weights) + const ElementwiseOperation& element_op) : threadwise_transfer_(src_descs, StaticallyIndexedArray{}, dst_descs, StaticallyIndexedArray{}, - element_op, - scatter_offsets, - scatter_weights) + element_op) { static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && @@ -129,12 +125,13 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, + StaticallyIndexedArray &scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); } } @@ -144,15 +141,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { if constexpr(is_detected::value) - threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id); else - threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id); } } @@ -160,10 +158,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter __device__ void Run(const SrcDescs& src_descs, const SrcBuffers& src_bufs, const DstDescs& dst_descs, - DstBuffers dst_bufs) + DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, + StaticallyIndexedArray &scatter_weights) { - RunRead(src_descs, src_bufs); - RunWrite(dst_descs, dst_bufs); + RunRead(src_descs, src_bufs, scatter_weights); + RunWrite(dst_descs, dst_bufs, scatter_offsets); } template diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 4be51c44f5..de905737da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1497,32 +1497,6 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = MPerBlock / EMThreads; - constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; - StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - const float *p_sorted_weights_0 = p_ds_grid[I0]; - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; - - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; - if constexpr (IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } else { - const float *p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; - // if(threadIdx.x % 16 == 0) - // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); - }); constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, @@ -1558,9 +1532,7 @@ struct GridwiseMoeGemm idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op, - scatter_offsets, - scatter_weights}; + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1589,8 +1561,37 @@ struct GridwiseMoeGemm CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + const float *p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_weights; //= for topk + // too hack here, 2 specific for topk weights, fixme + // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; + if constexpr (IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } else { + const float *p_sorted_weights_2 = p_ds_grid[I2]; + weight = weight * p_sorted_weights_2[c_token_pos + m0]; + } + + // if(threadIdx.x % 8 == 0 && blockIdx.x == 0) + // printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight); + scatter_offsets(m0) = token_offset * problem.N; + scatter_weights(m0) = weight; + }); + block_sync_lds(); // each thread write its data from VGPR to LDS @@ -1608,7 +1609,10 @@ struct GridwiseMoeGemm c_ds_desc_refs, c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); + tie(c_grid_buf), + scatter_offsets, + scatter_weights + ); if constexpr(access_id < num_access - 1) { @@ -1664,16 +1668,32 @@ struct GridwiseMoeGemm const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); // constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2}; // const index_t b_block_id = blockIdx.x % problem.NBlock; + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if (expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); const auto block_mn = [&]() -> std::pair { if constexpr (NSwizzle) { - const index_t expert_block_id = blockIdx.x / problem.NBlock; - const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]); - const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1]; - const index_t expert_block_swizzle = expert_block_id / expert_swizzle; - const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle); - const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8); - const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle); + // const index_t expert_block_id = blockIdx.x / problem.NBlock; // + // const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]); + // const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1]; + // const index_t expert_block_swizzle = expert_block_id / expert_swizzle; + // const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle); + // const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8); + // const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle); + // if(threadIdx.x==0) + // printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, es, p_sorted_expert_ids[expert_block_id]); + + const index_t ecnt_prefix = p_max_token_id[1+expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2+expert_id] - ecnt_prefix; + const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; //p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane(bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + // if(threadIdx.x==0) + // printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, ecnt, expert_id); return {nid, mid}; } else { return {blockIdx.x, blockIdx.y}; @@ -1681,7 +1701,7 @@ struct GridwiseMoeGemm }(); const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; - const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); + // if (threadIdx.x==0) { // printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id); // } @@ -1695,7 +1715,7 @@ struct GridwiseMoeGemm constexpr auto AMRepeats = MPerBlock / AMThreads; const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; - if(token_pos >= max_token_id || token0 >= problem.NumTokens) + if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) return; StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, AMRepeats, 1>{}([&](auto m0) { @@ -1989,32 +2009,6 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = MPerBlock / EMThreads; - constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; - StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - const float *p_sorted_weights_0 = p_ds_grid[I0]; - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; - - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; - if constexpr (IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } else { - const float *p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; - // if(threadIdx.x % 16 == 0) - // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); - }); constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, @@ -2050,9 +2044,7 @@ struct GridwiseMoeGemm idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op, - scatter_offsets, - scatter_weights}; + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2081,8 +2073,37 @@ struct GridwiseMoeGemm CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + const float *p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_weights; //= for topk + // too hack here, 2 specific for topk weights, fixme + // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; + if constexpr (IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } else { + const float *p_sorted_weights_2 = p_ds_grid[I2]; + weight = weight * p_sorted_weights_2[c_token_pos + m0]; + } + + // if(threadIdx.x % 8 == 0 && blockIdx.x == 0) + // printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight); + scatter_offsets(m0) = token_offset * problem.N; + scatter_weights(m0) = weight; + }); + block_sync_lds(); // each thread write its data from VGPR to LDS @@ -2100,7 +2121,10 @@ struct GridwiseMoeGemm c_ds_desc_refs, c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); + tie(c_grid_buf), + scatter_offsets, + scatter_weights + ); if constexpr(access_id < num_access - 1) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index b62c18538d..fb1ea640ff 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -100,14 +100,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const StaticallyIndexedArray& src_slice_origins, const DstDescs& dst_descs, const StaticallyIndexedArray& dst_slice_origins, - const ElementwiseOperation& element_op, - const StaticallyIndexedArray &scatter_offsets, - const StaticallyIndexedArray &scatter_weights) + const ElementwiseOperation& element_op) : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), - element_op_(element_op), - scatter_offsets_(scatter_offsets), - scatter_weights_(scatter_weights) + element_op_(element_op) { static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! cannot evenly divide"); @@ -158,6 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, + StaticallyIndexedArray &scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -181,9 +178,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static_assert(SrcScalarPerVectors{}[Number{}] == 1, "scatter weight dim, should only one vec"); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number{})); + // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number{})); static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights_(Number{}); }); + [&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights(Number{}); }); } else if constexpr(SrcScalarPerVectors{}[i] == 1) { @@ -418,6 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -430,13 +428,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter if constexpr (OutputScatter) { constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); - scatter_offset = scatter_offsets_(Number{}); + scatter_offset = scatter_offsets(Number{}); } // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); - const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord + const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], // dst_coords_[i]); @@ -449,11 +447,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter dst_offset, is_dst_valid, dst_vectors[i].template AsType()[I0]); - // if(1) { - // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { + // if(threadIdx.x%8 ==0 && blockIdx.x==0) { + // static_for<0, 1, 1>{}([&](auto idx) { // using DstData = remove_cvref_t>; // using print_vec_t = typename vector_type::type; - // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, + // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid, // type_convert(dst_vectors[i].template AsType()[idx])); // }); // } @@ -509,10 +507,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter __device__ void Run(const SrcDescs& src_descs, const SrcBuffers& src_bufs, const DstDescs& dst_descs, - DstBuffers dst_bufs) + DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, + StaticallyIndexedArray &scatter_weights) { - RunRead(src_descs, src_bufs); - RunWrite(dst_descs, dst_bufs); + RunRead(src_descs, src_bufs, scatter_weights); + RunWrite(dst_descs, dst_bufs, scatter_offsets); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ? dst_slice_origin_step_idx : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + auto adjusted_step_idx_scatter = [&]() + { + Index step_; + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number{}]; + }); + + return step_; + } + (); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter); move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); } @@ -709,8 +719,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter SrcCoords src_coords_; DstCoords dst_coords_; const ElementwiseOperation element_op_; - StaticallyIndexedArray scatter_offsets_; - StaticallyIndexedArray scatter_weights_; }; } // namespace ck