diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index fee83abb2a..7e2a2874ba 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -143,7 +143,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 -static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MPerBlock = 128; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -151,15 +151,15 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 64, - MPerBlock, 32, KPerBlock, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, 16, 16, 16, 16, 4, 4, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -170,14 +170,14 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - constexpr ck::index_t sorted_tile_num = 2; + constexpr ck::index_t sorted_tile_num = 13; constexpr ck::index_t valid_tile_num = sorted_tile_num; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t N = 6144; ck::index_t K = 4096; - ck::index_t experts = 2; + ck::index_t experts = 8; ck::index_t tokens = 832; ck::index_t topk = 2; @@ -418,7 +418,7 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; -#if 1 +#if 0 printf("a0_t_k_k:\n"); // for(int t = 0; t < tokens; ++t) // { @@ -671,7 +671,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); -#if 1 +#if 0 printf("e_t_n_device_result:\n"); for(int t = 0; t < tokens; ++t) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp index e073021f0d..df5b709222 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp @@ -203,9 +203,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto i) { if constexpr(i < mfma_stages_more) { - static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_a) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } else { - static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_a) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } @@ -274,23 +260,15 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto i) { if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) { - static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_a) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } else { - static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_b) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } @@ -392,14 +370,14 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}> b_scale_thread_bufs; // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0)); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); @@ -476,22 +454,11 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( [&](auto chunk) { @@ -503,7 +470,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}, I0, Number{}), - a_block_buf, + a_block_bufs(I0), a_thread_desc_, make_tuple(Number{}, I0, @@ -525,7 +492,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}, I0, Number{}), - b_block_buf, + b_block_bufs(I0), b_thread_desc_, make_tuple(Number{}, I0, @@ -537,6 +504,13 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { @@ -652,22 +626,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}]; }); - using mfma_input_type_a = - typename vector_type::type; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = - typename vector_type::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset( @@ -702,10 +674,11 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 47 96 --> 111| 160 --> 175 224 --> 239| etc. // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. // k = 0 k = 1 - block_sync_lds(); + // __builtin_amdgcn_s_waitcnt(3952); + // block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), @@ -719,7 +692,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}, I0, Number{}), - a_block_buf, + a_block_bufs(scale_mem_buf), a_thread_desc_, make_tuple(Number{}, I0, @@ -743,7 +716,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}, I0, Number{}), - b_block_buf, + b_block_bufs(scale_mem_buf), b_thread_desc_, make_tuple(Number{}, I0, @@ -801,10 +774,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto m0) { static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { @@ -848,22 +817,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}]; }); - using mfma_input_type_a = - typename vector_type::type; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = - typename vector_type::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset( make_tuple(m0, n0, imxdl, inxdl, 0)); @@ -885,11 +852,12 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * - (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( [&](auto chunk) { @@ -902,7 +870,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}, I0, Number{}), - a_block_buf, + a_block_bufs(I1), a_thread_desc_, make_tuple(Number{}, I0, @@ -925,7 +893,7 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}, I0, Number{}), - b_block_buf, + b_block_bufs(I1), b_thread_desc_, make_tuple(Number{}, I0, @@ -980,22 +948,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}]; }); - using mfma_input_type_a = - typename vector_type::type; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = - typename vector_type::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset( make_tuple(m0, n0, imxdl, inxdl, 0)); @@ -1062,22 +1028,20 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}]; }); - using mfma_input_type_a = - typename vector_type::type; + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; - using mfma_input_type_b = - typename vector_type::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; - using mfma_scale_input_type_a = - typename vector_type::type; - using mfma_scale_input_type_b = - typename vector_type::type; + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset( make_tuple(m0, n0, imxdl, inxdl, 0)); @@ -1092,69 +1056,6 @@ struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3(), c_thread_buf.GetVectorTypeReference(Number{})); -#if 0 - printf( - "blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: " - "%d, ikxdl: %d, a_thread_vec=<%.2f, %.2f, %.2f, %.2f>, " - "b_thread_vec=<%.2f, %.2f, %.2f, %.2f>, a_scale=%08x, " - "b_scale=%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n", - blockIdx.x, - blockIdx.y, - threadIdx.x, - imxdl.value, - inxdl.value, - ikxdl.value, - type_convert( - a_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<0>{})), - type_convert( - a_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<1>{})), - type_convert( - a_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<0>{})), - type_convert( - a_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<1>{})), - type_convert( - b_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<0>{})), - type_convert( - b_thread_vec - .template AsType()[Number<0>{}] - .unpack(Number<1>{})), - type_convert( - b_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<0>{})), - type_convert( - b_thread_vec - .template AsType()[Number<1>{}] - .unpack(Number<1>{})), - *(reinterpret_cast(&( - a_scale_thread_vec - .template AsType()[Number<0>{}]))), - *(reinterpret_cast(&( - b_scale_thread_vec - .template AsType()[Number<0>{}]))), - type_convert( - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[Number<0>{}]), - type_convert( - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[Number<1>{}]), - type_convert( - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[Number<2>{}]), - type_convert( - c_thread_buf.GetVectorTypeReference(Number{}) - .template AsType()[Number<3>{}])); -#endif }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp index 9bfa3e2365..3e9e501126 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,8 +52,7 @@ template + index_t GatherDim = 1> struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); @@ -67,31 +66,15 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - // static constexpr index_t AK0 = SrcDesc{}.GetLength(I0); - // static constexpr index_t M = SrcDesc{}.GetLength(I1); - // static constexpr index_t AK1 = SrcDesc{}.GetLength(I2); static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; - static constexpr auto wave_thread_cluster_lengths = - Sequence{}; - static constexpr auto wave_cluster_lengths = - Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{}; static constexpr auto thread_single_load_size = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - - // CK_PRINT(); - // After a load, each thread moves by `thread_steps` instead of loading the next elements. // It makes the whole wavefront load contiguous memory, what is required for direct loads. - static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; - static constexpr auto wave_single_load_size = - wave_thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; static constexpr index_t gather_num = thread_slice_lengths.At(Number{}); @@ -119,8 +102,12 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive // elements = 64 consecutive DWORDs. +#if defined(__gfx950__) int num_contiguous_dwords = 4; - bool is_contiguous = true; +#else + int num_contiguous_dwords = 1; +#endif + bool is_contiguous = true; static_for<0, nDim, 1>{}([&](auto i) { if(is_contiguous) { @@ -128,7 +115,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad } if(thread_slice_lengths[nDim - i - 1] > 1) { - CK_PRINT>(); is_contiguous = false; } }); @@ -189,6 +175,25 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + constexpr auto wave_cluster_lengths = generate_sequence_v2( + [&](auto i) { + if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3)) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths; + constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; + constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( make_multi_index(ThreadGroup::GetThreadId() / 64)); @@ -276,52 +281,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, true); -#if 0 - __builtin_amdgcn_s_waitcnt(3952); - block_sync_lds(); - printf("blkx: %u, blky: %u, tid: %u, red_id: %d src: %d (cal: %d, gather: %d), " - "dst_offset: " - "%d, a_dst_buffer=<0x%08x, 0x%08x, 0x%08x, 0x%08x>\n", - blockIdx.x, - blockIdx.y, - threadIdx.x, - static_cast(ordered_dst_access_idx[Number{}]), - src_offset, - src_coord_xor_.GetOffset(), - gather_offset, - dst_offset, - // *(reinterpret_cast(&(dst_buf[dst_offset + 0].data))), - *(reinterpret_cast( - &(dst_buf[dst_offset + 0 + 16 * threadIdx.x].data))), - *(reinterpret_cast( - &(dst_buf[dst_offset + 4 + 16 * threadIdx.x].data))), - *(reinterpret_cast( - &(dst_buf[dst_offset + 8 + 16 * threadIdx.x].data))), - *(reinterpret_cast( - &(dst_buf[dst_offset + 12 + 16 * threadIdx.x].data)))); - -#else - __builtin_amdgcn_s_waitcnt(3952); - block_sync_lds(); - printf("blkx: %u, blky: %u, tid: %u, thread_slice_lengths=<%d, %d, %d>, " - "src_coord_xor_=<%d, " - "%d, %d>, read_id: %d " - "src: %d (cal: %d, gather: %d)\n", - blockIdx.x, - blockIdx.y, - threadIdx.x, - thread_slice_lengths[0], - thread_slice_lengths[1], - thread_slice_lengths[2], - src_coord_xor_.GetIndex().At(I0), - src_coord_xor_.GetIndex().At(I1), - src_coord_xor_.GetIndex().At(I2), - static_cast(ordered_dst_access_idx[Number{}]), - src_offset, - src_coord_xor_.GetOffset(), - gather_offset); -#endif - constexpr auto move_src_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; @@ -432,8 +391,6 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad private: static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - static constexpr auto wave_cluster_desc_ = - make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); SrcCoord src_coord_; SrcCoord src_coord_xor_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp index d45cb4068e..3b02e4647f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp @@ -256,31 +256,18 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle; - RunKernel(kernel); - } - else - { - const auto kernel = kernel_moe_mxgemm_2lds; - RunKernel(kernel); - } - } + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { @@ -310,26 +297,15 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle; - RunKernel(kernel); - } - else - { - const auto kernel = kernel_moe_mxgemm_2lds; - RunKernel(kernel); - } + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 7af0760ab1..746d842294 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -129,8 +129,8 @@ template {}; static constexpr auto BK1Number = Number{}; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto MXdlPack = 2; static constexpr auto NXdlPack = 2; static constexpr auto KXdlPack = 2; + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; - static constexpr bool is_single_rate_mfma = false; - static constexpr auto is_scale_mfma = true; - using mfma_selector = MfmaSelector; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk / APackedSize); // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; @@ -362,12 +370,28 @@ struct GridwiseMoeGemmMXBNS // pad M, but not K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), make_right_pad_transform(M, MPad - M)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return a_grid_desc_ak0_m_ak1; + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(MPad), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return a_grid_desc; } else if constexpr(GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::NKPadding) @@ -439,8 +463,9 @@ struct GridwiseMoeGemmMXBNS GemmSpec != GemmSpecialization::Default), "pk_i4_t does not support padding"); static_assert(!(is_same_v, f4x2_pk_t> && - GemmSpec != GemmSpecialization::Default), - "f4x2_pk_t does not support padding"); + (GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MPadding)), + "f4x2_pk_t does not support K padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) @@ -1368,6 +1393,10 @@ struct GridwiseMoeGemmMXBNS static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, "B scale pack data type too large!"); + static_assert(is_same_v && + is_same_v, + "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!"); + #if 0 template (token_offset) * problem.K; }); -#if 0 - printf("blkx: %u, blky: %u, tidx: %u,AMThreads: %d, token_pos: %d, gather_offsets:<%d, %d, " - "%d, %d>\n", - blockIdx.x, - blockIdx.y, - threadIdx.x, - AMThreads, - token_pos, - gather_offsets[Number<0>{}], - gather_offsets[Number<1>{}], - gather_offsets[Number<2>{}], - gather_offsets[Number<3>{}]); -#endif - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(