enable do top k weights in moe stage1 gemm (#2094)

* add switch for mul topk weights

* fix bf16/f16 bugs

* complete
This commit is contained in:
lalala-sh
2025-04-18 10:45:49 +08:00
committed by GitHub
parent 213b203a3c
commit bcf5bb41be
8 changed files with 203 additions and 68 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -67,6 +67,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
@@ -270,6 +271,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -280,6 +282,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -295,6 +298,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -305,6 +309,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -325,6 +330,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}