From d87ddebb304c03ed162cee227010ef452bdb2b35 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Tue, 25 Feb 2025 03:06:55 +0000 Subject: [PATCH] revert back to v1 --- .../65_gemm_multiply_multiply/moe_gemm1.cpp | 11 +++++++---- ...se_gemm_pipeline_xdlops_b_preshuffle_v1.hpp | 3 +-- .../gpu/grid/gridwise_moe_gemm.hpp | 18 ++++++++++-------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp index b1dd958e0e..caffe129ef 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1.cpp @@ -139,7 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Nswizzle = true; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); @@ -175,9 +175,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 4, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, Nswizzle, true, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; +// DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, // clang-format on @@ -257,8 +258,10 @@ int main(int argc, char* argv[]) Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; - max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size, 0, 2, 4, 6, 8, 10, 12, 14, 16}; + int eids[] = {0, 0,1, 1, 2,2, 3,3, 4,4, 5, 5, 6, 6, 7,7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for (int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index c4c1fa2959..4bf6046e07 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -141,6 +141,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1>{}; + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); const float *p_sorted_weights_0 = p_ds_grid[I0]; - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk // too hack here, 2 specific for topk weights, fixme @@ -1568,8 +1568,6 @@ struct GridwiseMoeGemm auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - // if(threadIdx.x==0 && blockIdx.x==0) - // printf("cidx %d %d tpos %d\n", dstidx(I0), dstidx(I1), c_token_pos); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1581,13 +1579,13 @@ struct GridwiseMoeGemm const float *p_sorted_weights_2 = p_ds_grid[I2]; weight = weight * p_sorted_weights_2[c_token_pos + m0]; } + + // if(threadIdx.x % 8 == 0 && blockIdx.x == 0) + // printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight); scatter_offsets(m0) = token_offset * problem.N; scatter_weights(m0) = weight; - // if(threadIdx.x % 16 == 0) - // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); }); - // make sure it's safe to write to LDS block_sync_lds(); // each thread write its data from VGPR to LDS @@ -1605,7 +1603,11 @@ struct GridwiseMoeGemm c_ds_desc_refs, c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); + tie(c_grid_buf), + scatter_offsets, + scatter_weights + ); + if constexpr(access_id < num_access - 1) { constexpr auto cde_lds_and_global_step =