From 85f83330b5bb8271ed0f2920f87280ff559a3748 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 2 Apr 2025 07:02:09 +0000 Subject: [PATCH 1/4] fuse moe activation --- .../65_gemm_multiply_multiply/CMakeLists.txt | 6 + .../moe_gemm1_xdl_fp8.cpp | 105 +-- .../moe_gemm1_xdl_pk_i4.cpp | 120 ++-- .../moe_gemm2_xdl_fp8.cpp | 71 +- .../moe_gemm2_xdl_pk_i4.cpp | 21 +- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 620 ++++++++++++++++++ ...peline_xdlops_b_preshuffle_gufusion_v1.hpp | 573 ++++++++++++++++ ..._pipeline_xdlops_b_preshuffle_selector.hpp | 141 ++-- .../blockwise_gemm_pipeline_xdlops_base.hpp | 25 +- ...oup_tensor_slice_transfer_v7r3_scatter.hpp | 10 +- .../gpu/device/impl/device_moe_gemm.hpp | 20 +- .../gpu/grid/gridwise_moe_gemm.hpp | 304 ++++++--- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 39 +- .../cpu/reference_moe_gemm.hpp | 89 ++- 14 files changed, 1760 insertions(+), 384 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 95fd8bace8..fdc961f716 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -12,6 +12,12 @@ foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132) + set(EXAMPLE_COMPILE_OPTIONS) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + endif() set(target 1) endif() endforeach() diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 6f6ae34273..bc94162cd7 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -36,7 +36,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = EDataType; using D0DataType = F32; using D1DataType = F32; using DsDataType = ck::Tuple; @@ -59,27 +59,25 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { - e = ck::type_convert(c * d1 * d0); + (void)d0; + (void)d1; + e = ck::type_convert(c); } -}; - -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const { - // act - float x0 = 0; - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); - e = ck::type_convert(x0); + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); } }; @@ -87,8 +85,6 @@ struct MulABScaleSilu // using DsDataType = DsDataTypeGate; using CDEElementOp = MulABScale; -// using CDEElementOp = MulABScaleSiluMulGate; - void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) { int KPack = 16 / sizeof(B0DataType); @@ -125,9 +121,9 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MXDLPerWave = 1; +static constexpr ck::index_t NXDLPerWave = 1; static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; @@ -138,6 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t ActOP = 0; // 0: gelu, 1: silu, 2: swiglu // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -157,8 +154,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, uint32_t, A0DataType>; + 1, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, true, int32_t, A0DataType>; // clang-format on @@ -172,9 +169,9 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 16; - ck::index_t valid_tile_num = 13; - ck::index_t tokens = 832; + ck::index_t sorted_tile_num = 8; + ck::index_t valid_tile_num = 8; + ck::index_t tokens = 64; ck::index_t topk = 2; // ck::index_t tokens = batch * topk; @@ -224,7 +221,7 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{1, 0}; + constexpr auto StrideDs = std::array{1, 1}; ck::index_t KBatch = 1; @@ -236,7 +233,8 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} // max_token_id.mData = {valid_size}; for(int i = 0; i < sorted_tile_num; i++) { @@ -261,10 +259,13 @@ int main(int argc, char* argv[]) // expert_ids.savetxt("expert_ids.txt", "int"); // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + // Tensor d1_e_n(HostTensorDescriptor({experts, N * 2}, {StrideDs[1] ? StrideDs[1] * + // N * 2: 1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); @@ -278,10 +279,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); @@ -290,10 +291,10 @@ int main(int argc, char* argv[]) d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -311,7 +312,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); // a0_t_k.savetxt("a.txt"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); + d0_t_n.savetxt("d0_t_n.txt", "float"); // d1_e_n.savetxt("d1_e_n.txt", "int"); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); @@ -329,7 +330,8 @@ int main(int argc, char* argv[]) int NPerXdl = device_op.GetPreShuffleParameters(); - preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); @@ -367,9 +369,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) * K * N * experts + + sizeof(B0DataType) * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -394,7 +396,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - PassThrough>; + PassThrough, + ActOP>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -403,7 +406,9 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, PassThrough{}, PassThrough{}, @@ -432,10 +437,10 @@ int main(int argc, char* argv[]) } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + e_t_n_device_result.savetxt("out.txt"); + e_t_n_host_result.savetxt("ref.txt"); return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index badb2efb87..8a497f6905 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -36,7 +36,7 @@ using A0DataType = F8; using B0DataType = I4; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using DsDataType = ck::Tuple; @@ -55,42 +55,32 @@ struct MulABScale __host__ __device__ constexpr void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * 16); +#else + e = ck::type_convert(c); +#endif + } template <> __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); + e = ck::type_convert(c * 16); #else - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); #endif } }; -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const - { - // act - float x0 = 0; -#if CK_USE_PK4_LAYOUT_SHUFFLE - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16); -#else - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); -#endif - e = ck::type_convert(x0); - } -}; - using CDEElementOp = MulABScale; #if 1 @@ -132,53 +122,24 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -#if 0 -static constexpr ck::index_t MPerBlock = 64; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -// clang-format off -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< - Row, Col, DsLayout, ELayout, - A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, - AK1, BK1, - MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - MXDLPerWave, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>; -// clang-format on -#else static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 0; // 0: gelu, 1: silu, 2: swiglu // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, MPerBlock, 128, 128, + 256, MPerBlock, 64, 128, 16, 32, - 32, 32, - 4, 1, + 16, 16, + 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, ck::index_t, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, true, ck::index_t, A0DataType>; // clang-format on -#endif int main(int argc, char* argv[]) { @@ -191,14 +152,14 @@ int main(int argc, char* argv[]) // experts = 8 // per expert: // GEMM shape - ck::index_t N = 14336 * 2; + ck::index_t N = 14336; ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 64; + ck::index_t tokens = 644; ck::index_t topk = 2; if(argc == 1) @@ -260,11 +221,12 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); @@ -279,10 +241,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); @@ -440,7 +402,8 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument)) + if(!device_op.IsSupportedArgument(argument) || + !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -450,9 +413,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) / 2 * K * N * experts + + sizeof(B0DataType) / 2 * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -477,7 +440,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - PassThrough>; + PassThrough, + Act_OP>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -486,7 +450,9 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, PassThrough{}, PassThrough{}, diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 5d74c3bd53..05b015c4bb 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -36,7 +36,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -62,11 +62,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix (void)d0; - e = ck::type_convert(c * d1 * d2); + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -119,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 4; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); // static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint @@ -163,8 +171,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, uint32_t, A0DataType>; + 4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -184,8 +192,8 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 6; - ck::index_t valid_tile_num = 6; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 128; @@ -245,13 +253,13 @@ int main(int argc, char* argv[]) Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); // max_token_id.mData[0] = valid_size; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + // max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = 0;//eids[i]; + expert_ids.mData[i] = eids[i]; } if(tokens * topk > valid_size) { @@ -409,22 +417,21 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceMoeGemm2; - auto ref_moe_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_moe_gemm.MakeInvoker(); - auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, expert_ids, max_token_id, MPerBlock, diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index d5990d59e6..58f982bf46 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -62,11 +62,13 @@ struct MulABScaleExpertWeight EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { (void)d0; + (void)d1; + (void)d2; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d2 * 16); + e = ck::type_convert(c * 16); #else - e = ck::type_convert(c * d1 * d2); + e = ck::type_convert(c); #endif } // for reference cpu @@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; @@ -148,8 +150,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, ck::index_t, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -407,7 +409,8 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument)) + if(!device_op.IsSupportedArgument(argument) || + !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp new file mode 100644 index 0000000000..ce102ff1ad --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -0,0 +1,620 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + StaticallyIndexedArray{}> b_thread_dequant_bufs; + StaticallyIndexedArray{}> + b_thread_dequant_bufs_up; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I0)); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up + [mfma_reg_buf][Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(local_read_buf)); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I1)); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp new file mode 100644 index 0000000000..b5a4793716 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,573 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // if constexpr(i.value < num_buffer_load_inst_a) { + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // } + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index a94ef03008..074b5873ee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -3,8 +3,10 @@ #pragma once +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" @@ -33,57 +35,112 @@ template + index_t KPack, + bool GUFusion = false> constexpr auto BlockGemmBPreshufflePipeline_Selector() { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { if constexpr(std::is_same::value) { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + } } else { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< - BlkGemmPipeSche, - BlockSize, - ADataType, - BDataType, - ComputeDataType, - AccDataType, - ATileDesc, - BTileDesc, - AMmaTileDesc, - BMmaTileDesc, - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack>{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 45ed6845c2..d086201dd7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -46,7 +46,9 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + // static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); static constexpr auto xdlops_gemm = XdlopsGemm{}; @@ -181,6 +183,23 @@ struct BlockwiseGemmXdlops_pipeline_base using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + /** + * @brief Constructor for BlockwiseGemmXdlops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ __host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), Tuple4 b_origin = CalculateBThreadOriginDataIndex()) @@ -316,7 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_base return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( c_grid_desc_g_m0_n0_m1_n1_m2_n2); } - + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index befdd4cf7c..bee0b01a74 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -134,13 +134,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); } } @@ -170,10 +169,9 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index d69693d326..096011051e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -65,8 +65,10 @@ template ; RunKernel(kernel); } @@ -267,7 +271,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -282,7 +285,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -292,7 +294,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -312,7 +313,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 56297c2340..4c7b438bea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" @@ -26,11 +26,18 @@ namespace ck { // two lds chunks. // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu = 0, + silu = 1, + swiglu = 2 +}; + template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -44,7 +51,7 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( + GridwiseGemm::template Run( karg.p_sorted_token_ids, karg.p_sorted_expert_ids, karg.p_max_token_id, @@ -66,7 +73,6 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -81,21 +87,20 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm:: - template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -147,7 +152,10 @@ template - __host__ __device__ static auto - MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) { const auto c_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) @@ -903,7 +912,8 @@ struct GridwiseMoeGemm NPerXdl, MXdlPerWave, NXdlPerWave, - KPack>())>; + KPack, + IsInputGemm>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -1135,7 +1145,6 @@ struct GridwiseMoeGemm template __device__ static void Run(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, @@ -1221,7 +1230,8 @@ struct GridwiseMoeGemm } gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1232,7 +1242,6 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1305,24 +1314,66 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - + decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -1350,6 +1401,94 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_scale_b = p_ds_grid[I1]; + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + const float* p_sorted_weights_0 = p_ds_grid[I0]; + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + vector_type scale_token_ids; + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(!IsInputGemm) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu) + { + tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::gelu) + { + tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::swiglu) + { + const float scale_up = + p_scale_b[(n0 * NPerXdl + problem.N) * PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + gate = gate * math::rcp(1.0 + math::exp(-gate)); + c_thread_buf(cidx) = gate * up; + } + } + else + { + c_thread_buf(cidx) = scale_a * scale_b * + topk_weights.AsType()[m4] * + c_thread_buf[cidx]; + } + }); + }); + }); + }); + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1447,17 +1586,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -1520,7 +1650,7 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights @@ -1531,17 +1661,8 @@ struct GridwiseMoeGemm make_tuple(make_multi_index(0, 0, block_n_id, 0)), c_element_op}; - // using BufferType = std::conditional_t< - // std::is_same_v, - // decltype(make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())), - // decltype(make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize())) - // >; - auto c_grid_buf = make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - // BufferType c_grid_buf = std::is_same_v ? - // make_long_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()) : - // make_dynamic_buffer(p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -1571,36 +1692,21 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] - // & 0xff000000) >> 24; + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - float weight = token_offset < static_cast(problem.NumTokens) - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } scatter_offsets(m0) = static_cast(token_offset) * problem.N; - scatter_weights(m0) = weight; }); block_sync_lds(); @@ -1621,8 +1727,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { @@ -1647,7 +1752,6 @@ struct GridwiseMoeGemm template __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, @@ -1971,11 +2075,12 @@ struct GridwiseMoeGemm const DDataType* ptr_ = p_ds_grid[i]; // hack logic here to support different kind of strides. todo fix it. // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } + // if(i.value == 1) + // { + // ptr_ += + // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : + // 1); + // } return make_dynamic_buffer( ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); }, @@ -2040,7 +2145,7 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights @@ -2083,15 +2188,9 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] - // & 0xff000000) >> 24; + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2099,21 +2198,11 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < static_cast(problem.NumTokens) - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; - // float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]]; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } scatter_offsets(m0) = static_cast(token_offset) * problem.N; - scatter_weights(m0) = weight; }); block_sync_lds(); @@ -2134,8 +2223,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 0170960c4b..7cd0a0fc7f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -154,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -173,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter src_coords_[i]); oob_val = oob_val & is_src_valid; - if(i.value == ScatterWeightIdx) - { - static_assert(SrcScalarPerVectors{}[Number{}] == 1, - "scatter weight dim, should only one vec"); - constexpr auto iScatter = - SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); - static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { - src_vectors(i).template AsType()(j) = - scatter_weights(Number{}); - }); - } - else if constexpr(SrcScalarPerVectors{}[i] == 1) - { - auto data_types = SrcDatas{}; - using DataType = remove_cvref_t; - const auto tmp = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); - } - else - { - src_vectors(i).template AsType()(I0) = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - } + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); }); constexpr auto get_elem_op_vec_len = []() { @@ -421,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // loop over space-filling curve static_for<0, dst_num_access, 1>{}([&](auto iAccess) { - auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; IndexType scatter_offset = 0; if constexpr(OutputScatter) { @@ -432,7 +408,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; - IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); + IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], // dst_coords_[i]); @@ -491,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index af735925ed..a50307e86f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,11 +22,13 @@ template + index_t ActivationType_ = 0, + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct ReferenceMoeGemm : public device::BaseOperator { // Argument + static constexpr auto ActivationType = ActivationType_; struct Argument : public device::BaseArgument { Argument(const Tensor& sorted_token_ids, @@ -34,7 +36,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_t, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -44,7 +48,9 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id_{max_token_id}, sorted_tile_size_{sorted_tile_size}, a_t_k_{a_t_k}, + a_scale_t_{a_scale_t}, b_e_n_k_{b_e_n_k}, + b_scale_e_n_{b_scale_e_n}, c_t_k_n_{c_t_k_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, @@ -57,7 +63,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id_; index_t sorted_tile_size_; const Tensor& a_t_k_; + const Tensor& a_scale_t_; const Tensor& b_e_n_k_; + const Tensor& b_scale_e_n_; Tensor& c_t_k_n_; AElementwiseOperation a_element_op_; @@ -72,11 +80,20 @@ struct ReferenceMoeGemm : public device::BaseOperator float Run(const Argument& arg) { - auto f_mk_kn_mn = [&](auto m, auto n) { + if constexpr(ActivationType > 2) + { + static_assert(false, "Not supported activation type"); + } + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc_up{0}; + ComputeTypeB v_b_up{0}; AccDataType v_acc{0}; + ComputeTypeA v_a{0}; ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int e = arg.expert_ids_(m / arg.sorted_tile_size_); @@ -96,7 +113,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else @@ -106,37 +123,75 @@ struct ReferenceMoeGemm : public device::BaseOperator // same for B matrix if constexpr(is_same_v) { - uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; - uint8_t i4 = 0; + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data; + uint8_t i4 = 0; + uint8_t i4_up = 0; if(k % 2 == 1) - i4 = (i4x2 >> 0) & 0xf; + { + i4 = (i4x2 >> 0) & 0xf; + i4_up = (i4x2_up >> 0) & 0xf; + } else - i4 = (i4x2 >> 4) & 0xf; + { + i4 = (i4x2 >> 4) & 0xf; + i4_up = (i4x2_up >> 4) & 0xf; + } #if CK_USE_PK4_LAYOUT_SHUFFLE - v_b = i4_to_f32_gfx9(i4); + v_b = i4_to_f32_gfx9(i4); + v_b_up = i4_to_f32_gfx9(i4_up); #else - v_b = i4 - 8; + v_b = i4 - 8; + v_b_up = i4_up - 8; #endif } else { arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + if constexpr(ActivationType == 2) + { + arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); + } } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + + if constexpr(ActivationType == 2) + { + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); + } } CDataType v_c{0}; + CDataType v_c_up{0}; arg.c_element_op_(v_c, v_acc); - - arg.c_t_k_n_(t, topk_id, n) = v_c; + if constexpr(ActivationType == 2) + { + arg.c_element_op_(v_c_up, v_acc_up); + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + v_c = v_c * (1.0 / (1.0 + math::exp(-v_c))); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + else + { + if constexpr(ActivationType == 1) + { + tensor_operation::element_wise::Silu{}(v_c, v_c); + } + else if constexpr(ActivationType == 0) + { + tensor_operation::element_wise::Gelu{}(v_c, v_c); + } + arg.c_t_k_n_(t, topk_id, n) = v_c; + } } }; const ck::index_t max_token_id = arg.max_token_id_(0); - make_ParallelTensorFunctor( - f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])( + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( std::thread::hardware_concurrency()); return 0; @@ -162,7 +217,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_n, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -173,7 +230,9 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id, sorted_tile_size, a_t_k, + a_scale_n, b_e_n_k, + b_scale_e_n, c_t_k_n, a_element_op, b_element_op, From b2b34fffbb5c4e3f4e427209b1f02d80c5601539 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 2 Apr 2025 16:27:52 +0000 Subject: [PATCH 2/4] fix fp8 16x16 --- .../moe_gemm1_xdl_fp8.cpp | 28 +++++++++---------- .../gpu/grid/gridwise_moe_gemm.hpp | 5 ++-- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index bc94162cd7..186416904c 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -121,12 +121,12 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 32; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -134,7 +134,7 @@ static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t ActOP = 0; // 0: gelu, 1: silu, 2: swiglu +static constexpr ck::index_t ActOP = 2; // 0: gelu, 1: silu, 2: swiglu // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -154,7 +154,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 1, 1, S<1, 32, 1, 8>, S, + 2, 2, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, true, int32_t, A0DataType>; // clang-format on @@ -169,8 +169,8 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; ck::index_t tokens = 64; ck::index_t topk = 2; @@ -232,9 +232,9 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} // max_token_id.mData = {valid_size}; for(int i = 0; i < sorted_tile_num; i++) { @@ -285,9 +285,9 @@ int main(int argc, char* argv[]) d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 3: diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 4c7b438bea..5a183321cd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1421,7 +1421,7 @@ struct GridwiseMoeGemm vector_type scale_token_ids; vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - const float scale_b = p_scale_b[n0 * NWave * PerTokenQuant]; + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + @@ -1471,7 +1471,8 @@ struct GridwiseMoeGemm else if(ActivationOperation == Activation::swiglu) { const float scale_up = - p_scale_b[(n0 * NPerXdl + problem.N) * PerTokenQuant]; + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; auto gate = scale_a * scale_b * c_thread_buf[cidx]; auto up = scale_a * scale_up * c_thread_buf_up[cidx]; gate = gate * math::rcp(1.0 + math::exp(-gate)); From 20f6674bf60f9e5a9de60d95b03c9a579be19b59 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 02:46:01 +0000 Subject: [PATCH 3/4] fix no quant case --- .../gpu/grid/gridwise_moe_gemm.hpp | 212 +++++++++++------- 1 file changed, 134 insertions(+), 78 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 5a183321cd..0930a64b55 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1402,93 +1402,149 @@ struct GridwiseMoeGemm constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); // mul scales - const float* p_scale_b = p_ds_grid[I1]; - if constexpr(PerTokenQuant) - { - constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); - p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + - get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; - } - else - { - p_scale_b += expert_id; - } const float* p_sorted_weights_0 = p_ds_grid[I0]; - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); - const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; - vector_type scale_token_ids; - vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(PerTokenQuant) - { - scale_token_ids = - *c_style_pointer_cast*>( - p_sorted_token_ids + m_pos); - } - if constexpr(!IsInputGemm) - { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - float scale_a = [&]() { - if constexpr(PerTokenQuant) + const float* p_scale_b = p_ds_grid[I1]; + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + vector_type scale_token_ids; + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(!IsInputGemm) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion { - index_t fused_token = scale_token_ids.AsType()[m4]; - const index_t token_offset = fused_token & 0xffffff; - return token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset] - : 0.0; + if constexpr(ActivationOperation == Activation::silu) + { + tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::gelu) + { + tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::swiglu) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + gate = gate * math::rcp(1.0 + math::exp(-gate)); + c_thread_buf(cidx) = gate * up; + } } else { - return p_sorted_weights_0[0]; + c_thread_buf(cidx) = scale_a * scale_b * + topk_weights.AsType()[m4] * + c_thread_buf[cidx]; } - }(); - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion - { - if constexpr(ActivationOperation == Activation::silu) - { - tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); - } - else if(ActivationOperation == Activation::gelu) - { - tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); - } - else if(ActivationOperation == Activation::swiglu) - { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - auto gate = scale_a * scale_b * c_thread_buf[cidx]; - auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); - c_thread_buf(cidx) = gate * up; - } - } - else - { - c_thread_buf(cidx) = scale_a * scale_b * - topk_weights.AsType()[m4] * - c_thread_buf[cidx]; - } + }); }); }); }); - }); + } + else + { + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(!IsInputGemm) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu) + { + tensor_operation::element_wise::Silu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::gelu) + { + tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), + c_thread_buf(cidx)); + } + else if(ActivationOperation == Activation::swiglu) + { + auto gate = c_thread_buf[cidx]; + auto up = c_thread_buf_up[cidx]; + gate = gate * math::rcp(1.0 + math::exp(-gate)); + c_thread_buf(cidx) = gate * up; + } + } + else + { + c_thread_buf(cidx) = + topk_weights.AsType()[m4] * c_thread_buf[cidx]; + } + }); + }); + }); + }); + } constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); From c0c1c04b5049cd22429a47be86e1f98f2275c089 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 06:26:57 +0000 Subject: [PATCH 4/4] fix bugs --- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 1 + .../gpu/grid/gridwise_moe_gemm.hpp | 32 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index ce102ff1ad..29750b8baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -314,6 +314,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< // Initialize C c_thread_buf.Clear(); + c_thread_buf_up.Clear(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 0930a64b55..7582669e08 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1205,6 +1205,7 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; const index_t token0 = @@ -1320,7 +1321,7 @@ struct GridwiseMoeGemm KPerBlock); if constexpr(IsInputGemm) { - const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( p_b_grid_up + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); @@ -1468,8 +1469,18 @@ struct GridwiseMoeGemm } else if(ActivationOperation == Activation::gelu) { - tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + auto gate = scale_a * scale_b * c_thread_buf[cidx]; + auto up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } else if(ActivationOperation == Activation::swiglu) { @@ -1478,7 +1489,12 @@ struct GridwiseMoeGemm PerTokenQuant]; auto gate = scale_a * scale_b * c_thread_buf[cidx]; auto up = scale_a * scale_up * c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } } @@ -1524,14 +1540,16 @@ struct GridwiseMoeGemm } else if(ActivationOperation == Activation::gelu) { - tensor_operation::element_wise::Gelu{}(c_thread_buf(cidx), - c_thread_buf(cidx)); + auto gate = c_thread_buf[cidx]; + auto up = c_thread_buf_up[cidx]; + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } else if(ActivationOperation == Activation::swiglu) { auto gate = c_thread_buf[cidx]; auto up = c_thread_buf_up[cidx]; - gate = gate * math::rcp(1.0 + math::exp(-gate)); + tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } }