Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)

This commit is contained in:
linqunAMD
2025-08-23 05:46:30 +08:00
committed by GitHub
parent 0db21053e6
commit d6e49c5fde
11 changed files with 683 additions and 127 deletions

View File

@@ -176,8 +176,36 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
template <bool isWave64>
static constexpr auto GetNXdlPerWave()
{
constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32;
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static_assert(MWaves > 0);
constexpr index_t NWaves = Waves / MWaves;
if constexpr(NWaves == 0)
{
return 0;
}
else
{
if constexpr(NPerBlock % (NPerXDL * NWaves) == 0)
{
return NPerBlock / (NWaves * NPerXDL);
}
else
{
return 0;
}
}
}
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
@@ -199,7 +227,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
NXdlPerWave_,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -226,8 +254,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
ComputeTypeB,
PermuteA,
PermuteB>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using Argument = typename GridwiseGemm::Argument;
using Argument = typename GridwiseGemm64::Argument;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
@@ -254,12 +284,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
template <typename GridwiseGemm>
float RunImp(const typename GridwiseGemm::Argument& arg,
const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
@@ -285,7 +312,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
auto arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
@@ -297,7 +324,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
@@ -733,6 +760,31 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return ave_time;
}
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
return RunImp<GridwiseGemm64>(arg, stream_config);
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
return RunImp<GridwiseGemm32>(
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg),
stream_config);
}
}
return 0;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
@@ -754,9 +806,39 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
if(arg.KBatch > 1)
{
return false;
if(is_gfx11_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
{
return false;
}
if(sizeof(CDataType) == 1)
{
return false;
}
}
if(is_gfx11_supported() || is_gfx12_supported())
{
if(MPerXDL != 16 || NPerXDL != 16)
{
return false;
}
}
if(is_gfx11_supported())
{
if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
std::is_same_v<ADataType, ck::bf8_t>)
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
@@ -767,7 +849,29 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
return GridwiseGemm::CheckValidity(arg);
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
return GridwiseGemm64::CheckValidity(arg);
}
else
{
return false;
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
return GridwiseGemm32::CheckValidity(
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
}
else
{
return false;
}
}
}
// polymorphic
@@ -849,6 +953,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
index_t PrefetchStages = 0;
index_t AMmaKStride = 0;
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
}
}
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
@@ -872,9 +995,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
<< AMmaKStride;
// clang-format on
return str.str();