This commit is contained in:
mtgu0705
2025-05-13 04:15:53 -05:00
parent c11fef9197
commit 6dfe24c53e
3 changed files with 83 additions and 59 deletions

View File

@@ -155,22 +155,22 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// clang-format on
#else
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
static constexpr ck::index_t MPerBlock = 16;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, 128,
ScaleBlockSize, 64,
MPerBlock, 16, 128,
32, 32,
16, 16,
8, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>,
1, 1,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
// clang-format on
#endif
@@ -183,14 +183,14 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 19;
constexpr ck::index_t valid_tile_num = 16;
constexpr ck::index_t sorted_tile_num = 2;
constexpr ck::index_t valid_tile_num = 2;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t experts = 2;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -285,7 +285,7 @@ int main(int argc, char* argv[])
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), Scale_Stride_BN, 1}));
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
@@ -371,34 +371,32 @@ int main(int argc, char* argv[])
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer(),
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
@@ -439,19 +437,19 @@ int main(int argc, char* argv[])
Tensor<CShuffleDataType> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
A1DataType,
B0DataType,
B1DataType,
D2DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
A1DataType,
B0DataType,
B1DataType,
D2DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
@@ -480,6 +478,28 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
#if 1
printf("e_t_n_device_result:\n");
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
printf("%f ", ck::type_convert<float>(e_t_n_device_result(t, n)));
}
printf("\n");
}
printf("e_t_n_host_result:\n");
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
printf("%f ", ck::type_convert<float>(e_t_n_host_result(t, n)));
}
printf("\n");
}
#endif
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0

View File

@@ -1167,7 +1167,7 @@ struct GridwiseMoeGemmMX
}
// check gridwise gemm pipeline
#if 1
#if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)

View File

@@ -28,7 +28,7 @@ template <typename ADataType,
bool MulRoutedWeight = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm2 : public device::BaseOperator
struct ReferenceMoeMXGemm2 : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
@@ -81,14 +81,18 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMoeGemm2::Argument;
using Argument = ReferenceMoeMXGemm2::Argument;
float Run(const Argument& arg)
{
arg.c_t_n_.SetZero();
const ck::index_t SCALE_BLOCK = arg.b_e_n_k_.mDesc.GetLengths()[2];
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1];
if(m == 0 && n == 0)
{
printf("SCALE_BLOCK: %d\n", SCALE_BLOCK);
}
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};