v3 function pass

This commit is contained in:
OscarXu
2025-05-16 03:42:48 -05:00
parent c5be9a501b
commit 39ff3fbf05
4 changed files with 1046 additions and 6 deletions

View File

@@ -177,7 +177,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
#endif
// clang-format on

View File

@@ -4,8 +4,9 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_moe_mx.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_moe_mx_gufusion_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_mx_moe_gufusion_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v3_moe_mx.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_mx_mmoe_gufusion_v3.hpp"
namespace ck {
@@ -73,7 +74,7 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_moe_mx_gufusion_v1<
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
@@ -98,7 +99,7 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_moe_mx<
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
@@ -126,11 +127,32 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
if constexpr(GUFusion)
{
return nullptr;
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3_moe_mx<
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,