diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index b3b49ea13b..f4066a249b 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -175,7 +175,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 -static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MPerBlock = 128; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -183,14 +183,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 64, - MPerBlock, 32, KPerBlock, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, 16, 16, 16, 16, - 2, 2, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, - 2, 2, S<1, 8, 1, 8>, S<2, 1, 1, 1>, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on @@ -202,14 +202,14 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - constexpr ck::index_t sorted_tile_num = 2; - constexpr ck::index_t valid_tile_num = 2; + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t N = 6144; ck::index_t K = 4096; - ck::index_t experts = 2; + ck::index_t experts = 8; ck::index_t tokens = 832; ck::index_t topk = 2; @@ -351,30 +351,44 @@ int main(int argc, char* argv[]) d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 4: a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 5: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 6: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp index 8e14797efb..0350f3590b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -965,6 +965,54 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3(), b_scale_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); +#if 0 + printf("blkIdx: %u, blkIdy: %u, tidx: %u, imxdl: %d, inxdl: " + "%d, ikxdl: %d, a_thread_vec=<0x%08x, 0x%08x, 0x%08x, " + "0x%08x>, " + "b_thread_vec=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, " + "a_scale=0x%08x, " + "b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + im_minor, + in_minor, + ik_minor, + *(reinterpret_cast( + &(a_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast( + &(a_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast( + &(a_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast( + &(a_thread_vec.template AsType()[Number<3>{}]))), + *(reinterpret_cast( + &(b_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast( + &(b_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast( + &(b_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast( + &(b_thread_vec.template AsType()[Number<3>{}]))), + *(reinterpret_cast( + &(a_scale_thread_vec + .template AsType()[Number<0>{}]))), + *(reinterpret_cast( + &(b_scale_thread_vec + .template AsType()[Number<0>{}]))), + type_convert( + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[Number<0>{}]), + type_convert( + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[Number<1>{}]), + type_convert( + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[Number<2>{}]), + type_convert( + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[Number<3>{}])); +#endif }); }); if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index b2bfa582cf..b796321e3e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -42,12 +42,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { -#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -79,12 +79,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { -#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -885,7 +885,7 @@ struct GridwiseMoeGemmMX_BPreshuffle { // contiguous in LDS return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, AK1Number), + make_tuple(AK0Number, Number{}, AK1Number), make_tuple(AK1Number, Number{}, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill @@ -2025,17 +2025,30 @@ struct GridwiseMoeGemmMX_BPreshuffle problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( - make_tuple(problem.M / (MXdlPack * MPerXdl), + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( make_tuple(problem.N / (NXdlPack * NPerXdl), math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -2102,23 +2115,15 @@ struct GridwiseMoeGemmMX_BPreshuffle // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); // Gride buffer creation const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - -#if 0 - printf("blkx: %u, blky: %u, tidx: %u, a_grid_size: %ld\n", - blockIdx.x, - blockIdx.y, - threadIdx.x, - a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - -#endif - - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + const auto b_grid_buf = + make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A, B scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( @@ -2585,16 +2590,18 @@ struct GridwiseMoeGemmMX_BPreshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2615,40 +2622,41 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());