Fix v2 topk_weight cal. Add silu asm.

This commit is contained in:
OscarXu
2025-05-20 13:42:06 +08:00
parent f87973a4ac
commit 9fdfff82ea
15 changed files with 60 additions and 40 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -38,7 +38,7 @@ using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using CShuffleDataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -124,10 +124,10 @@ static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
#if 0
#if 1
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
@@ -179,7 +179,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on
@@ -201,7 +201,7 @@ int main(int argc, char* argv[])
// ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 259;
ck::index_t valid_tile_num = 256;
ck::index_t tokens = 8192;
ck::index_t tokens = 4096;
#else
// deepseek
ck::index_t N = 2048;

View File

@@ -39,7 +39,7 @@ using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using CShuffleDataType = F32; //todo: change to EDataType
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -58,29 +58,27 @@ struct MulABScaleExpertWeight
template <typename E, typename C, typename D2>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
(void) d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
e = ck::type_convert<EDataType>(c* d2);
}
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d2) const
{
// for reference cpu
(void)d2;
e = ck::type_convert<EDataType>(c);
// for reference cpu
e = ck::type_convert<EDataType>(c* d2);
}
};
@@ -158,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, true, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
@@ -169,11 +167,11 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 2,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, true, int32_t, A0DataType>;
#endif
// clang-format on
@@ -484,7 +482,7 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
float,
CShuffleDataType,
float,
D2DataType,
AccDataType,
PassThrough,

View File

@@ -254,9 +254,9 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;

View File

@@ -355,6 +355,10 @@ struct DeviceMoeGemmBlockScale
#if CK_USE_ASM_MOE_BLOCKSCALE
(void)minimum_occupancy;
(void)MemoryDataOp;
//do_weight stage check
if (MulRoutedWeight == IsInputGemm){
throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n");
}
// get .co file name for ASM. select by version and shape.
std::string hsa_name = "";
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
@@ -371,7 +375,17 @@ struct DeviceMoeGemmBlockScale
}
else
{
printf("Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
}
if constexpr(ActivationOP == Activation::silu_and_mul){
hsa_name += "_silu";
}
else if constexpr(ActivationOP == Activation::gelu_and_mul){
hsa_name += "_gelu";
}
else
{
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
}
}
else
@@ -386,7 +400,7 @@ struct DeviceMoeGemmBlockScale
}
else
{
printf("Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
}
}
}
@@ -400,7 +414,17 @@ struct DeviceMoeGemmBlockScale
}
else
{
printf("Faild: v3 only support 64x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 64x128x1288.\n");
}
if constexpr(ActivationOP == Activation::silu_and_mul){
hsa_name += "_silu";
}
else if constexpr(ActivationOP == Activation::gelu_and_mul){
hsa_name += "_gelu";
}
else
{
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
}
}
else
@@ -409,19 +433,19 @@ struct DeviceMoeGemmBlockScale
{
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
}
else if constexpr(MPerBlock == 64)
{
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
}
// else if constexpr(MPerBlock == 64)
// {
// hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
// }
else
{
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 128x128x128.\n");
}
}
}
else
{
printf("Faild: only support v1 or v3.\n");
throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n");
}
// launch kernel
if(has_main_k_block_loop)

View File

@@ -288,14 +288,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
oob_val = oob_val & is_src_valid;
if(i.value == ScatterWeightIdx)
{
auto data_types = SrcDatas{};
using DataType = remove_cvref_t<decltype(data_types[i])>;
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
"scatter weight dim, should only one vec");
constexpr auto iScatter =
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
src_vectors(i).template AsType<DataType>()(j) =
src_vectors(i).template AsType<float>()(j) =
scatter_weights(Number<iScatter>{});
});
}