Fix v2 topk_weight cal. Add silu asm.

This commit is contained in:
OscarXu
2025-05-20 13:42:06 +08:00
parent f87973a4ac
commit 9fdfff82ea
15 changed files with 60 additions and 40 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -38,7 +38,7 @@ using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using CShuffleDataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -124,10 +124,10 @@ static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
#if 0
#if 1
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
@@ -179,7 +179,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on
@@ -201,7 +201,7 @@ int main(int argc, char* argv[])
// ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 259;
ck::index_t valid_tile_num = 256;
ck::index_t tokens = 8192;
ck::index_t tokens = 4096;
#else
// deepseek
ck::index_t N = 2048;

View File

@@ -39,7 +39,7 @@ using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using CShuffleDataType = F32; //todo: change to EDataType
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -58,29 +58,27 @@ struct MulABScaleExpertWeight
template <typename E, typename C, typename D2>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
(void) d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
e = ck::type_convert<EDataType>(c* d2);
}
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d2) const
{
// for reference cpu
(void)d2;
e = ck::type_convert<EDataType>(c);
// for reference cpu
e = ck::type_convert<EDataType>(c* d2);
}
};
@@ -158,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, true, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
@@ -169,11 +167,11 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 2,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, true, int32_t, A0DataType>;
#endif
// clang-format on
@@ -484,7 +482,7 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
float,
CShuffleDataType,
float,
D2DataType,
AccDataType,
PassThrough,