diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index deca85ae64..3c1947c058 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -13,6 +13,12 @@ foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132) + set(EXAMPLE_COMPILE_OPTIONS) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + endif() set(target 1) endif() endforeach() diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index f594080755..3b31460953 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,7 +35,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = EDataType; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -61,27 +60,25 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { - e = ck::type_convert(c * d1 * d0); + (void)d0; + (void)d1; + e = ck::type_convert(c); } -}; - -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const { - // act - float x0 = 0; - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); - e = ck::type_convert(x0); + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); } }; @@ -95,11 +92,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix + (void)d0; + (void)d1; (void)d2; - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -107,16 +112,14 @@ struct MulABScaleExpertWeight float& e, const float& c, const float& d0, const float& d1, const float& d2) const { // for reference cpu - e = ck::type_convert(c * d0 * d1 * d2); + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } }; -using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true -// using DsLayout = DsLayoutGate; -// using DsDataType = DsDataTypeGate; -// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false - -// using CDEElementOp = MulABScaleSiluMulGate; +using CDEElementOp = MulABScaleExpertWeight; void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) { @@ -155,22 +158,21 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = true; -static constexpr bool MulRoutedWeight = false; +static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t D2Vec = 1; -// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -188,8 +190,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; + 2, 2, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on @@ -201,15 +203,13 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; - ck::index_t tokens = 128; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; + ck::index_t tokens = 64; ck::index_t topk = 2; - // ck::index_t tokens = batch * topk; - if(argc == 1) { // use default case @@ -255,28 +255,22 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -290,13 +284,12 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( @@ -304,6 +297,7 @@ int main(int argc, char* argv[]) std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -312,25 +306,25 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; case 3: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); - d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -349,9 +343,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k.savetxt("a.txt"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -369,7 +361,8 @@ int main(int argc, char* argv[]) int NPerXdl = device_op.GetPreShuffleParameters(); - preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); @@ -408,9 +401,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) * K * N * experts + + sizeof(B0DataType) * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -437,6 +430,7 @@ int main(int argc, char* argv[]) PassThrough, PassThrough, PassThrough, + ActOP, MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -446,7 +440,9 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, d2_e_n, PassThrough{}, @@ -472,15 +468,14 @@ int main(int argc, char* argv[]) c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n), - 1.f); + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index fb8a8b9826..3c3ef16198 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -36,7 +36,7 @@ using A0DataType = F8; using B0DataType = I4; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -47,7 +47,8 @@ using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -56,42 +57,32 @@ struct MulABScale __host__ __device__ constexpr void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c); +#else + e = ck::type_convert(c); +#endif + } template <> __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); + e = ck::type_convert(c); #else - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); #endif } }; -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const - { - // act - float x0 = 0; -#if CK_USE_PK4_LAYOUT_SHUFFLE - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16); -#else - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); -#endif - e = ck::type_convert(x0); - } -}; - struct MulABScaleExpertWeight { template @@ -102,13 +93,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { + (void)d0; + (void)d1; (void)d2; - -#if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); -#else - e = ck::type_convert(c * d1 * d0); -#endif + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -116,15 +113,18 @@ struct MulABScaleExpertWeight float& e, const float& c, const float& d0, const float& d1, const float& d2) const { // for reference cpu -#if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d0 * d1 * d2 * 16); -#else - e = ck::type_convert(c * d0 * d1 * d2); -#endif + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } }; -using CDEElementOp = MulABScaleExpertWeight; +static constexpr bool MulRoutedWeight = true; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true + +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) @@ -165,54 +165,24 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -#if 0 -static constexpr ck::index_t MPerBlock = 64; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -// clang-format off -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< - Row, Col, DsLayout, ELayout, - A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, - AK1, BK1, - MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - MXDLPerWave, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; -// clang-format on -#else static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t Nswizzle = false; -static constexpr bool MulRoutedWeight = false; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, MPerBlock, 128, 128, + 256, MPerBlock, 64, 128, 16, 32, - 32, 32, - 4, 1, + 16, 16, + 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>; // clang-format on -#endif int main(int argc, char* argv[]) { @@ -220,13 +190,10 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape - ck::index_t N = 4096 * 2; - ck::index_t K = 6144; + ck::index_t N = 14336; + ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; @@ -266,20 +233,20 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0}; + constexpr auto StrideDs = std::array{0, 0, 0}; ck::index_t KBatch = 1; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0, 0}; + max_token_id.mData = {valid_size}; int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } - int token_per_tile = tokens * topk / valid_tile_num; + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; for(int i = 0; i < sorted_size; i++) { @@ -294,11 +261,12 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( @@ -306,6 +274,7 @@ int main(int argc, char* argv[]) std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -314,11 +283,11 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); @@ -497,9 +466,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) / 2 * K * N * experts + + sizeof(B0DataType) / 2 * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -526,6 +495,7 @@ int main(int argc, char* argv[]) PassThrough, PassThrough, PassThrough, + Act_OP, MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -535,7 +505,9 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, d2_e_n, PassThrough{}, @@ -561,13 +533,13 @@ int main(int argc, char* argv[]) c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n), - 1.f); + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 04f10b53ae..42d892fe26 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,7 +35,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -48,7 +47,6 @@ using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; -// using DsLayoutGate = ck::Tuple; using DsLayout = ck::Tuple; // d0: ascale, d1: bscale, d2:expert weight @@ -62,11 +60,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix (void)d0; - e = ck::type_convert(c * d1 * d2); + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -119,14 +125,12 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 4; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint -// static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -135,7 +139,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; -static constexpr bool MulRoutedWeight = false; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -164,8 +168,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; + 4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -177,16 +181,13 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 6; - ck::index_t valid_tile_num = 6; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 128; @@ -212,6 +213,18 @@ int main(int argc, char* argv[]) K = std::stoi(argv[5]); tokens = std::stoi(argv[6]); } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); @@ -229,15 +242,13 @@ int main(int argc, char* argv[]) ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); - // max_token_id.mData[0] = valid_size; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + + max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; @@ -249,7 +260,7 @@ int main(int argc, char* argv[]) } int token_per_tile = tokens * topk / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -263,8 +274,7 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - expert_ids.savetxt("expert_ids.txt", "int"); - sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); @@ -315,12 +325,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k_k.savetxt("a.txt"); - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); - // d2_e_n.savetxt("d2_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -398,7 +403,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2(c * d1 * d2 * 16); + e = ck::type_convert(c * 16); #else - e = ck::type_convert(c * d1 * d2); + e = ck::type_convert(c); #endif } // for reference cpu @@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; @@ -149,8 +151,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -159,9 +161,6 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp new file mode 100644 index 0000000000..29750b8baa --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + StaticallyIndexedArray{}> b_thread_dequant_bufs; + StaticallyIndexedArray{}> + b_thread_dequant_bufs_up; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I0)); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up + [mfma_reg_buf][Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(local_read_buf)); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I1)); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp new file mode 100644 index 0000000000..73749c6309 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,573 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // if constexpr(i.value < num_buffer_load_inst_a) { + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // } + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index a94ef03008..074b5873ee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -3,8 +3,10 @@ #pragma once +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" @@ -33,57 +35,112 @@ template + index_t KPack, + bool GUFusion = false> constexpr auto BlockGemmBPreshufflePipeline_Selector() { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { if constexpr(std::is_same::value) { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + } } else { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< - BlkGemmPipeSche, - BlockSize, - ADataType, - BDataType, - ComputeDataType, - AccDataType, - ATileDesc, - BTileDesc, - AMmaTileDesc, - BMmaTileDesc, - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack>{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index d7ba2559ea..ce507ca8d3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); static constexpr auto xdlops_gemm = XdlopsGemm{}; @@ -333,7 +334,7 @@ struct BlockwiseGemmXdlops_pipeline_base return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( c_grid_desc_g_m0_n0_m1_n1_m2_n2); } - + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp index 859649185a..92aef65388 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadGroupTensorSliceTransfer_v4r1_gather @@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather const DstDesc& dst_desc, const Index& dst_block_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : threadwise_transfer_(src_desc, make_zero_multi_index(), src_element_op, @@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, + IndexType, GatherDim, NumThreadScratch>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index cf758e4d5f..bee0b01a74 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,6 +42,7 @@ template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); } } @@ -149,7 +149,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or @@ -169,10 +169,9 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } @@ -230,6 +229,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, + IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 03db4bdd41..08d177035e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -65,9 +65,12 @@ template ; RunKernel(kernel); } @@ -281,8 +287,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -297,8 +301,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -308,8 +310,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -329,8 +329,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index a2d1114bbe..255fb8cff4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" @@ -26,12 +26,17 @@ namespace ck { // two lds chunks. // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -45,22 +50,19 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run(karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -70,8 +72,6 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -86,23 +86,20 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds(karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -154,7 +151,12 @@ template ) @@ -497,8 +500,8 @@ struct GridwiseMoeGemm } template - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) { const auto c_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) @@ -909,7 +912,8 @@ struct GridwiseMoeGemm NPerXdl, MXdlPerWave, NXdlPerWave, - KPack>())>; + KPack, + IsInputGemm>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -1141,9 +1145,7 @@ struct GridwiseMoeGemm template + TailNumber TailNum = TailNumber::Odd> __device__ static void Run(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1203,6 +1205,7 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; const index_t token0 = @@ -1218,7 +1221,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray gather_offsets; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1226,9 +1229,10 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1239,7 +1243,6 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1269,6 +1272,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1311,24 +1315,74 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -1356,6 +1410,185 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1453,17 +1686,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -1526,7 +1750,8 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -1538,7 +1763,6 @@ struct GridwiseMoeGemm auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -1568,35 +1792,21 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - StaticallyIndexedArray scatter_weights; //= for topk + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - if constexpr(MulRoutedWeight) - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - if constexpr(sizeof(ADataType) < 2) - weight = p_sorted_weights_2[c_token_pos + m0] * weight; - else - weight = p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; }); block_sync_lds(); @@ -1604,7 +1814,7 @@ struct GridwiseMoeGemm // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, + c_thread_buf_fp32, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf); @@ -1617,8 +1827,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { @@ -1643,9 +1852,7 @@ struct GridwiseMoeGemm template + TailNumber TailNum = TailNumber::Odd> __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1721,7 +1928,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray + StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; @@ -1730,7 +1937,7 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1773,6 +1980,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, 2>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1967,11 +2175,12 @@ struct GridwiseMoeGemm const DDataType* ptr_ = p_ds_grid[i]; // hack logic here to support different kind of strides. todo fix it. // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } + // if(i.value == 1) + // { + // ptr_ += + // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : + // 1); + // } return make_dynamic_buffer( ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); }, @@ -2036,7 +2245,8 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -2078,12 +2288,9 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2091,23 +2298,11 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - if constexpr(MulRoutedWeight) - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - if constexpr(sizeof(ADataType) < 2) - weight = p_sorted_weights_2[c_token_pos + m0] * weight; - else - weight = p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; }); block_sync_lds(); @@ -2128,8 +2323,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index bb9a452761..bd6fe772e4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadwiseTensorSliceTransfer_v3r1_gather @@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather const DstDesc& dst_desc, const Index& dst_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), src_element_op_(src_element_op), @@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); - const index_t ld_offset = src_coord_.GetOffset() + gather_offset; + const IndexType ld_offset = src_coord_.GetOffset() + gather_offset; src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, true); @@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather DstCoord dst_coord_; const SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_; - StaticallyIndexedArray gather_offsets_; + StaticallyIndexedArray gather_offsets_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 6a1c195dc1..7cd0a0fc7f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,6 +43,7 @@ template typename DstResetCoordinateAfterRunFlags, // Sequence + typename IndexType, index_t ScatterDim = 1, bool OutputScatter = true, index_t ScatterWeightIdx = 3, @@ -153,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -172,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter src_coords_[i]); oob_val = oob_val & is_src_valid; - if(i.value == ScatterWeightIdx) - { - static_assert(SrcScalarPerVectors{}[Number{}] == 1, - "scatter weight dim, should only one vec"); - constexpr auto iScatter = - SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); - static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { - src_vectors(i).template AsType()(j) = - scatter_weights(Number{}); - }); - } - else if constexpr(SrcScalarPerVectors{}[i] == 1) - { - auto data_types = SrcDatas{}; - using DataType = remove_cvref_t; - const auto tmp = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); - } - else - { - src_vectors(i).template AsType()(I0) = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - } + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); }); constexpr auto get_elem_op_vec_len = []() { @@ -412,7 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -420,8 +397,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // loop over space-filling curve static_for<0, dst_num_access, 1>{}([&](auto iAccess) { - auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; - auto scatter_offset = 0; + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + IndexType scatter_offset = 0; if constexpr(OutputScatter) { constexpr auto iScatter = @@ -431,8 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; - auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); + // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + // dst_coords_[i]); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); dst_bufs(i).template Update( @@ -488,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 1a0ea27eab..1d80f196b5 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -24,7 +24,8 @@ template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + typename IndexType = index_t> struct DynamicBuffer { using type = T; @@ -59,16 +60,16 @@ struct DynamicBuffer return BufferAddressSpace; } - __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; } - __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; } template >::type, typename scalar_type>::type>::value || !is_native_type(), bool>::type = false> - __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -79,7 +80,7 @@ struct DynamicBuffer "wrong! X should contain multiple T"); #if CK_USE_AMD_BUFFER_LOAD - bool constexpr use_amd_buffer_addressing = true; + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -140,7 +141,7 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::Set) { @@ -191,8 +192,8 @@ struct DynamicBuffer template __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf, - index_t src_offset, - index_t dst_offset, + IndexType src_offset, + IndexType dst_offset, bool is_valid_element) const { // Copy data from global to LDS memory using direct loads. @@ -214,7 +215,7 @@ struct DynamicBuffer typename scalar_type>::type>::value || !is_native_type(), bool>::type = false> - __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -224,8 +225,8 @@ struct DynamicBuffer static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); -#if CK_USE_AMD_BUFFER_STORE - bool constexpr use_amd_buffer_addressing = true; +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -342,11 +343,12 @@ struct DynamicBuffer { if(is_valid_element) { -#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS +#if 0 X tmp = x; __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); #else + // if(i >= 2169041600) *c_style_pointer_cast(&p_data_[i]) = x; #endif } @@ -357,7 +359,7 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x) { using scalar_t = typename scalar_type>::type; @@ -378,12 +380,14 @@ struct DynamicBuffer (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) - bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = - is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || - (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); + sizeof(IndexType) <= sizeof(int32_t) && + (is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0)); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -408,12 +412,12 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr IndexType scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr IndexType scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -421,8 +425,9 @@ struct DynamicBuffer static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 - using scalar_t = typename scalar_type>::type; - bool constexpr use_amd_buffer_addressing = is_same_v, double>; + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, double>; #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -455,6 +460,17 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el p, element_space_size}; } +template +__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p, + ElementSpaceSize element_space_size) +{ + return DynamicBuffer{ + p, element_space_size}; +} + template < AddressSpaceEnum BufferAddressSpace, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index b1a0c1fc5d..ec055fb2a2 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -23,6 +23,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number) return generate_tuple_for(f, make_index_sequence{}); } +template +__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber) +{ + return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + template __host__ __device__ constexpr auto generate_tie(F&& f, Number) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 72c9dc86ac..120bf7484a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -23,12 +23,14 @@ template + index_t ActivationType_ = 0, + bool MulRoutedWeight = true, + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct ReferenceMoeGemm : public device::BaseOperator { // Argument + static constexpr auto ActivationType = ActivationType_; struct Argument : public device::BaseArgument { Argument(const Tensor& sorted_token_ids, @@ -36,7 +38,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_t, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, const Tensor& d2, AElementwiseOperation a_element_op, @@ -47,7 +51,9 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id_{max_token_id}, sorted_tile_size_{sorted_tile_size}, a_t_k_{a_t_k}, + a_scale_t_{a_scale_t}, b_e_n_k_{b_e_n_k}, + b_scale_e_n_{b_scale_e_n}, c_t_k_n_{c_t_k_n}, d2_{d2}, a_element_op_{a_element_op}, @@ -61,7 +67,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id_; index_t sorted_tile_size_; const Tensor& a_t_k_; + const Tensor& a_scale_t_; const Tensor& b_e_n_k_; + const Tensor& b_scale_e_n_; Tensor& c_t_k_n_; const Tensor& d2_; @@ -77,11 +85,17 @@ struct ReferenceMoeGemm : public device::BaseOperator float Run(const Argument& arg) { - auto f_mk_kn_mn = [&](auto m, auto n) { + static_assert(ActivationType < 2, "Not supported activation type"); + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc_up{0}; + ComputeTypeB v_b_up{0}; AccDataType v_acc{0}; + ComputeTypeA v_a{0}; ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int e = arg.expert_ids_(m / arg.sorted_tile_size_); @@ -102,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else @@ -112,42 +126,79 @@ struct ReferenceMoeGemm : public device::BaseOperator // same for B matrix if constexpr(is_same_v) { - uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; - uint8_t i4 = 0; + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data; + uint8_t i4 = 0; + uint8_t i4_up = 0; if(k % 2 == 1) - i4 = (i4x2 >> 0) & 0xf; + { + i4 = (i4x2 >> 0) & 0xf; + i4_up = (i4x2_up >> 0) & 0xf; + } else - i4 = (i4x2 >> 4) & 0xf; + { + i4 = (i4x2 >> 4) & 0xf; + i4_up = (i4x2_up >> 4) & 0xf; + } #if CK_USE_PK4_LAYOUT_SHUFFLE - v_b = i4_to_f32_gfx9(i4); + v_b = i4_to_f32_gfx9(i4); + v_b_up = i4_to_f32_gfx9(i4_up); #else - v_b = i4 - 8; + v_b = i4 - 8; + v_b_up = i4_up - 8; #endif } else { arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); } CDataType v_c{0}; - + CDataType v_c_up{0}; if constexpr(MulRoutedWeight) { v_acc *= v_topk_w; + v_acc_up *= v_topk_w; } arg.c_element_op_(v_c, v_acc); + arg.c_element_op_(v_c_up, v_acc_up); - arg.c_t_k_n_(t, topk_id, n) = v_c; + if constexpr(ActivationType == 1) + { + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Silu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + else if constexpr(ActivationType == 0) + { + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Gelu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } } }; const ck::index_t max_token_id = arg.max_token_id_(0); - make_ParallelTensorFunctor( - f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])( + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( std::thread::hardware_concurrency()); return 0; @@ -173,7 +224,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_n, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, const Tensor& d2, AElementwiseOperation a_element_op, @@ -185,7 +238,9 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id, sorted_tile_size, a_t_k, + a_scale_n, b_e_n_k, + b_scale_e_n, c_t_k_n, d2, a_element_op, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index fb5c71e30a..5c932fcb18 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -25,7 +25,7 @@ template struct ReferenceMoeGemm2 : public device::BaseOperator