mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 4 (#2724)
* Fix example
* fix build error
* update pk_i4 & moe test case
* fix all instance build (examples)
* fix batched_gemm_gemm (example)
* disable example_gemm_bias_softmax_gemm_permute on gfx11
* remove unnecessary disable gfx11
* update tests
* update tests2
[ROCm/composable_kernel commit: 321627aec5]
This commit is contained in:
@@ -78,11 +78,17 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
|
||||
///###### RCR
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// fp8 are not supported on gfx11
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
@@ -97,11 +97,12 @@ struct MultiplyMultiply
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr int KPack = 8;
|
||||
|
||||
void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16 / sizeof(F16);
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int KLane = ck::get_warp_size() / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
@@ -147,12 +148,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
|
||||
32, 128, 128,
|
||||
8, 8,
|
||||
32, 32,
|
||||
1, 1,
|
||||
KPack, KPack,
|
||||
16, 16,
|
||||
2, 2,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
|
||||
1, 1, S<1, 16, 1, 16>, S<4, 4, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>;
|
||||
// clang-format on
|
||||
|
||||
@@ -211,6 +212,12 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// temp disable on gfx11
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -202,8 +202,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
constexpr auto I0 = ck::Number<0>{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
@@ -218,7 +216,7 @@ int main(int argc, char* argv[])
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{I0, I0},
|
||||
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -125,11 +125,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
|
||||
64, 128, 256,
|
||||
16, 16,
|
||||
32, 32,
|
||||
1, 2,
|
||||
16, 16,
|
||||
2, 4,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
1, 1, S<1, 32, 1, 8>, S<4, 4, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t EVec = 8 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
|
||||
@@ -121,6 +121,7 @@ struct MulABScaleExpertWeight
|
||||
};
|
||||
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
static constexpr ck::index_t KPack = 32;
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
|
||||
|
||||
@@ -129,7 +130,6 @@ using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
|
||||
#if 1
|
||||
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 32;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
@@ -169,18 +169,19 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, MPerBlock, 64, 128,
|
||||
16, 32,
|
||||
16, KPack,
|
||||
16, 16,
|
||||
8, 1,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
|
||||
2, 1, S<1, 32, 1, 8>, S<8, 1, 1>,
|
||||
2, 1, S<1, 32, 1, 8>, S<4, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
@@ -458,9 +459,10 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
std::cout << "This kernel support gfx942, gfx950, gfx11 and gfx12 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
|
||||
@@ -85,11 +85,11 @@ struct MulABScaleExpertWeight
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
static constexpr int KPack = 32 / sizeof(B0DataType);
|
||||
|
||||
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 32;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
@@ -135,7 +135,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t CShuffleNLane = 32;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t BK1 = KPack;
|
||||
static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
@@ -414,9 +414,10 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
std::cout << "This kernel support gfx942, gfx950, gfx11 and gfx12 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
|
||||
Reference in New Issue
Block a user