enable do top k weights in moe stage1 gemm (#2094)

* add switch for mul topk weights

* fix bf16/f16 bugs

* complete

[ROCm/composable_kernel commit: bcf5bb41be]
This commit is contained in:
lalala-sh
2025-04-18 10:45:49 +08:00
committed by GitHub
parent 6d0890b6f4
commit 2d0b5aba13
8 changed files with 203 additions and 68 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -39,14 +39,16 @@ using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// for gate, a_scale, b_scale
struct MulABScale
@@ -83,9 +85,36 @@ struct MulABScaleSilu
}
};
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for real kernel use
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
// tofix:felix
(void)d2;
e = ck::type_convert<EDataType>(c * d1 * d0);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using CDEElementOp = MulABScale;
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false
// using CDEElementOp = MulABScaleSiluMulGate;
@@ -133,11 +162,13 @@ static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = true;
static constexpr bool MulRoutedWeight = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
@@ -157,8 +188,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
// clang-format on
@@ -224,7 +255,7 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 0};
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
@@ -266,6 +297,7 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
@@ -273,6 +305,7 @@ int main(int argc, char* argv[])
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
@@ -283,24 +316,28 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
sorted_token_ids.mDesc.GetElementSpaceSize());
@@ -310,6 +347,7 @@ int main(int argc, char* argv[])
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// a0_t_k.savetxt("a.txt");
// d0_t_n.savetxt("d0_t_n.txt", "int");
@@ -320,6 +358,7 @@ int main(int argc, char* argv[])
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
@@ -342,7 +381,8 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
@@ -392,10 +432,12 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType,
CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
PassThrough,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
@@ -406,6 +448,7 @@ int main(int argc, char* argv[])
a0_t_k,
b0_e_n_k,
c_t_k_n,
d2_e_n,
PassThrough{},
PassThrough{},
PassThrough{});
@@ -428,7 +471,8 @@ int main(int argc, char* argv[])
cde_element_op(e_t_n_host_result(t, topk_id, n),
c_t_k_n(t, topk_id, n),
d0_t_n(t, n),
d1_e_n(e, n));
d1_e_n(e, n),
1.f);
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -39,14 +39,15 @@ using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using DsLayout = ck::Tuple<D0Layout, D1Layout, ELayout>;
// for gate, a_scale, b_scale
struct MulABScale
@@ -91,7 +92,39 @@ struct MulABScaleSilu
}
};
using CDEElementOp = MulABScale;
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
(void)d2;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
#else
e = ck::type_convert<EDataType>(c * d1 * d0);
#endif
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d0 * d1 * d2 * 16);
#else
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
#endif
}
};
using CDEElementOp = MulABScaleExpertWeight;
#if 1
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
@@ -164,6 +197,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
#else
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr bool MulRoutedWeight = false;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
@@ -175,8 +209,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
// clang-format on
#endif
@@ -265,6 +299,7 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
@@ -283,18 +318,21 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
sorted_token_ids.mDesc.GetElementSpaceSize());
@@ -304,6 +342,7 @@ int main(int argc, char* argv[])
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2);
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
@@ -312,6 +351,7 @@ int main(int argc, char* argv[])
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
@@ -424,7 +464,8 @@ int main(int argc, char* argv[])
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
@@ -480,10 +521,12 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType,
CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
PassThrough,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
@@ -494,6 +537,7 @@ int main(int argc, char* argv[])
a0_t_k,
b0_e_n_k,
c_t_k_n,
d2_e_n,
PassThrough{},
PassThrough{},
PassThrough{});
@@ -516,7 +560,8 @@ int main(int argc, char* argv[])
cde_element_op(e_t_n_host_result(t, topk_id, n),
c_t_k_n(t, topk_id, n),
d0_t_n(t, n),
d1_e_n(e, n));
d1_e_n(e, n),
1.f);
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -135,6 +135,7 @@ static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool MulRoutedWeight = false;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
@@ -164,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -409,7 +410,8 @@ int main(int argc, char* argv[])
AccDataType,
PassThrough,
PassThrough,
CDEElementOp>;
CDEElementOp,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -138,6 +138,7 @@ static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
@@ -149,7 +150,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -455,7 +456,8 @@ int main(int argc, char* argv[])
AccDataType,
PassThrough,
PassThrough,
CDEElementOp>;
CDEElementOp,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -67,6 +67,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
@@ -270,6 +271,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -280,6 +282,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -295,6 +298,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -305,6 +309,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -325,6 +330,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
IsInputGemm,
MulRoutedWeight,
TailNumber::Odd>;
RunKernel(kernel);
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -31,6 +31,7 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -44,19 +45,22 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
IsInputGemm,
MulRoutedWeight,
TailNum>(karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -67,6 +71,7 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -81,21 +86,23 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::
template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
IsInputGemm,
MulRoutedWeight,
TailNum>(karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
@@ -1134,8 +1141,9 @@ struct GridwiseMoeGemm
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
bool IsInputGemm = true,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
@@ -1492,7 +1500,7 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -1579,10 +1587,13 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
if constexpr(MulRoutedWeight)
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
if constexpr(sizeof(ADataType) < 2)
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
else
weight = p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
@@ -1632,8 +1643,9 @@ struct GridwiseMoeGemm
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
bool IsInputGemm = true,
bool MulRoutedWeight = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
const index_t* p_max_token_id,
@@ -1998,7 +2010,7 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix
constexpr index_t scatter_weight_idx = 3; // hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
@@ -2086,10 +2098,13 @@ struct GridwiseMoeGemm
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
else
if constexpr(MulRoutedWeight)
{
const float* p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
if constexpr(sizeof(ADataType) < 2)
weight = p_sorted_weights_2[c_token_pos + m0] * weight;
else
weight = p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -18,10 +18,12 @@ namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename D2DataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool MulRoutedWeight = false,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm : public device::BaseOperator
@@ -36,6 +38,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_k_n,
const Tensor<D2DataType>& d2,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
@@ -46,6 +49,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k},
c_t_k_n_{c_t_k_n},
d2_{d2},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
@@ -59,6 +63,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_t_k_n_;
const Tensor<D2DataType>& d2_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
@@ -81,6 +86,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
D2DataType v_topk_w = arg.d2_(m, 0); // expert
if(t < token_cnt)
{
for(int k = 0; k < K; ++k)
@@ -128,6 +134,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
}
CDataType v_c{0};
if constexpr(MulRoutedWeight)
{
v_acc *= v_topk_w;
}
arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c;
@@ -164,6 +175,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_k_n,
const Tensor<D2DataType>& d2,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
@@ -175,6 +187,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
a_t_k,
b_e_n_k,
c_t_k_n,
d2,
a_element_op,
b_element_op,
c_element_op};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -25,6 +25,7 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool MulRoutedWeight = false,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm2 : public device::BaseOperator
@@ -143,7 +144,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
if constexpr(MulRoutedWeight)
{
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
}
else
{
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f);
}
arg.c_t_n_(t, n) += v_c;
}
};