diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 1f488fef83..186416904c 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -121,12 +121,12 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 32; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -134,7 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu +static constexpr ck::index_t ActOP = 2; // 0: gelu, 1: silu, 2: swiglu // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -154,7 +154,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 1, 1, S<1, 32, 1, 8>, S, + 2, 2, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, true, int32_t, A0DataType>; // clang-format on @@ -169,8 +169,8 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; ck::index_t tokens = 64; ck::index_t topk = 2; @@ -232,9 +232,9 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} // max_token_id.mData = {valid_size}; for(int i = 0; i < sorted_tile_num; i++) { @@ -285,9 +285,9 @@ int main(int argc, char* argv[]) d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 299ccb6a3e..05b015c4bb 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -172,7 +172,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, false, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index ce102ff1ad..29750b8baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -314,6 +314,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< // Initialize C c_thread_buf.Clear(); + c_thread_buf_up.Clear(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp index e3c5f5e065..b5a4793716 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index 727332d4c2..bee0b01a74 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 4c7b438bea..7582669e08 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1205,6 +1205,7 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; const index_t token0 = @@ -1320,7 +1321,7 @@ struct GridwiseMoeGemm KPerBlock); if constexpr(IsInputGemm) { - const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( p_b_grid_up + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); @@ -1402,92 +1403,166 @@ struct GridwiseMoeGemm constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); // mul scales - const float* p_scale_b = p_ds_grid[I1]; - if constexpr(PerTokenQuant) - { - constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); - p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + - get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; - } - else - { - p_scale_b += expert_id; - } const float* p_sorted_weights_0 = p_ds_grid[I0]; - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); - const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; - vector_type scale_token_ids; - vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - const float scale_b = p_scale_b[n0 * NWave * PerTokenQuant]; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(PerTokenQuant) - { - scale_token_ids = - *c_style_pointer_cast*>( - p_sorted_token_ids + m_pos); - } - if constexpr(!IsInputGemm) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - float scale_a = [&]() { - if constexpr(PerTokenQuant) + const float* p_scale_b = p_ds_grid[I1]; + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + vector_type scale_token_ids; + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(!IsInputGemm) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion { - index_t fused_token = scale_token_ids.AsType()[m4]; - const index_t token_offset = fused_token & 0xffffff; - return token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset] - : 0.0; + if constexpr(ActivationOperation == Activation::silu) + { + tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::gelu) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + else if(ActivationOperation == Activation::swiglu) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } } else { - return p_sorted_weights_0[0]; + c_thread_buf(cidx) = scale_a * scale_b * + topk_weights.AsType()[m4] * + c_thread_buf[cidx]; } - }(); - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu) - { - tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); - } - else if(ActivationOperation == Activation::gelu) - { - tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); - } - else if(ActivationOperation == Activation::swiglu) - { - const float scale_up = - p_scale_b[(n0 * NPerXdl + problem.N) * PerTokenQuant]; - auto gate = scale_a * scale_b * c_thread_buf[cidx]; - auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); - c_thread_buf(cidx) = gate * up; - } - } - else - { - c_thread_buf(cidx) = scale_a * scale_b * - topk_weights.AsType()[m4] * - c_thread_buf[cidx]; - } + }); }); }); }); - }); + } + else + { + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(!IsInputGemm) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu) + { + tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::gelu) + { + auto gate = c_thread_buf[cidx]; + auto up = c_thread_buf_up[cidx]; + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + else if(ActivationOperation == Activation::swiglu) + { + auto gate = c_thread_buf[cidx]; + auto up = c_thread_buf_up[cidx]; + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + } + else + { + c_thread_buf(cidx) = + topk_weights.AsType()[m4] * c_thread_buf[cidx]; + } + }); + }); + }); + }); + } constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 7b912ef362..7cd0a0fc7f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once