mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
enable do top k weights in moe stage1 gemm (#2094)
* add switch for mul topk weights * fix bf16/f16 bugs * complete
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -18,10 +18,12 @@ namespace host {
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename D2DataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
bool MulRoutedWeight = false,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceMoeGemm : public device::BaseOperator
|
||||
@@ -36,6 +38,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
const Tensor<D2DataType>& d2,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
@@ -46,6 +49,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
a_t_k_{a_t_k},
|
||||
b_e_n_k_{b_e_n_k},
|
||||
c_t_k_n_{c_t_k_n},
|
||||
d2_{d2},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
@@ -59,6 +63,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ADataType>& a_t_k_;
|
||||
const Tensor<BDataType>& b_e_n_k_;
|
||||
Tensor<CDataType>& c_t_k_n_;
|
||||
const Tensor<D2DataType>& d2_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
@@ -81,6 +86,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
|
||||
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
|
||||
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
|
||||
D2DataType v_topk_w = arg.d2_(m, 0); // expert
|
||||
if(t < token_cnt)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
@@ -128,6 +134,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
}
|
||||
CDataType v_c{0};
|
||||
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
v_acc *= v_topk_w;
|
||||
}
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c;
|
||||
@@ -164,6 +175,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
const Tensor<D2DataType>& d2,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
@@ -175,6 +187,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
|
||||
a_t_k,
|
||||
b_e_n_k,
|
||||
c_t_k_n,
|
||||
d2,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -25,6 +25,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
bool MulRoutedWeight = false,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
@@ -143,7 +144,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
CDataType v_c{0};
|
||||
D0DataType v_d0 = arg.d0_(m, n); // a
|
||||
D0DataType v_d1 = arg.d1_(e, n); // b
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f);
|
||||
}
|
||||
arg.c_t_n_(t, n) += v_c;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user