From da2659d5025289300c8281ebe2d6ca7f98eb759e Mon Sep 17 00:00:00 2001 From: coderfeli Date: Sat, 15 Mar 2025 14:26:30 +0000 Subject: [PATCH] input output all ok --- .../moe_gemm1_xdl_fp8.cpp | 10 ++-- .../moe_gemm1_xdl_pk_i4.cpp | 4 +- .../moe_gemm2_xdl_fp8.cpp | 2 +- .../moe_gemm2_xdl_pk_i4.cpp | 2 +- ...roup_tensor_slice_transfer_v4r1_gather.hpp | 4 +- ...oup_tensor_slice_transfer_v7r3_scatter.hpp | 6 ++- .../gpu/device/impl/device_moe_gemm.hpp | 2 + .../gpu/grid/gridwise_moe_gemm.hpp | 48 ++++++++++++------- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 7 +-- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 9 ++-- include/ck/utility/dynamic_buffer.hpp | 13 ++--- 11 files changed, 66 insertions(+), 41 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index a2421bd3d7..cb39aff6e4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::long_index_t, A0DataType>; // clang-format on @@ -235,12 +235,12 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size}; + max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // max_token_id.mData = {valid_size}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = 0; + expert_ids.mData[i] = eids[i]; } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index c06e595c0f..badb2efb87 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -159,7 +159,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, MXDLPerWave, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>; // clang-format on #else static constexpr ck::index_t MPerBlock = 128; @@ -176,7 +176,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>; // clang-format on #endif diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 0d12441016..f134524c33 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -164,7 +164,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index c80b01d8c5..d5990d59e6 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -149,7 +149,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp index fd150b4fdc..92aef65388 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadGroupTensorSliceTransfer_v4r1_gather @@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather const DstDesc& dst_desc, const Index& dst_block_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : threadwise_transfer_(src_desc, make_zero_multi_index(), src_element_op, @@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, + IndexType, GatherDim, NumThreadScratch>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index c42b0cb65e..befdd4cf7c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -42,6 +42,7 @@ template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or @@ -169,7 +170,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, StaticallyIndexedArray& scatter_weights) { RunRead(src_descs, src_bufs, scatter_weights); @@ -230,6 +231,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, + IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 950fe0236d..d69693d326 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -67,6 +67,7 @@ template ) @@ -491,7 +492,7 @@ struct GridwiseMoeGemm template __host__ __device__ static auto - MakeCGridDescriptor_M_N(long_index_t M, long_index_t MPad, long_index_t N, long_index_t NPad, long_index_t StrideC) + MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) { const auto c_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) @@ -1210,7 +1211,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray gather_offsets; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1218,7 +1219,7 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1226,7 +1227,7 @@ struct GridwiseMoeGemm const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); - const auto a_grid_buf = make_dynamic_buffer( + const auto a_grid_buf = make_long_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, @@ -1261,6 +1262,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1519,6 +1521,7 @@ struct GridwiseMoeGemm uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -1528,8 +1531,16 @@ struct GridwiseMoeGemm make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; - auto c_grid_buf = make_long_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // using BufferType = std::conditional_t< + // std::is_same_v, + // decltype(make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())), + // decltype(make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())) + // >; + auto c_grid_buf = make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // BufferType c_grid_buf = std::is_same_v ? + // make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) : + // make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, @@ -1563,7 +1574,7 @@ struct GridwiseMoeGemm const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk // too hack here, 2 specific for topk weights, fixme @@ -1575,7 +1586,7 @@ struct GridwiseMoeGemm block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - long_index_t token_offset = fused_token & 0xffffff; + IndexType token_offset = fused_token & 0xffffff; float weight = token_offset < problem.NumTokens ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] : 0.0; @@ -1588,7 +1599,7 @@ struct GridwiseMoeGemm const float* p_sorted_weights_2 = p_ds_grid[I2]; weight = weight * p_sorted_weights_2[c_token_pos + m0]; } - scatter_offsets(m0) = token_offset * problem.N; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; scatter_weights(m0) = weight; }); @@ -1713,7 +1724,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray + StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; @@ -1722,7 +1733,7 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1730,7 +1741,7 @@ struct GridwiseMoeGemm const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); - const auto a_grid_buf = make_dynamic_buffer( + const auto a_grid_buf = make_long_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, @@ -1765,6 +1776,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, 2>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -2029,6 +2041,7 @@ struct GridwiseMoeGemm uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -2038,7 +2051,10 @@ struct GridwiseMoeGemm make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; - auto c_grid_buf = make_long_dynamic_buffer( + auto c_grid_buf = std::is_same_v ? + make_long_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()): + make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = @@ -2073,7 +2089,7 @@ struct GridwiseMoeGemm const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk // too hack here, 2 specific for topk weights, fixme @@ -2099,7 +2115,7 @@ struct GridwiseMoeGemm const float* p_sorted_weights_2 = p_ds_grid[I2]; weight = weight * p_sorted_weights_2[c_token_pos + m0]; } - scatter_offsets(m0) = token_offset * problem.N; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; scatter_weights(m0) = weight; }); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index c496622f9b..bd6fe772e4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadwiseTensorSliceTransfer_v3r1_gather @@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather const DstDesc& dst_desc, const Index& dst_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), src_element_op_(src_element_op), @@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); - const index_t ld_offset = src_coord_.GetOffset() + gather_offset; + const IndexType ld_offset = src_coord_.GetOffset() + gather_offset; src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, true); @@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather DstCoord dst_coord_; const SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_; - StaticallyIndexedArray gather_offsets_; + StaticallyIndexedArray gather_offsets_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 08fcdfd007..0170960c4b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -43,6 +43,7 @@ template typename DstResetCoordinateAfterRunFlags, // Sequence + typename IndexType, index_t ScatterDim = 1, bool OutputScatter = true, index_t ScatterWeightIdx = 3, @@ -412,7 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -421,7 +422,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // loop over space-filling curve static_for<0, dst_num_access, 1>{}([&](auto iAccess) { auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; - long_index_t scatter_offset = 0; + IndexType scatter_offset = 0; if constexpr(OutputScatter) { constexpr auto iScatter = @@ -431,7 +432,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; - long_index_t dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); + IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], // dst_coords_[i]); @@ -490,7 +491,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, StaticallyIndexedArray& scatter_weights) { RunRead(src_descs, src_bufs, scatter_weights); diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index bc80dea99f..25d3d104ef 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -225,8 +225,8 @@ struct DynamicBuffer static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); -#if 0 - bool constexpr use_amd_buffer_addressing = true; +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -349,7 +349,7 @@ struct DynamicBuffer __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); #else // if(i >= 2169041600) - *c_style_pointer_cast(p_data_ + i) = x; + *c_style_pointer_cast(&p_data_[i]) = x; #endif } } @@ -380,12 +380,13 @@ struct DynamicBuffer (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) - bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && ( is_same_v, float> || (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || - (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0)); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -424,7 +425,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 using scalar_t = typename scalar_type>::type; - bool constexpr use_amd_buffer_addressing = is_same_v, double>; + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v, double>; #else bool constexpr use_amd_buffer_addressing = false; #endif