diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 3f93733b48..f654550417 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -16,3 +16,7 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() +set(EXAMPLE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental") +target_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 889035b3ed..bdb7b4520a 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -158,11 +158,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< MPerBlock, 128, 128, 16, 16, 32, 32, - 2, 2, + 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + 1, 1, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, false, false, A0DataType>; #endif // clang-format on diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp index 5d95d2f76e..9902128bdc 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -746,25 +746,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); __builtin_amdgcn_sched_barrier(0); - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, I0), - a_scale_thread_buf); - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<0>{})); - }); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); if constexpr(NumKBlockPerScale == 1) { a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<2>{})); + a_scale_thread_copy_step.At(Number<1>{})); } else { a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<1>{})); + a_scale_thread_copy_step.At(Number<0>{})); } b_scale_thread_copy.Run(b_scale_grid_desc, @@ -786,25 +782,21 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, I0), - a_scale_thread_buf); - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<0>{})); - }); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); if constexpr(NumKBlockPerScale == 1) { a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<2>{})); + a_scale_thread_copy_step.At(Number<1>{})); } else { a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - a_scale_thread_copy_step.At(Number<1>{})); + a_scale_thread_copy_step.At(Number<0>{})); } b_scale_thread_copy.Run(b_scale_grid_desc, @@ -970,25 +962,22 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; }); - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc, - make_tuple(m0, I0), - a_scale_thread_buf); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); - }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); if constexpr(NumKBlockPerScale == 1) { a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); } else { a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); } b_scale_thread_copy.Run(b_scale_grid_desc, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 7ad72039e3..68c3097d85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1464,43 +1464,6 @@ struct GridwiseMoeGemmBlockScale // shuffle C and write out { - // // print C - // printf("tid: %d, blkid: %d, " - // "c_thread_buf = <%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, - // %1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n", get_thread_local_1d_id(), block_m_id, - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<0>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<1>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<2>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<3>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<4>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<5>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<6>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<7>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<8>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<9>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<10>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<11>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<12>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<13>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<14>{}], - // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template - // AsType()[Number<15>{}]); - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -1811,6 +1774,8 @@ struct GridwiseMoeGemmBlockScale const BDataType* p_b_grid, DsGridPointer& p_ds_grid, CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, void* p_shared, void* p_shared1, const Problem& problem, @@ -1834,6 +1799,17 @@ struct GridwiseMoeGemmBlockScale problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens + : problem.NumTokens * problem.TopK, + ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -1891,7 +1867,9 @@ struct GridwiseMoeGemmBlockScale gather_offsets(m0) = token_offset * problem.K; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); - + const index_t expert_scale_stride = + __builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); @@ -1902,6 +1880,12 @@ struct GridwiseMoeGemmBlockScale p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1932,7 +1916,7 @@ struct GridwiseMoeGemmBlockScale AThreadTransferSrcResetCoordinateAfterRun, true, 1, - 2>(a_grid_desc_ak0_m_ak1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), a_element_op, a_block_desc_ak0_m_ak1, @@ -1984,19 +1968,111 @@ struct GridwiseMoeGemmBlockScale (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + + //scale + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + // get each thread's offset in the scale tensor + // A scale + const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM; + + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray scale_gather_offsets; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scale_gather_offsets(m0) = + token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK); + }); + + // printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id, + // threadIdx.x, a_thread_offset, + // scale_gather_offsets(Number<0>{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2_gather, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false, + MXdlPerWave>( + a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_scale_thread_desc, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); // shuffle C and write out { @@ -2006,23 +2082,24 @@ struct GridwiseMoeGemmBlockScale constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + // transposed XDL // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + //only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2031,24 +2108,24 @@ struct GridwiseMoeGemmBlockScale static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // M0 (MXdlPerWave) per shuffle M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + M2)), // M2 = MPerXdl make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // N0 (NXdlPerWave) per shuffle N1, // N1 = NWave - N2))), // N2 = NPerXdl + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -2058,56 +2135,56 @@ struct GridwiseMoeGemmBlockScale const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3, + N2, + I1, + N4>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, 7, 1, InMemoryDataOperationEnum::Set, 1, true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, make_multi_index(0, 0, m_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), ck::tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -2162,7 +2239,7 @@ struct GridwiseMoeGemmBlockScale using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -2202,16 +2279,16 @@ struct GridwiseMoeGemmBlockScale p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, + SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence>{}; + N2, + 1, + N4>>{}; constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -2230,7 +2307,6 @@ struct GridwiseMoeGemmBlockScale constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS StaticallyIndexedArray @@ -2243,16 +2319,14 @@ struct GridwiseMoeGemmBlockScale static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; + float weight = token_offset < problem.NumTokens ? 1 : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } else { - const float* p_sorted_weights_2 = p_ds_grid[I2]; + const float* p_sorted_weights_2 = p_ds_grid[I0]; weight = weight * p_sorted_weights_2[c_token_pos + m0]; } scatter_offsets(m0) = token_offset * problem.N; @@ -2262,10 +2336,10 @@ struct GridwiseMoeGemmBlockScale block_sync_lds(); // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_shuffle_block_buf); // make sure it's safe to read from LDS