use mem_op::set when topk=1

This commit is contained in:
lalala-sh
2025-05-08 09:49:16 +00:00
parent 960b2bce1c
commit def952a178

View File

@@ -264,77 +264,152 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
if(IsInputGemm || arg.TopK == 1)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
constexpr auto MemoryDataOp = InMemoryDataOperationEnum::Set;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
throw std::runtime_error("todo: only v1 & v2 support now");
}
}
#if 1
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
}
#endif
}
else
{
constexpr auto MemoryDataOp = InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
throw std::runtime_error("todo: only v1 & v2 support now");
}
}
#if 1
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
}
else
{
throw std::runtime_error("todo: only v1 & v2 support now");
}
}
#if 1
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
}
#endif
}
return ave_time;
}