Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)

This commit is contained in:
linqunAMD
2025-08-23 05:46:30 +08:00
committed by GitHub
parent 0db21053e6
commit d6e49c5fde
11 changed files with 683 additions and 127 deletions

View File

@@ -35,20 +35,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) ||
(GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) ||
(GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2))
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -78,23 +66,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
enum struct Arch : bool
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) ||
(GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) ||
(GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2))
{
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -696,12 +672,23 @@ struct GridwiseGemm_xdl_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
// clang-format off
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
// clang-format off
}
index_t M;
@@ -831,6 +818,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
@@ -888,7 +879,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto KThreadRead = WaveSize / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
@@ -969,6 +960,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
@@ -1022,7 +1016,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto KThreadRead = WaveSize / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
@@ -1169,12 +1163,99 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_block_size * sizeof(CShuffleDataType));
}
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static bool constexpr IsValidCompilationParameter()
{
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
(AK1Number < 32 && BK1Number < 32) ||
(AK1Number >= 32 && APackedSize == 2) ||
(BK1Number >= 32 && BPackedSize == 2))
{
}
else
{
return false;
}
// Check tile size
#if defined(__gfx11__) || defined(__gfx12__)
if constexpr(MPerXdl != 16 || NPerXdl != 16)
{
return false;
}
#endif
// Check atomic caps
#if defined(__gfx11__)
constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set;
#else
constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation ==
InMemoryDataOperationEnum::Set);
#endif
if constexpr(SupportMemOp == false)
{
return false;
}
// Check tile size
if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
if constexpr(MWaves > 0 && NWaves > 0)
{
constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
if constexpr(WaveSize == get_warp_size())
{
return true;
}
else
{
return false;
}
}
else
{
return false;
}
}
else
{
return false;
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0)
{
return false;
}
else
{
if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) ||
(NPerBlock % (NXdlPerWave * NPerXdl) != 0))
{
return false;
}
else
{
if(BlockwiseGemmPipe::WaveSize != get_warp_size())
{
return false;
}
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||