diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index eb331b9e80..aa03e59eb9 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -115,8 +115,6 @@ struct MulABScaleExpertWeight } }; -static constexpr bool MulRoutedWeight = true; - using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true // using CDEElementOp = MulABScale; // combine MulRoutedWeight = true @@ -163,6 +161,9 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; #if 0 static constexpr ck::index_t MPerBlock = 128; @@ -180,8 +181,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; -static constexpr bool MulRoutedWeight = true; -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -193,12 +193,11 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, false, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on #else static constexpr ck::index_t MPerBlock = 128; -static constexpr bool MulRoutedWeight = true; // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< @@ -213,8 +212,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; -// clang-format on #endif +// clang-format on int main(int argc, char* argv[]) { @@ -283,22 +282,9 @@ int main(int argc, char* argv[]) Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); max_token_id.mData = {valid_size}; // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; - int eids[sorted_tile_num]{}; for(int i = 0; i < sorted_tile_num; i++) { - if(i < valid_tile_num) - { - eids[i] = (i * experts) / valid_tile_num; - } - else - { - eids[i] = 3; - } - } - - for(int i = 0; i < sorted_tile_num; i++) - { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; @@ -333,7 +319,7 @@ int main(int argc, char* argv[]) HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; - std::cout << "a1_t_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; @@ -398,7 +384,7 @@ int main(int argc, char* argv[]) #if 1 preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), - N * experts, + N * 2 * experts, K, device_op.GetPreShuffleParameters()); #else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index d7c3fb8d8f..ef3d1a4353 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -2293,183 +2293,70 @@ struct GridwiseMoeGemmMX constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); // mul scales - const float* p_sorted_weights_0 = p_ds_grid[I0]; - const float* p_scale_b = p_ds_grid[I1]; - + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); static_assert(M4 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; - if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) - { - if constexpr(PerTokenQuant) - { - constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); - p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + - get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; - } - else - { - p_scale_b += expert_id; - } - vector_type scale_token_ids; - vector_type topk_weights; - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(PerTokenQuant) + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion { - scale_token_ids = - *c_style_pointer_cast*>( - p_sorted_token_ids + m_pos); - } - if constexpr(MulRoutedWeight) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - float scale_a = [&]() { - if constexpr(PerTokenQuant) - { - index_t fused_token = scale_token_ids.AsType()[m4]; - const index_t token_offset = fused_token & 0xffffff; - return token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset] - : 0.0; - } - else - { - return p_sorted_weights_0[0]; - } - }(); - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion + if constexpr(ActivationOperation == Activation::silu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - float gate = scale_a * scale_b * c_thread_buf[cidx]; - float up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) - { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - float gate = scale_a * scale_b * c_thread_buf[cidx]; - float up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - } - else - { - c_thread_buf_fp32(cidx) = - scale_a * scale_b * c_thread_buf[cidx]; + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * - topk_weights.AsType()[m4]; + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; } - }); + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } }); }); }); - } - else - { - vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - } - else - { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; - if constexpr(MulRoutedWeight) - { - c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * - c_thread_buf_fp32[cidx]; - } - } - }); - }); - }); - }); - } + }); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp index c453396466..a0c1ce95fe 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp @@ -137,12 +137,15 @@ struct ReferenceMoeMXGemm2 : public device::BaseOperator f4_t f4 = 0; f4_t f4_up = 0; if(k % 2 == 1) + { f4 = (f4x2 >> 0) & 0xf; f4_up = (f4x2_up >> 0) & 0xf; + } else + { f4 = (f4x2 >> 4) & 0xf; - f4_up = (f4x2_upo >> 4) & 0xf; - + f4_up = (f4x2_up >> 4) & 0xf; + } v_b = type_convert(f4) * type_convert(b_scale); v_b_up = type_convert(f4_up) *