Moe gemm activation (#2026)

* fix useless code and remove usless oob

* clang format

* fix coredump in e2e test

* fix2

* fix clang format

* fix output oob

* impl int64 but result not correct

* int64 index ok now

* input output all ok

* fix uint32

* revert v1 test

* use uint32

* mork to support 13w tokens

* moe sorting fix moebuf

* fix merge

* update moe api fix aiter build

* fix buid

* fuse silu

* silu ok

* acale ok

* add silu

* change code

* gemm2 ok

* gufusion compatible ok, fix warnings

* gu fusion for m32 m64 ok

* support bf16 cshuffle

* i4 gemm2 ok

* i4 gemm2 ok and i4 gemm1 build

* 16x16 run ok

* change flops; change cshuffle dtype

* fuse gelu silu act in moe gemm1

* fp8 with act ready

* int4 act ready

* remove useless changes

* remove useless code change

* fix clang format

* add the arch limit of int4 moe gemm

* fuse moe activation

* fix fp8 16x16

* fix no quant case

* fix bugs

* fix fp8 gufusion bug

* remove useless comments

* refine activation code & complete moe example

* fix int8 bugs

* merge tkw1

---------

Co-authored-by: coderfeli <coderfeli@163.com>
Co-authored-by: feli <felix.li@amd.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: root <root@hjbog-srdc-51.amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
lalala-sh
2025-04-23 10:35:34 +08:00
committed by GitHub
parent 94662b02d0
commit 39ba03f25d
19 changed files with 1975 additions and 496 deletions

View File

@@ -13,6 +13,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132)
set(EXAMPLE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
endif()
set(target 1)
endif()
endforeach()

View File

@@ -25,7 +25,6 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
// using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
@@ -36,7 +35,7 @@ using A0DataType = F8;
using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = EDataType;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
@@ -61,27 +60,25 @@ struct MulABScale
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1) const
{
e = ck::type_convert<EDataType>(c * d1 * d0);
(void)d0;
(void)d1;
e = ck::type_convert<EDataType>(c);
}
};
// for gate, a_scale, b_scale, fuse silu,
struct MulABScaleSilu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float>(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
{
// act
float x0 = 0;
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
e = ck::type_convert<EDataType>(x0);
(void)d0;
(void)d1;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, EDataType, EDataType>(
EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const
{
(void)d0;
(void)d1;
e = ck::type_convert<EDataType>(c);
}
};
@@ -95,11 +92,19 @@ struct MulABScaleExpertWeight
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for real kernel use
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
// tofix:felix
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c * d1 * d0);
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
@@ -107,16 +112,14 @@ struct MulABScaleExpertWeight
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false
// using CDEElementOp = MulABScaleSiluMulGate;
using CDEElementOp = MulABScaleExpertWeight;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
@@ -155,22 +158,21 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = true;
static constexpr bool MulRoutedWeight = false;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
@@ -188,8 +190,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
// clang-format on
@@ -201,15 +203,13 @@ int main(int argc, char* argv[])
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 8;
ck::index_t valid_tile_num = 8;
ck::index_t tokens = 128;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t tokens = 64;
ck::index_t topk = 2;
// ck::index_t tokens = batch * topk;
if(argc == 1)
{
// use default case
@@ -255,28 +255,22 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{1, 1, 1};
ck::index_t KBatch = 1;
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
@@ -290,13 +284,12 @@ int main(int argc, char* argv[])
sorted_token_ids.mData[i] = tokens;
}
}
// expert_ids.savetxt("expert_ids.txt", "int");
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D1DataType> d1_e_n(
HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(
@@ -304,6 +297,7 @@ int main(int argc, char* argv[])
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
@@ -312,25 +306,25 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0, 1});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
@@ -349,9 +343,7 @@ int main(int argc, char* argv[])
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// a0_t_k.savetxt("a.txt");
// d0_t_n.savetxt("d0_t_n.txt", "int");
// d1_e_n.savetxt("d1_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
@@ -369,7 +361,8 @@ int main(int argc, char* argv[])
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
preShuffleBuffer(
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
@@ -408,9 +401,9 @@ int main(int argc, char* argv[])
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
sizeof(B0DataType) * K * N * experts +
sizeof(B0DataType) * K * N * 2 * experts +
sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -437,6 +430,7 @@ int main(int argc, char* argv[])
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
@@ -446,7 +440,9 @@ int main(int argc, char* argv[])
max_token_id,
MPerBlock,
a0_t_k,
d0_t_n,
b0_e_n_k,
d1_e_n,
c_t_k_n,
d2_e_n,
PassThrough{},
@@ -472,15 +468,14 @@ int main(int argc, char* argv[])
c_t_k_n(t, topk_id, n),
d0_t_n(t, n),
d1_e_n(e, n),
1.f);
d2_e_n(e, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
// e_t_n_device_result.savetxt("out.txt");
// e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
}

View File

@@ -36,7 +36,7 @@ using A0DataType = F8;
using B0DataType = I4;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
@@ -47,7 +47,8 @@ using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout, ELayout>;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// for gate, a_scale, b_scale
struct MulABScale
@@ -56,42 +57,32 @@ struct MulABScale
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1) const
{
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c);
#else
e = ck::type_convert<EDataType>(c);
#endif
}
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1) const
{
(void)d0;
(void)d1;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
e = ck::type_convert<EDataType>(c);
#else
e = ck::type_convert<EDataType>(c * d1 * d0);
e = ck::type_convert<EDataType>(c);
#endif
}
};
// for gate, a_scale, b_scale, fuse silu,
struct MulABScaleSilu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float>(EDataType& e,
const float& c,
const float& d0,
const float& d1) const
{
// act
float x0 = 0;
#if CK_USE_PK4_LAYOUT_SHUFFLE
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16);
#else
ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0);
#endif
e = ck::type_convert<EDataType>(x0);
}
};
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
@@ -102,13 +93,19 @@ struct MulABScaleExpertWeight
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d1 * d0 * 16);
#else
e = ck::type_convert<EDataType>(c * d1 * d0);
#endif
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
@@ -116,15 +113,18 @@ struct MulABScaleExpertWeight
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d0 * d1 * d2 * 16);
#else
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
#endif
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
static constexpr bool MulRoutedWeight = true;
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true
#if 1
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
@@ -165,54 +165,24 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
#if 0
static constexpr ck::index_t MPerBlock = 64;
static constexpr ck::index_t MXDLPerWave = 1;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
MXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>;
// clang-format on
#else
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr bool MulRoutedWeight = false;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, MPerBlock, 128, 128,
256, MPerBlock, 64, 128,
16, 32,
32, 32,
4, 1,
16, 16,
8, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>;
2, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>;
// clang-format on
#endif
int main(int argc, char* argv[])
{
@@ -220,13 +190,10 @@ int main(int argc, char* argv[])
int init_method = 1;
bool time_kernel = true;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
// GEMM shape
ck::index_t N = 4096 * 2;
ck::index_t K = 6144;
ck::index_t N = 14336;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
@@ -266,20 +233,20 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0, 0};
max_token_id.mData = {valid_size};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
int token_per_tile = tokens * topk / valid_tile_num;
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
@@ -294,11 +261,12 @@ int main(int argc, char* argv[])
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(
@@ -306,6 +274,7 @@ int main(int argc, char* argv[])
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
@@ -314,11 +283,11 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{-2, 2});
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
@@ -497,9 +466,9 @@ int main(int argc, char* argv[])
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
sizeof(B0DataType) / 2 * K * N * experts +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -526,6 +495,7 @@ int main(int argc, char* argv[])
PassThrough,
PassThrough,
PassThrough,
Act_OP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
@@ -535,7 +505,9 @@ int main(int argc, char* argv[])
max_token_id,
MPerBlock,
a0_t_k,
d0_t_n,
b0_e_n_k,
d1_e_n,
c_t_k_n,
d2_e_n,
PassThrough{},
@@ -561,13 +533,13 @@ int main(int argc, char* argv[])
c_t_k_n(t, topk_id, n),
d0_t_n(t, n),
d1_e_n(e, n),
1.f);
d2_e_n(e, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
}

View File

@@ -25,7 +25,6 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
// using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
@@ -36,7 +35,7 @@ using A0DataType = F8;
using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
@@ -48,7 +47,6 @@ using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
@@ -62,11 +60,19 @@ struct MulABScaleExpertWeight
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for real kernel use
// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside.
// tofix:felix
(void)d0;
e = ck::type_convert<EDataType>(c * d1 * d2);
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
@@ -119,14 +125,12 @@ using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 4;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint
// static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t CShuffleNLane = 32;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
@@ -135,7 +139,7 @@ static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool MulRoutedWeight = false;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
@@ -164,8 +168,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -177,16 +181,13 @@ int main(int argc, char* argv[])
int init_method = 1;
bool time_kernel = true;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 6;
ck::index_t valid_tile_num = 6;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 128;
@@ -212,6 +213,18 @@ int main(int argc, char* argv[])
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
@@ -229,15 +242,13 @@ int main(int argc, char* argv[])
ck::index_t KBatch = 1;
// const ck::index_t experts = 8;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
// max_token_id.mData[0] = valid_size;
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
@@ -249,7 +260,7 @@ int main(int argc, char* argv[])
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
@@ -263,8 +274,7 @@ int main(int argc, char* argv[])
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
@@ -315,12 +325,7 @@ int main(int argc, char* argv[])
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// a0_t_k_k.savetxt("a.txt");
// expert_ids.savetxt("expert_ids.txt", "int");
// sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
// d0_t_n.savetxt("d0_t_n.txt", "int");
// d1_e_n.savetxt("d1_e_n.txt", "int");
// d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
@@ -398,7 +403,7 @@ int main(int argc, char* argv[])
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<CShuffleDataType> c_t_n({tokens, N});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
@@ -406,7 +411,7 @@ int main(int argc, char* argv[])
D0DataType,
D1DataType,
D2DataType,
CShuffleDataType,
float,
AccDataType,
PassThrough,
PassThrough,
@@ -439,8 +444,7 @@ int main(int argc, char* argv[])
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
// e_t_n_device_result.savetxt("out.txt");
// e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0

View File

@@ -62,11 +62,13 @@ struct MulABScaleExpertWeight
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
#if CK_USE_PK4_LAYOUT_SHUFFLE
e = ck::type_convert<EDataType>(c * d1 * d2 * 16);
e = ck::type_convert<EDataType>(c * 16);
#else
e = ck::type_convert<EDataType>(c * d1 * d2);
e = ck::type_convert<EDataType>(c);
#endif
}
// for reference cpu
@@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 1;
static constexpr ck::index_t MXDLPerWave = 8;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t CShuffleNLane = 32;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
@@ -149,8 +151,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>;
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
@@ -159,9 +161,6 @@ int main(int argc, char* argv[])
int init_method = 1;
bool time_kernel = true;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
// GEMM shape
ck::index_t N = 4096;