enable do top k weights in moe stage1 gemm (#2094)

* add switch for mul topk weights

* fix bf16/f16 bugs

* complete
This commit is contained in:
lalala-sh
2025-04-18 10:45:49 +08:00
committed by GitHub
parent 213b203a3c
commit bcf5bb41be
8 changed files with 203 additions and 68 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -18,10 +18,12 @@ namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename D2DataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool MulRoutedWeight = false,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm : public device::BaseOperator
@@ -36,6 +38,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_k_n,
const Tensor<D2DataType>& d2,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
@@ -46,6 +49,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k},
c_t_k_n_{c_t_k_n},
d2_{d2},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
@@ -59,6 +63,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_t_k_n_;
const Tensor<D2DataType>& d2_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
@@ -81,6 +86,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
D2DataType v_topk_w = arg.d2_(m, 0); // expert
if(t < token_cnt)
{
for(int k = 0; k < K; ++k)
@@ -128,6 +134,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
}
CDataType v_c{0};
if constexpr(MulRoutedWeight)
{
v_acc *= v_topk_w;
}
arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c;
@@ -164,6 +175,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_k_n,
const Tensor<D2DataType>& d2,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
@@ -175,6 +187,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
a_t_k,
b_e_n_k,
c_t_k_n,
d2,
a_element_op,
b_element_op,
c_element_op};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -25,6 +25,7 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool MulRoutedWeight = false,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm2 : public device::BaseOperator
@@ -143,7 +144,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
if constexpr(MulRoutedWeight)
{
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
}
else
{
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f);
}
arg.c_t_n_(t, n) += v_c;
}
};