diff --git a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp b/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp index 8cd93e95e7..6a2f158ba9 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm_fp16.cpp @@ -131,13 +131,18 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F16>; < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 32, 128, 128, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + 256, 32, 128, 128, + // ak1, bk1 8, 8, + // mn_perxdl 32, 32, + // mn_xdlperwave 1, 1, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + // a,b: loadtranfer cluster, cluster order, srcorder, srcpervec, dstpervec, lds_extra + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -162,7 +167,7 @@ int main(int argc, char* argv[]) ck::index_t N = 6144; ck::index_t K = 8192; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; + ck::index_t sorted_tile_num = 1; ck::index_t sorted_tile_size = 32; ck::index_t SORTED_SIZE = sorted_tile_num * sorted_tile_size; ck::index_t tokens = 32; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp index f9a81210da..d452ed2e3c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp @@ -45,11 +45,12 @@ template struct ThreadGroupTensorSliceTransfer_v4r1_mod8 { + static constexpr auto I0 = Number<0>{}; static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; static constexpr index_t gather_num = thread_slice_lengths.At(Number{}); + static constexpr index_t mod_num = ThreadClusterLengths{}.At(I0); // Dirty HACK FELIX, TODO fix using Index = MultiIndex; - // using GatherIndex = MultiIndex; __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_mod8( const SrcDesc& src_desc, @@ -86,7 +87,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(ThreadGroup::GetThreadId() % 8)); + make_multi_index(ThreadGroup::GetThreadId() % mod_num)); threadwise_transfer_.SetSrcSliceOrigin(src_desc, src_block_slice_origin + src_thread_cluster_idx * thread_slice_lengths); @@ -104,7 +105,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_mod8 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(ThreadGroup::GetThreadId() % 8)); + make_multi_index(ThreadGroup::GetThreadId() % mod_num)); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; threadwise_transfer_.SetSrcSliceOrigin(src_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 1c8cbbd193..3902ecd283 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1127,16 +1127,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]); // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto MLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto KLoadThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0) * ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); - constexpr auto MLoadRepeats = MPerBlock / MLoadThreads; - static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); - const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / KLoadThreads; - StaticallyIndexedArray token_offsets; //= p_sorted_token_ids[token_pos]; - static_for<0, MLoadRepeats, 1>{}([&](auto m0) { - token_offsets(m0) = p_sorted_token_ids[token_pos + MLoadThreads * m0] * problem.K; + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + // static_assert(MLoadRepeats == 1, "only support 1 line per thread now!"); + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + gather_offsets(m0) = p_sorted_token_ids[token_pos + m0] * problem.K; + printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); }); - // printf("threadIdx.x %d off %d\n", threadIdx.x, token_offsets(I0)); const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1194,7 +1196,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}, - token_offsets); + gather_offsets); // Thread-wise copy // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -1222,7 +1224,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle auto a_block_buf = make_dynamic_buffer( static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(AK0Threads, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); // Blockwise GEMM pipeline diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index df8e1fd7aa..5270c4bb32 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -178,15 +178,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // maintain a container record is_src_valid, waiting for RunWrite use. const index_t ld_offset = src_coord_.GetOffset() + gather_offset; - const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize() * sizeof(SrcData);//hack felix, todo use coord + const bool is_src_valid = ld_offset < src_desc.GetElementSpaceSize();//hack felix, todo use coord //coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_) && (gather_offset < 32*512); src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, is_src_valid); using src_vector_type = vector_type_maker_t; using src_vector_t = typename src_vector_type::type; - // if(blockIdx.x+blockIdx.y==0) - // printf("tid %d off %d %d\n", threadIdx.x, src_coord_.GetOffset(), gather_offset ); + if(threadIdx.x==0) + printf("use tid %d num %d off %d %d\n", threadIdx.x, ordered_src_access_idx[Number{}](), src_coord_.GetOffset(), gather_offset ); auto src_vector_container = src_vector_type{src_buf.template Get(ld_offset, true)}; @@ -235,7 +235,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // printf("tid %d %f\n",threadIdx.x, type_convert(src_vector_container.template AsType()[idx])); // }); // } - constexpr auto move_on_dim = [&]() constexpr + auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; @@ -246,15 +246,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather move_on_dim_(i) &= ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; }); + move_on_dim_(i) &= i.value != ordered_gather_dim; + + // if(threadIdx.x==0) + // printf("i %d %d ordered_gather_dim %d\n", i.value, move_on_dim_(i), ordered_gather_dim); }); return move_on_dim_; } (); - // move src coord static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) + if(threadIdx.x==0) + printf("use tid %d ori cord: %d i %d mov %d\n", threadIdx.x, src_coord_.GetOffset(), i.value, move_on_dim[i]); + if (move_on_dim[i]) { if constexpr(forward_sweep[i]) { @@ -267,7 +272,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); } } + if(threadIdx.x==0) + printf("use tid %d moved cord: %d\n", threadIdx.x, src_coord_.GetOffset()); }); + }); // move src coordinate back to slice origin (or not) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index bdf7821c20..30e820b45c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -423,14 +423,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3 dst_coords_[i].GetOffset(), is_dst_valid, dst_vectors[i].template AsType()[I0]); - if(1) { - static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { - using DstData = remove_cvref_t>; - using print_vec_t = typename vector_type::type; - // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, - // type_convert(dst_vectors[i].template AsType()[idx])); - }); - } + // if(1) { + // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { + // using DstData = remove_cvref_t>; + // using print_vec_t = typename vector_type::type; + // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, + // type_convert(dst_vectors[i].template AsType()[idx])); + // }); + // } }); // move coordinate