From a8a82e0cfc63dc42ee0aa28871fe10bb070a935b Mon Sep 17 00:00:00 2001 From: coderfeli Date: Tue, 11 Feb 2025 01:54:08 +0000 Subject: [PATCH] fix warnings and impl scale for gemm2, build ok --- .../65_gemm_multiply_multiply/moe_gemm1.cpp | 5 +- .../65_gemm_multiply_multiply/moe_gemm2.cpp | 88 ++++++++++--------- .../gpu/device/impl/device_moe_gemm.hpp | 2 +- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 2 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 2 +- .../cpu/reference_moe_gemm2.hpp | 27 ++++-- 6 files changed, 76 insertions(+), 50 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp index 2ae610ffea..5505d48ea4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1.cpp @@ -152,6 +152,9 @@ static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this const static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -181,7 +184,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 1, 1, S<1, 32, 1, 8>, S, + 1, 1, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp index 791efb1b1e..5c32431c0b 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -40,9 +40,8 @@ using AccDataType = F32; using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; -using D2DataType = EDataType; -// using DsDataTypeGate = ck::Tuple; -using DsDataTypeUp = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; @@ -51,35 +50,39 @@ using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; // using DsLayoutGate = ck::Tuple; -using DsLayoutUp = ck::Tuple; +using DsLayout = ck::Tuple; -struct MultiplyMultiply +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight { template __host__ __device__ constexpr void operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; - + //gpu template <> - __host__ __device__ constexpr void operator() + __host__ __device__ constexpr void operator() (EDataType& e, const float& c, const float& d0, const float& d1, - const D2DataType& d2) const + const float& d2) const { - // const float x0_f = c * d0 * d1; - (void)d0; (void)d1; (void)d2; - const float x0_f = c; - e = ck::type_convert(x0_f); + e = ck::type_convert(c * d0 * d1 * d2); + } + // for reference + template <> + __host__ __device__ constexpr void operator() + (float& e, + const float& c, + const float& d0, + const float& d1, + const float& d2) const + { + e = ck::type_convert(c * d0 * d1 * d2); } - }; -// using DsLayout = DsLayoutGate; -// using DsDataType = DsDataTypeGate; -using DsLayout = DsLayoutUp; -using DsDataType = DsDataTypeUp; -using CDEElementOp = MultiplyMultiply; +using CDEElementOp = MulABScaleExpertWeight; void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) { @@ -115,7 +118,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; using BElementOp = PassThrough; -using CDEElementOp = MultiplyMultiply; +using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 32; @@ -126,6 +129,9 @@ static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -155,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S, + CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -232,16 +238,16 @@ int main(int argc, char* argv[]) Tensor a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); Tensor b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1})); - Tensor d0_t_n(HostTensorDescriptor({N, 1}, {1, 0})); - Tensor d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); - Tensor d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1})); + Tensor d0_t_n(HostTensorDescriptor({SORTED_SIZE, N}, {0, 0})); + Tensor d1_e_n(HostTensorDescriptor({experts, N}, {0, 0})); + Tensor d2_e_n(HostTensorDescriptor({experts, 1}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); e_t_n_device_result.SetZero(); std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; - std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl; - std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -252,38 +258,38 @@ int main(int argc, char* argv[]) a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d2_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 2: a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); - d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); - d2_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d2_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); - DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); - DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); a0_m_k.savetxt("a.txt"); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); - d1_device_buf.ToDevice(d1_m_n.mData.data()); - d2_device_buf.ToDevice(d2_m_n.mData.data()); + d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); e_device_buf.ToDevice(e_t_n_device_result.mData.data()); auto a_element_op = AElementOp{}; @@ -358,26 +364,26 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + CDEElementOp>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); - auto ref_argument = ref_moe_gemm.MakeArgument( - sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{}); + sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); ref_invoker.Run(ref_argument); for(int t = 0; t < tokens; ++t) { - - // const int t = sorted_token_ids(m); for(int n = 0; n < N; ++n) { - cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_m_n(t, n), d2_m_n(t, n)); + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index b4d028fc21..570cd904ec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -535,7 +535,7 @@ struct DeviceMoeGemm index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op) override { // assert(0, "no impl"); return std::make_unique(nullptr, nullptr, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index d580a39647..f167fe6212 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -901,9 +901,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather SrcCoord src_coord_; DstCoord dst_coord_; - StaticallyIndexedArray gather_offsets_; const SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_; + StaticallyIndexedArray gather_offsets_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 480fbc5ff0..2b51f48838 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -687,10 +687,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter using OOBVectorTuple = StaticallyIndexedArray; StaticallyIndexedArray oob_vectors_tuple_; - StaticallyIndexedArray scatter_offsets_; SrcCoords src_coords_; DstCoords dst_coords_; const ElementwiseOperation element_op_; + StaticallyIndexedArray scatter_offsets_; }; } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 8633f061b1..22cbe55c98 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -17,6 +17,9 @@ namespace host { template & a_m_k, const Tensor& b_e_n_k, + const Tensor& d0, + const Tensor& d1, + const Tensor& d2, Tensor& c_t_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -42,6 +48,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator sorted_tile_size_{sorted_tile_size}, a_m_k_{a_m_k}, b_e_n_k_{b_e_n_k}, + d0_{d0}, + d1_{d1}, + d2_{d2}, c_t_n_{c_t_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, @@ -49,16 +58,19 @@ struct ReferenceMoeGemm2 : public device::BaseOperator { } - const Tensor& expert_ids_; const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + index_t sorted_tile_size_; const Tensor& a_m_k_; const Tensor& b_e_n_k_; + const Tensor& d0_; + const Tensor& d1_; + const Tensor& d2_; Tensor& c_t_n_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; - index_t sorted_tile_size_; }; // Invoker @@ -106,8 +118,10 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } CDataType v_c{0}; - - arg.c_element_op_(v_c, v_acc); + D0DataType v_d0 = arg.d0_(m, n); // a + D0DataType v_d1 = arg.d1_(e, n); // b + D0DataType v_d2 = arg.d2_(e, 0); //expert + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_d2); arg.c_t_n_(t, n) += v_c; } @@ -140,12 +154,15 @@ struct ReferenceMoeGemm2 : public device::BaseOperator const index_t sorted_tile_size, const Tensor& a_m_k, const Tensor& b_e_n_k, + const Tensor& d0, + const Tensor& d1, + const Tensor& d2, Tensor& c_t_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { - return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op}; + return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op}; } static auto MakeInvoker() { return Invoker{}; }