mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
enable do top k weights in moe stage1 gemm (#2094)
* add switch for mul topk weights
* fix bf16/f16 bugs
* complete
[ROCm/composable_kernel commit: bcf5bb41be]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -39,14 +39,16 @@ using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// for gate, a_scale, b_scale
|
||||
struct MulABScale
|
||||
@@ -83,9 +85,36 @@ struct MulABScaleSilu
|
||||
}
|
||||
};
|
||||
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
|
||||
// tofix:felix
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
|
||||
// using DsLayout = DsLayoutGate;
|
||||
// using DsDataType = DsDataTypeGate;
|
||||
using CDEElementOp = MulABScale;
|
||||
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false
|
||||
|
||||
// using CDEElementOp = MulABScaleSiluMulGate;
|
||||
|
||||
@@ -133,11 +162,13 @@ static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = true;
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -157,8 +188,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -224,7 +255,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 0};
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
@@ -266,6 +297,7 @@ int main(int argc, char* argv[])
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
@@ -273,6 +305,7 @@ int main(int argc, char* argv[])
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
|
||||
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
@@ -283,24 +316,28 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
@@ -310,6 +347,7 @@ int main(int argc, char* argv[])
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
// a0_t_k.savetxt("a.txt");
|
||||
// d0_t_n.savetxt("d0_t_n.txt", "int");
|
||||
@@ -320,6 +358,7 @@ int main(int argc, char* argv[])
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
d0_device_buf.ToDevice(d0_t_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_e_n.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
@@ -342,7 +381,8 @@ int main(int argc, char* argv[])
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
|
||||
d1_device_buf.GetDeviceBuffer()},
|
||||
d1_device_buf.GetDeviceBuffer(),
|
||||
d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
@@ -392,10 +432,12 @@ int main(int argc, char* argv[])
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
|
||||
B0DataType,
|
||||
CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
PassThrough,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
@@ -406,6 +448,7 @@ int main(int argc, char* argv[])
|
||||
a0_t_k,
|
||||
b0_e_n_k,
|
||||
c_t_k_n,
|
||||
d2_e_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
@@ -428,7 +471,8 @@ int main(int argc, char* argv[])
|
||||
cde_element_op(e_t_n_host_result(t, topk_id, n),
|
||||
c_t_k_n(t, topk_id, n),
|
||||
d0_t_n(t, n),
|
||||
d1_e_n(e, n));
|
||||
d1_e_n(e, n),
|
||||
1.f);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -39,14 +39,15 @@ using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, ELayout>;
|
||||
|
||||
// for gate, a_scale, b_scale
|
||||
struct MulABScale
|
||||
@@ -91,7 +92,39 @@ struct MulABScaleSilu
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScale;
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d2;
|
||||
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d1 * d0);
|
||||
#endif
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2 * 16);
|
||||
#else
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
#if 1
|
||||
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
|
||||
@@ -164,6 +197,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
@@ -175,8 +209,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
4, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
@@ -265,6 +299,7 @@ int main(int argc, char* argv[])
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
@@ -283,18 +318,21 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
@@ -304,6 +342,7 @@ int main(int argc, char* argv[])
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
@@ -312,6 +351,7 @@ int main(int argc, char* argv[])
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
d0_device_buf.ToDevice(d0_t_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_e_n.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
@@ -424,7 +464,8 @@ int main(int argc, char* argv[])
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
|
||||
d1_device_buf.GetDeviceBuffer()},
|
||||
d1_device_buf.GetDeviceBuffer(),
|
||||
d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
@@ -480,10 +521,12 @@ int main(int argc, char* argv[])
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
|
||||
B0DataType,
|
||||
CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
PassThrough,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
@@ -494,6 +537,7 @@ int main(int argc, char* argv[])
|
||||
a0_t_k,
|
||||
b0_e_n_k,
|
||||
c_t_k_n,
|
||||
d2_e_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
@@ -516,7 +560,8 @@ int main(int argc, char* argv[])
|
||||
cde_element_op(e_t_n_host_result(t, topk_id, n),
|
||||
c_t_k_n(t, topk_id, n),
|
||||
d0_t_n(t, n),
|
||||
d1_e_n(e, n));
|
||||
d1_e_n(e, n),
|
||||
1.f);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -135,6 +135,7 @@ static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
@@ -164,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -409,7 +410,8 @@ int main(int argc, char* argv[])
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp>;
|
||||
CDEElementOp,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -138,6 +138,7 @@ static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
@@ -149,7 +150,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -455,7 +456,8 @@ int main(int argc, char* argv[])
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp>;
|
||||
CDEElementOp,
|
||||
MulRoutedWeight>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
Reference in New Issue
Block a user