From dd42d8e8fa09f54f96a41742e192021906e448d7 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Mon, 10 Mar 2025 09:56:15 +0000 Subject: [PATCH] support index type --- .../moe_pk_i4_gemm1.cpp | 2 +- .../moe_pk_i4_gemm2.cpp | 2 +- ..._group_tensor_slice_transfer_v4r1_mod8.hpp | 44 ++--- ...oup_tensor_slice_transfer_v7r3_scatter.hpp | 74 +++++--- .../gpu/device/impl/device_moe_gemm.hpp | 111 +++++------ .../gpu/grid/gridwise_moe_gemm.hpp | 178 +++++++++--------- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 147 +++++---------- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 104 +++++----- 8 files changed, 322 insertions(+), 340 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp index 1dff53bbb8..1a1ee30767 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp @@ -182,7 +182,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, S<4, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ck::index_t, Nswizzle, true, A0DataType>; // clang-format on #endif diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp index 11aa7ab502..d2b43f4099 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp @@ -162,7 +162,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,ck::index_t, false, false, A0DataType>; // clang-format on int main(int argc, char* argv[]) diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp index d452ed2e3c..29524013dc 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp @@ -41,7 +41,8 @@ template struct ThreadGroupTensorSliceTransfer_v4r1_mod8 { @@ -59,7 +60,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 const DstDesc& dst_desc, const Index& dst_block_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray &gather_offsets) + const StaticallyIndexedArray& gather_offsets) : threadwise_transfer_(src_desc, make_zero_multi_index(), src_element_op, @@ -178,25 +179,26 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v3r1_gather; + SrcElementwiseOperation, + DstElementwiseOperation, + DstInMemOp, + SrcData, + DstData, + SrcDesc, + DstDesc, + SrcDimAccessOrder, + DstDimAccessOrder, + SrcVectorDim, + DstVectorDim, + SrcScalarPerVector, + DstScalarPerVector, + SrcScalarStrideInVector, + DstScalarStrideInVector, + ThreadTransferSrcResetCoordinateAfterRun, + ThreadTransferDstResetCoordinateAfterRun, + ScatterIdxType, + GatherDim, + NumThreadScratch>; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index ac4c6f678e..1d5ecaa58e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -42,8 +42,9 @@ template struct ThreadGroupTensorSliceTransfer_v7r3_scatter @@ -51,14 +52,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter static constexpr index_t nDim = remove_cvref_t>::GetNumOfDimension(); - static constexpr index_t mod_num = ThreadClusterLengths{}.At( Number<3>{}) ; // Dirty HACK FELIX, TODO fix + static constexpr index_t mod_num = + ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix static constexpr index_t nSrc = remove_cvref_t::Size(); static constexpr index_t nDst = remove_cvref_t::Size(); using Index = MultiIndex; static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; - static constexpr index_t scatter_num = thread_slice_lengths.At(Number{}); + static constexpr index_t scatter_num = thread_slice_lengths.At(Number{}); __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter( const SrcDescs& src_descs, @@ -108,13 +110,20 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( make_multi_index(ThreadGroup::GetThreadId())); const auto src_thread_slice_origins = generate_tuple( - [&](auto i) { return src_block_slice_origins[i] + src_thread_cluster_idx * thread_slice_lengths; }, + [&](auto i) { + return src_block_slice_origins[i] + + src_thread_cluster_idx * thread_slice_lengths; + }, Number{}); const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index( OutputScatter ? ThreadGroup::GetThreadId() % mod_num : ThreadGroup::GetThreadId())); + make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num + : ThreadGroup::GetThreadId())); const auto dst_thread_slice_origins = generate_tuple( - [&](auto i) { return dst_block_slice_origins[i] + dst_thread_cluster_idx * thread_slice_lengths; }, + [&](auto i) { + return dst_block_slice_origins[i] + + dst_thread_cluster_idx * thread_slice_lengths; + }, Number{}); threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); @@ -125,7 +134,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray &scatter_weights, + StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or @@ -141,16 +150,18 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray &scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { if constexpr(is_detected::value) - threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id); + threadwise_transfer_.RunWrite( + dst_descs, dst_bufs, scatter_offsets, thread_scratch_id); else - threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id); + threadwise_transfer_.RunWrite( + dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id); } } @@ -159,8 +170,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray &scatter_offsets, - StaticallyIndexedArray &scatter_weights) + StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_weights) { RunRead(src_descs, src_bufs, scatter_weights); RunWrite(dst_descs, dst_bufs, scatter_offsets); @@ -206,24 +217,25 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v7r3_scatter; + DstDatas, + SrcDescs, + DstDescs, + ElementwiseOperation, + DstInMemOps, + decltype(thread_slice_lengths), + SrcDimAccessOrder, + DstDimAccessOrder, + SrcVectorDim, + DstVectorDim, + SrcScalarPerVectors, + DstScalarPerVector, + ThreadTransferSrcResetCoordinateAfterRunFlags, + ThreadTransferDstResetCoordinateAfterRunFlags, + ScatterIdxType, + ScatterDim, + OutputScatter, + ScatterWeightIdx, + NumThreadScratch>; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index f6db7f5b6e..9a3778c047 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -66,8 +66,9 @@ template { static constexpr index_t NumDTensor = DsDataType::Size(); - using GridwiseGemm = - GridwiseMoeGemm< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - NSwizzle, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + using GridwiseGemm = + GridwiseMoeGemm; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 0d013e4929..03fdcf5574 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -146,7 +146,8 @@ template = max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1243,37 +1244,37 @@ struct GridwiseMoeGemm // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1_mod8, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - LDSTypeA, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - 1, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + ScatterIdxType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -1523,16 +1524,16 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - 1, //ScatterDim - true, //OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - > - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + ScatterIdxType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1567,7 +1568,8 @@ struct GridwiseMoeGemm const float *p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray + scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk // too hack here, 2 specific for topk weights, fixme // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; @@ -1717,7 +1719,8 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; + StaticallyIndexedArray + gather_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1749,37 +1752,37 @@ struct GridwiseMoeGemm // dummy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1_mod8, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - LDSTypeA, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - 1, - 2>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}, - gather_offsets); + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + ScatterIdxType, + 1, + 2>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -2035,16 +2038,16 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - 1, //ScatterDim - true, //OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - > - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + ScatterIdxType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2079,7 +2082,8 @@ struct GridwiseMoeGemm const float *p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray + scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk // too hack here, 2 specific for topk weights, fixme // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index ffd1932113..0d3dab370a 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -31,8 +31,8 @@ template struct ThreadwiseTensorSliceTransfer_v3r1_gather { @@ -54,31 +55,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I10 = Number<10>{}; - static constexpr auto I12 = Number<12>{}; - static constexpr auto I13 = Number<13>{}; - static constexpr auto I14 = Number<14>{}; - static constexpr auto I16 = Number<16>{}; - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr auto SrcScalarPerVector = Number{}; - static constexpr auto DstScalarPerVector = Number{}; - + static constexpr auto I0 = Number<0>{}; static constexpr index_t gather_num = SliceLengths{}.At(Number{}); __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather( @@ -88,29 +65,26 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather const DstDesc& dst_desc, const Index& dst_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray &gather_offsets) + const StaticallyIndexedArray& gather_offsets) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), src_element_op_(src_element_op), dst_element_op_(dst_element_op), gather_offsets_(gather_offsets) { - if constexpr(is_same_v, pk_i4_t>) - { - static_assert(is_same_v, remove_cvref_t>, - "SrcData != DstData"); - - static_assert( - SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, - "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); - - static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); - } } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + + auto adjusted_origin_idx = [&]() { + Index idx; + static_for<0, nDim, 1>{}([&](auto i) { + idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number{}]; + }); + return idx; + }(); + src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx); } __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) @@ -134,15 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, + static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim]; + constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim]; constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); @@ -210,19 +183,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); + auto gather_offset = + gather_offsets_(ordered_src_access_idx[Number{}]); // maintain a container record is_src_valid, waiting for RunWrite use. const index_t ld_offset = src_coord_.GetOffset() + gather_offset; - const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord - //coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512); + const bool is_src_valid = + ld_offset < + src_desc + .GetElementSpaceSize(); // hack felix, todo use coord + // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, + // src_coord_) && (gather_offset < 32*512); src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, is_src_valid); using src_vector_type = vector_type_maker_t; using src_vector_t = typename src_vector_type::type; - // if(threadIdx.x==0) - // printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number{}](), src_coord_.GetOffset(), gather_offset ); + auto src_vector_container = src_vector_type{src_buf.template Get(ld_offset, true)}; @@ -236,22 +213,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather if constexpr(decltype(src_element_op_)::is_pack8_invocable) return math::min(8, SrcScalarPerVector); } - else if constexpr(is_detected::value) + if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack4_invocable) return math::min(4, SrcScalarPerVector); } - else if constexpr(is_detected::value) + if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack2_invocable) return math::min(2, SrcScalarPerVector); } - else - { - return 1; - } + return 1; }; constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); @@ -269,14 +241,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - - // if(1) { - // using print_vec_t = typename vector_type::type; - // static_for<0, SrcScalarPerVector, 1>{}([&](auto idx) { - // printf("tid %d %f\n",threadIdx.x, type_convert(src_vector_container.template AsType()[idx])); - // }); - // } - constexpr auto move_on_dim = [&]() constexpr + + auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; @@ -288,9 +254,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; }); move_on_dim_(i) &= i.value != ordered_gather_dim; - - // if(threadIdx.x==0) - // printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim); }); return move_on_dim_; @@ -298,9 +261,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather (); // move src coord static_for<0, nDim, 1>{}([&](auto i) { - // if(threadIdx.x==0) - // printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]); - if constexpr(move_on_dim[i]) + if(move_on_dim[i]) { if constexpr(forward_sweep[i]) { @@ -313,10 +274,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); } } - // if(threadIdx.x==0) - // printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset()); }); - }); // move src coordinate back to slice origin (or not) @@ -349,7 +307,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // OOB Check constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -420,8 +378,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { - static_assert(!is_same_v, pk_i4_t>, - "in-register transpose is not supported for pk_i4_t"); // each transpose does // DstScalarPerVector # of src vectors in src_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ @@ -482,12 +438,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather } else { - constexpr auto packed_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; - - static_ford{}([&](auto idx) { + static_ford{}([&](auto idx) { dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); } @@ -515,7 +466,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -609,16 +560,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // copy data from dst_vector_container to dst_buf dst_buf.template Set( - dst_coord_.GetOffset() / PackedSize, + dst_coord_.GetOffset(), is_dst_valid, dst_vector_container.template AsType()[I0]); - - // if(1) { - // using print_vec_t = typename vector_type::type; - // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { - // printf("tid %d off %d valid %d val %f\n",threadIdx.x, dst_coord_.GetOffset(), is_dst_valid, type_convert(dst_vector_container.template AsType()[idx])); - // }); - // } + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; @@ -669,7 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -714,7 +659,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto reset_src_data_step = [&]() { Index reset_src_data_step_; - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; }); + static_for<0, nDim, 1>{}([&](auto i) { + reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; + }); return reset_src_data_step_; }(); @@ -726,7 +673,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -811,7 +758,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather __device__ static constexpr auto GetSrcThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -860,7 +807,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -871,7 +818,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather { // 1st stage of transforms constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -951,7 +898,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather DstCoord dst_coord_; const SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_; - StaticallyIndexedArray gather_offsets_; + StaticallyIndexedArray gather_offsets_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index fb1ea640ff..a91fcf91b0 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -43,8 +43,9 @@ template typename DstResetCoordinateAfterRunFlags, // Sequence - index_t ScatterDim = 1, - bool OutputScatter = true, + typename ScatterIdxType = index_t, + index_t ScatterDim = 1, + bool OutputScatter = true, index_t ScatterWeightIdx = 3, index_t NumThreadScratch = 1> struct ThreadwiseTensorSliceTransfer_v7r3_scatter @@ -61,8 +62,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static constexpr index_t nSrc = SrcDescs::Size(); static constexpr index_t nDst = DstDescs::Size(); - using Index = MultiIndex; - static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); + using Index = MultiIndex; + static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); // return a tuple of coordiantes for a tuple of tensor template {}([&](auto i) { dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); - // printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], dst_coords_(i).GetOffset()); + // printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, + // dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], + // dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], + // dst_coords_(i).GetOffset()); }); } @@ -154,7 +158,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray &scatter_weights, + StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -173,14 +177,19 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter src_coords_[i]); oob_val = oob_val & is_src_valid; - if (i.value == ScatterWeightIdx) + if(i.value == ScatterWeightIdx) { - static_assert(SrcScalarPerVectors{}[Number{}] == 1, "scatter weight dim, should only one vec"); - constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); + static_assert(SrcScalarPerVectors{}[Number{}] == 1, + "scatter weight dim, should only one vec"); + constexpr auto iScatter = + SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number{})); - static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights(Number{}); }); + // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, + // scatter_weights(Number{})); + static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { + src_vectors(i).template AsType()(j) = + scatter_weights(Number{}); + }); } else if constexpr(SrcScalarPerVectors{}[i] == 1) { @@ -189,7 +198,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const auto tmp = src_bufs[i].template Get(src_coords_[i].GetOffset(), true); // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp); + // printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, + // i.value, src_coords_[i].GetOffset(), tmp); static_for<0, SrcScalarPerVector, 1>{}( [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); } @@ -415,7 +425,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray &scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -423,36 +433,37 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // loop over space-filling curve static_for<0, dst_num_access, 1>{}([&](auto iAccess) { - auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; auto scatter_offset = 0; - if constexpr (OutputScatter) + if constexpr(OutputScatter) { - constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); + constexpr auto iScatter = + DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); scatter_offset = scatter_offsets(Number{}); } // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { - using dst_vector_t = typename remove_cvref_t::type; - auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + using dst_vector_t = typename remove_cvref_t::type; + auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); - // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], - // dst_coords_[i]); + // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + // dst_coords_[i]); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); // if(threadIdx.x==0) - // printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), scatter_offset ); + // printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), + // scatter_offset ); dst_bufs(i).template Update( - dst_offset, - is_dst_valid, - dst_vectors[i].template AsType()[I0]); + dst_offset, is_dst_valid, dst_vectors[i].template AsType()[I0]); // if(threadIdx.x%8 ==0 && blockIdx.x==0) { // static_for<0, 1, 1>{}([&](auto idx) { // using DstData = remove_cvref_t>; // using print_vec_t = typename vector_type::type; - // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid, - // type_convert(dst_vectors[i].template AsType()[idx])); + // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, + // is_dst_valid, type_convert(dst_vectors[i].template + // AsType()[idx])); // }); // } }); @@ -468,18 +479,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static_for<0, nDim, 1>{}([&](auto i) { step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i]; - + // if(threadIdx.x==0) - // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); + // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), + // ordered_gather_dim); }); return step_; } (); static_for<0, nDst, 1>{}([&](auto i) { - move_tensor_coordinate(dst_descs[i], - dst_coords_(i), - make_tensor_coordinate_step(dst_descs[i], forward_step_scatter)); + move_tensor_coordinate( + dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step_scatter)); }); } }); @@ -508,8 +521,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray &scatter_offsets, - StaticallyIndexedArray &scatter_weights) + StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_weights) { RunRead(src_descs, src_bufs, scatter_weights); RunWrite(dst_descs, dst_bufs, scatter_offsets); @@ -535,15 +548,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter } else { - constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + constexpr auto reset_step = + DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); auto reset_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { - step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number{}]; - + step_(i) = + (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number{}]; + // if(threadIdx.x==0) - // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim); + // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), + // ordered_gather_dim); }); return step_; @@ -683,18 +699,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ? dst_slice_origin_step_idx : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - auto adjusted_step_idx_scatter = [&]() - { + auto adjusted_step_idx_scatter = [&]() { Index step_; static_for<0, nDim, 1>{}([&](auto i) { - step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number{}]; + step_(i) = + (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number{}]; }); return step_; - } - (); + }(); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter); + const auto adjusted_step = + make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter); move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); }