mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
updated
This commit is contained in:
@@ -155,22 +155,22 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// clang-format on
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
static constexpr ck::index_t MPerBlock = 16;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, 128,
|
||||
ScaleBlockSize, 64,
|
||||
MPerBlock, 16, 128,
|
||||
32, 32,
|
||||
16, 16,
|
||||
8, 2,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>,
|
||||
1, 1,
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
#endif
|
||||
@@ -183,14 +183,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 19;
|
||||
constexpr ck::index_t valid_tile_num = 16;
|
||||
constexpr ck::index_t sorted_tile_num = 2;
|
||||
constexpr ck::index_t valid_tile_num = 2;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t experts = 2;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
@@ -285,7 +285,7 @@ int main(int argc, char* argv[])
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B1DataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), Scale_Stride_BN, 1}));
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
@@ -371,34 +371,32 @@ int main(int argc, char* argv[])
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
|
||||
d1_device_buf.GetDeviceBuffer(),
|
||||
d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -439,19 +437,19 @@ int main(int argc, char* argv[])
|
||||
Tensor<CShuffleDataType> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
@@ -480,6 +478,28 @@ int main(int argc, char* argv[])
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
#if 1
|
||||
printf("e_t_n_device_result:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(e_t_n_device_result(t, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("e_t_n_host_result:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
printf("%f ", ck::type_convert<float>(e_t_n_host_result(t, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
|
||||
@@ -1167,7 +1167,7 @@ struct GridwiseMoeGemmMX
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
#if 1
|
||||
#if 0
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
|
||||
@@ -28,7 +28,7 @@ template <typename ADataType,
|
||||
bool MulRoutedWeight = true,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
struct ReferenceMoeMXGemm2 : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
@@ -81,14 +81,18 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceMoeGemm2::Argument;
|
||||
using Argument = ReferenceMoeMXGemm2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
arg.c_t_n_.SetZero();
|
||||
const ck::index_t SCALE_BLOCK = arg.b_e_n_k_.mDesc.GetLengths()[2];
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
|
||||
const ck::index_t SCALE_BLOCK = K / arg.b_e_n_k_scale_.mDesc.GetLengths()[1];
|
||||
if(m == 0 && n == 0)
|
||||
{
|
||||
printf("SCALE_BLOCK: %d\n", SCALE_BLOCK);
|
||||
}
|
||||
AccDataType v_acc{0};
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
|
||||
Reference in New Issue
Block a user