mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Fix v2 topk_weight cal. Add silu asm.
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co
Executable file → Normal file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_128x128x128.co
Executable file → Normal file
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co
Executable file → Normal file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v1_32x128x128.co
Executable file → Normal file
Binary file not shown.
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co
Executable file → Normal file
BIN
example/65_gemm_multiply_multiply/hsa/gfx942/moe_bs_stage2_v3_128x128x128.co
Executable file → Normal file
Binary file not shown.
Binary file not shown.
@@ -38,7 +38,7 @@ using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using CShuffleDataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
@@ -124,10 +124,10 @@ static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
@@ -179,7 +179,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -201,7 +201,7 @@ int main(int argc, char* argv[])
|
||||
// ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_tile_num = 259;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
ck::index_t tokens = 8192;
|
||||
ck::index_t tokens = 4096;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
|
||||
@@ -39,7 +39,7 @@ using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using CShuffleDataType = F32; //todo: change to EDataType
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
@@ -58,29 +58,27 @@ struct MulABScaleExpertWeight
|
||||
template <typename E, typename C, typename D2>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
|
||||
// for real kernel use
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void) d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
e = ck::type_convert<EDataType>(c* d2);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c* d2);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -158,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, true, int32_t, A0DataType>;
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
@@ -169,11 +167,11 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, true, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -484,7 +482,7 @@ int main(int argc, char* argv[])
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
|
||||
float,
|
||||
CShuffleDataType,
|
||||
float,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
|
||||
@@ -254,9 +254,9 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
|
||||
constexpr auto buffer_load_issue_point_b = 0;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
|
||||
constexpr auto ds_write_issue_point = 0;
|
||||
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
|
||||
|
||||
|
||||
@@ -355,6 +355,10 @@ struct DeviceMoeGemmBlockScale
|
||||
#if CK_USE_ASM_MOE_BLOCKSCALE
|
||||
(void)minimum_occupancy;
|
||||
(void)MemoryDataOp;
|
||||
//do_weight stage check
|
||||
if (MulRoutedWeight == IsInputGemm){
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n");
|
||||
}
|
||||
// get .co file name for ASM. select by version and shape.
|
||||
std::string hsa_name = "";
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
@@ -371,7 +375,17 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
|
||||
}
|
||||
if constexpr(ActivationOP == Activation::silu_and_mul){
|
||||
hsa_name += "_silu";
|
||||
}
|
||||
else if constexpr(ActivationOP == Activation::gelu_and_mul){
|
||||
hsa_name += "_gelu";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -386,7 +400,7 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -400,7 +414,17 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: v3 only support 64x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 64x128x1288.\n");
|
||||
}
|
||||
if constexpr(ActivationOP == Activation::silu_and_mul){
|
||||
hsa_name += "_silu";
|
||||
}
|
||||
else if constexpr(ActivationOP == Activation::gelu_and_mul){
|
||||
hsa_name += "_gelu";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -409,19 +433,19 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
}
|
||||
// else if constexpr(MPerBlock == 64)
|
||||
// {
|
||||
// hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
// }
|
||||
else
|
||||
{
|
||||
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 128x128x128.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: only support v1 or v3.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n");
|
||||
}
|
||||
// launch kernel
|
||||
if(has_main_k_block_loop)
|
||||
|
||||
@@ -288,14 +288,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<DataType>()(j) =
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user