enable do top k weights in moe stage1 gemm (#2094)

* add switch for mul topk weights

* fix bf16/f16 bugs

* complete

[ROCm/composable_kernel commit: bcf5bb41be]
This commit is contained in:
lalala-sh
2025-04-18 10:45:49 +08:00
committed by GitHub
parent 6d0890b6f4
commit 2d0b5aba13
8 changed files with 203 additions and 68 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -67,6 +67,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
@@ -270,6 +271,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -280,6 +282,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -295,6 +298,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -305,6 +309,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -325,6 +330,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -31,6 +31,7 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -44,19 +45,22 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
IsInputGemm,
MulRoutedWeight,
TailNum>(karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -67,6 +71,7 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -81,21 +86,23 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::
template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
IsInputGemm,
MulRoutedWeight,
TailNum>(karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -1134,8 +1141,9 @@ struct GridwiseMoeGemm
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
bool IsInputGemm = true,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
@@ -1492,7 +1500,7 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -1579,10 +1587,13 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
if constexpr(MulRoutedWeight)
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
if constexpr(sizeof(ADataType) < 2)
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
else
weight = p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
@@ -1632,8 +1643,9 @@ struct GridwiseMoeGemm
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
bool IsInputGemm = true,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
@@ -1998,7 +2010,7 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -2086,10 +2098,13 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
if constexpr(MulRoutedWeight)
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
if constexpr(sizeof(ADataType) < 2)
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
else
weight = p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;