diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128_gelu.co similarity index 100% rename from example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128.co rename to example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128_gelu.co diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128_silu.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128_silu.co new file mode 100644 index 0000000000..50369a50a1 Binary files /dev/null and b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_32x128x128_silu.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128_gelu.co similarity index 100% rename from example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128.co rename to example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128_gelu.co diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128_silu.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128_silu.co new file mode 100644 index 0000000000..ba4fa2221b Binary files /dev/null and b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v1_64x128x128_silu.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128_gelu.co similarity index 100% rename from example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128.co rename to example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128_gelu.co diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128_silu.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128_silu.co new file mode 100644 index 0000000000..34b0f009ae Binary files /dev/null and b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage1_v3_64x128x128_silu.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co old mode 100755 new mode 100644 index 83aaa89665..73a70a6d05 Binary files a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co and b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co old mode 100755 new mode 100644 index b15bc9b360..9dcd7a4e64 Binary files a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co and b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co old mode 100755 new mode 100644 index 71759b072e..648d7ee52f Binary files a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co and b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co differ diff --git a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_64x128x128.co b/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_64x128x128.co deleted file mode 100755 index 849c3fa06d..0000000000 Binary files a/example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_64x128x128.co and /dev/null differ diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index 7e5099fd53..fee16b5472 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -38,7 +38,7 @@ using B0DataType = F8; using B1DataType = F32; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = EDataType; +using CShuffleDataType = F32; using D2DataType = F32; using DsDataType = ck::Tuple; @@ -124,10 +124,10 @@ static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul static constexpr bool MulRoutedWeight = false; -#if 0 +#if 1 static constexpr ck::index_t MPerBlock = 32; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 16; @@ -179,7 +179,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; #endif // clang-format on @@ -201,7 +201,7 @@ int main(int argc, char* argv[]) // ck::index_t valid_tile_num = 13; ck::index_t sorted_tile_num = 259; ck::index_t valid_tile_num = 256; - ck::index_t tokens = 8192; + ck::index_t tokens = 4096; #else // deepseek ck::index_t N = 2048; diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index e5471fc259..dd6f58f678 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -39,7 +39,7 @@ using B0DataType = F8; using B1DataType = F32; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = EDataType; +using CShuffleDataType = F32; //todo: change to EDataType using D2DataType = F32; using DsDataType = ck::Tuple; @@ -58,29 +58,27 @@ struct MulABScaleExpertWeight template __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; // for real kernel use + + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + (void) d2; + e = ck::type_convert(c); + } template <> __host__ __device__ constexpr void operator()(EDataType& e, const float& c, const float& d2) const { - // for real kernel use - (void)d2; - e = ck::type_convert(c); - } - template <> - __host__ __device__ constexpr void - operator()(EDataType& e, const EDataType& c, const float& d2) const - { - (void)d2; - e = ck::type_convert(c); - } // for reference cpu - template <> + e = ck::type_convert(c* d2); + } + template <> __host__ __device__ constexpr void operator()(float& e, const float& c, const float& d2) const { - // for reference cpu - (void)d2; - e = ck::type_convert(c); + // for reference cpu + e = ck::type_convert(c* d2); } }; @@ -158,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, true, int32_t, A0DataType>; #else -static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< +static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< Row, Col, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -169,11 +167,11 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor MPerBlock, 128, 128, 16, 16, 16, 16, - 4, 2, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, true, int32_t, A0DataType>; #endif // clang-format on @@ -484,7 +482,7 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2BlockScale= 3 ? 1 : 0; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 2f8e854514..47114805be 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -355,6 +355,10 @@ struct DeviceMoeGemmBlockScale #if CK_USE_ASM_MOE_BLOCKSCALE (void)minimum_occupancy; (void)MemoryDataOp; + //do_weight stage check + if (MulRoutedWeight == IsInputGemm){ + throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n"); + } // get .co file name for ASM. select by version and shape. std::string hsa_name = ""; if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) @@ -371,7 +375,17 @@ struct DeviceMoeGemmBlockScale } else { - printf("Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n"); + throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n"); + } + if constexpr(ActivationOP == Activation::silu_and_mul){ + hsa_name += "_silu"; + } + else if constexpr(ActivationOP == Activation::gelu_and_mul){ + hsa_name += "_gelu"; + } + else + { + throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n"); } } else @@ -386,7 +400,7 @@ struct DeviceMoeGemmBlockScale } else { - printf("Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n"); + throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n"); } } } @@ -400,7 +414,17 @@ struct DeviceMoeGemmBlockScale } else { - printf("Faild: v3 only support 64x128x1288.\n"); + throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 64x128x1288.\n"); + } + if constexpr(ActivationOP == Activation::silu_and_mul){ + hsa_name += "_silu"; + } + else if constexpr(ActivationOP == Activation::gelu_and_mul){ + hsa_name += "_gelu"; + } + else + { + throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n"); } } else @@ -409,19 +433,19 @@ struct DeviceMoeGemmBlockScale { hsa_name = std::string("moe_bs_stage2_v3_128x128x128"); } - else if constexpr(MPerBlock == 64) - { - hsa_name = std::string("moe_bs_stage2_v3_64x128x128"); - } + // else if constexpr(MPerBlock == 64) + // { + // hsa_name = std::string("moe_bs_stage2_v3_64x128x128"); + // } else { - printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n"); + throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 128x128x128.\n"); } } } else { - printf("Faild: only support v1 or v3.\n"); + throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n"); } // launch kernel if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 11294f556e..4cbc9c16f6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -288,14 +288,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter oob_val = oob_val & is_src_valid; if(i.value == ScatterWeightIdx) { - auto data_types = SrcDatas{}; - using DataType = remove_cvref_t; static_assert(SrcScalarPerVectors{}[Number{}] == 1, "scatter weight dim, should only one vec"); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { - src_vectors(i).template AsType()(j) = + src_vectors(i).template AsType()(j) = scatter_weights(Number{}); }); }