From cd12af75af32d06183c2aaa26344b11fa0ad9f3d Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 26 Mar 2025 06:50:21 +0000 Subject: [PATCH] support bf16 cshuffle --- .../moe_gemm2_xdl_fp8.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 8fb6fbc742..002497e9d2 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -36,7 +36,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -67,6 +67,15 @@ struct MulABScaleExpertWeight (void)d2; e = ck::type_convert(c); } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } // for reference cpu template <> __host__ __device__ constexpr void operator()( @@ -408,7 +417,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2