mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fix v2 topk_weight cal. Add silu asm.
This commit is contained in:
@@ -254,9 +254,9 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
|
||||
constexpr auto buffer_load_issue_point_b = 0;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
num_mfma_perstage / buffer_load_perstage_more ? num_mfma_perstage / buffer_load_perstage_more : 1;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
num_mfma_perstage / buffer_load_perstage_less ? num_mfma_perstage / buffer_load_perstage_less : 1;
|
||||
constexpr auto ds_write_issue_point = 0;
|
||||
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
|
||||
|
||||
|
||||
@@ -355,6 +355,10 @@ struct DeviceMoeGemmBlockScale
|
||||
#if CK_USE_ASM_MOE_BLOCKSCALE
|
||||
(void)minimum_occupancy;
|
||||
(void)MemoryDataOp;
|
||||
//do_weight stage check
|
||||
if (MulRoutedWeight == IsInputGemm){
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Only gemm2 can do weight.\n");
|
||||
}
|
||||
// get .co file name for ASM. select by version and shape.
|
||||
std::string hsa_name = "";
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
@@ -371,7 +375,17 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 64x128x1288.\n");
|
||||
}
|
||||
if constexpr(ActivationOP == Activation::silu_and_mul){
|
||||
hsa_name += "_silu";
|
||||
}
|
||||
else if constexpr(ActivationOP == Activation::gelu_and_mul){
|
||||
hsa_name += "_gelu";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -386,7 +400,7 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm2 only support 32x128x128 or 128x128x1288.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -400,7 +414,17 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: v3 only support 64x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 64x128x1288.\n");
|
||||
}
|
||||
if constexpr(ActivationOP == Activation::silu_and_mul){
|
||||
hsa_name += "_silu";
|
||||
}
|
||||
else if constexpr(ActivationOP == Activation::gelu_and_mul){
|
||||
hsa_name += "_gelu";
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: Gemm1 ACT only support silu or gelu.\n");
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -409,19 +433,19 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_128x128x128");
|
||||
}
|
||||
else if constexpr(MPerBlock == 64)
|
||||
{
|
||||
hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
}
|
||||
// else if constexpr(MPerBlock == 64)
|
||||
// {
|
||||
// hsa_name = std::string("moe_bs_stage2_v3_64x128x128");
|
||||
// }
|
||||
else
|
||||
{
|
||||
printf("Faild: v3 only support 128x128x1288 or 64x128x1288.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: v3 only support 128x128x128.\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Faild: only support v1 or v3.\n");
|
||||
throw std::runtime_error("MOE_BS_ASM Faild: only support v1 or v3.\n");
|
||||
}
|
||||
// launch kernel
|
||||
if(has_main_k_block_loop)
|
||||
|
||||
@@ -288,14 +288,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
oob_val = oob_val & is_src_valid;
|
||||
if(i.value == ScatterWeightIdx)
|
||||
{
|
||||
auto data_types = SrcDatas{};
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<DataType>()(j) =
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user