mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)
This commit is contained in:
@@ -176,8 +176,36 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
template <bool isWave64>
|
||||
static constexpr auto GetNXdlPerWave()
|
||||
{
|
||||
constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32;
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
|
||||
static_assert(MWaves > 0);
|
||||
|
||||
constexpr index_t NWaves = Waves / MWaves;
|
||||
if constexpr(NWaves == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NPerBlock % (NPerXDL * NWaves) == 0)
|
||||
{
|
||||
return NPerBlock / (NWaves * NPerXDL);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -199,7 +227,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -226,8 +254,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
using Argument = typename GridwiseGemm64::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
|
||||
@@ -254,12 +284,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
template <typename GridwiseGemm>
|
||||
float RunImp(const typename GridwiseGemm::Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -285,7 +312,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
auto arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
@@ -297,7 +324,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
@@ -733,6 +760,31 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm64>(arg, stream_config);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
return RunImp<GridwiseGemm32>(
|
||||
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg),
|
||||
stream_config);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
@@ -754,9 +806,39 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
if(is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(sizeof(CDataType) == 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(is_gfx11_supported() || is_gfx12_supported())
|
||||
{
|
||||
if(MPerXDL != 16 || NPerXDL != 16)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(is_gfx11_supported())
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
|
||||
std::is_same_v<ADataType, ck::bf8_t>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
@@ -767,7 +849,29 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
return GridwiseGemm64::CheckValidity(arg);
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
return GridwiseGemm32::CheckValidity(
|
||||
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -849,6 +953,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
index_t PrefetchStages = 0;
|
||||
index_t AMmaKStride = 0;
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
|
||||
AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
|
||||
AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
<< "<"
|
||||
@@ -872,9 +995,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< PrefetchStages << ", "
|
||||
<< "Kpack: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
|
||||
<< AMmaKStride;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
Reference in New Issue
Block a user