mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Fix v2 topk_weight cal. Add silu asm.
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co
Executable file → Normal file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co
Executable file → Normal file
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co
Executable file → Normal file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co
Executable file → Normal file
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co
Executable file → Normal file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co
Executable file → Normal file
Binary file not shown.
Binary file not shown.
@@ -38,7 +38,7 @@ using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using CShuffleDataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
@@ -124,10 +124,10 @@ static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
@@ -179,7 +179,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -201,7 +201,7 @@ int main(int argc, char* argv[])
|
||||
// ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_tile_num = 259;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
ck::index_t tokens = 8192;
|
||||
ck::index_t tokens = 4096;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
|
||||
@@ -39,7 +39,7 @@ using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using CShuffleDataType = F32; //todo: change to EDataType
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
@@ -58,29 +58,27 @@ struct MulABScaleExpertWeight
|
||||
template <typename E, typename C, typename D2>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
|
||||
// for real kernel use
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void) d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
e = ck::type_convert<EDataType>(c* d2);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c* d2);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -158,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, true, int32_t, A0DataType>;
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
@@ -169,11 +167,11 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, true, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -484,7 +482,7 @@ int main(int argc, char* argv[])
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
|
||||
float,
|
||||
CShuffleDataType,
|
||||
float,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
|
||||
Reference in New Issue
Block a user