Fix v2 topk_weight cal. Add silu asm.

This commit is contained in:
OscarXu
2025-05-20 13:42:06 +08:00
parent f87973a4ac
commit 9fdfff82ea
15 changed files with 60 additions and 40 deletions

View File

@@ -254,9 +254,9 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;

View File

@@ -355,6 +355,10 @@ struct DeviceMoeGemmBlockScale
#if CK_USE_ASM_MOE_BLOCKSCALE
(void)minimum_occupancy;
(void)MemoryDataOp;
//do_weight stage check
if (MulRoutedWeight == IsInputGemm){
throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n");
}
// get .co file name for ASM. select by version and shape.
std::string hsa_name = "";
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
@@ -371,7 +375,17 @@ struct DeviceMoeGemmBlockScale
}
else
{
printf("Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
}
if constexpr(ActivationOP == Activation::silu_and_mul){
hsa_name += "_silu";
}
else if constexpr(ActivationOP == Activation::gelu_and_mul){
hsa_name += "_gelu";
}
else
{
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
}
}
else
@@ -386,7 +400,7 @@ struct DeviceMoeGemmBlockScale
}
else
{
printf("Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
}
}
}
@@ -400,7 +414,17 @@ struct DeviceMoeGemmBlockScale
}
else
{
printf("Faild: v3 only support 64x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 64x128x1288.\n");
}
if constexpr(ActivationOP == Activation::silu_and_mul){
hsa_name += "_silu";
}
else if constexpr(ActivationOP == Activation::gelu_and_mul){
hsa_name += "_gelu";
}
else
{
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
}
}
else
@@ -409,19 +433,19 @@ struct DeviceMoeGemmBlockScale
{
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
}
else if constexpr(MPerBlock == 64)
{
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
}
// else if constexpr(MPerBlock == 64)
// {
// hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
// }
else
{
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 128x128x128.\n");
}
}
}
else
{
printf("Faild: only support v1 or v3.\n");
throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n");
}
// launch kernel
if(has_main_k_block_loop)

View File

@@ -288,14 +288,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
oob_val = oob_val & is_src_valid;
if(i.value == ScatterWeightIdx)
{
auto data_types = SrcDatas{};
using DataType = remove_cvref_t<decltype(data_types[i])>;
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
"scatter weight dim, should only one vec");
constexpr auto iScatter =
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
src_vectors(i).template AsType<DataType>()(j) =
src_vectors(i).template AsType<float>()(j) =
scatter_weights(Number<iScatter>{});
});
}