mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 1 (#2606)
This commit is contained in:
@@ -35,20 +35,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
enum struct Arch : bool
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
is_gfx950_build = true,
|
||||
#else
|
||||
is_gfx950_build = false,
|
||||
#endif
|
||||
};
|
||||
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
|
||||
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
|
||||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) ||
|
||||
(GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) ||
|
||||
(GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2))
|
||||
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -78,23 +66,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
enum struct Arch : bool
|
||||
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
is_gfx950_build = true,
|
||||
#else
|
||||
is_gfx950_build = false,
|
||||
#endif
|
||||
};
|
||||
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
|
||||
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
|
||||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32) ||
|
||||
(GridwiseGemm::AK1Number >= 32 && GridwiseGemm::APackedSize == 2) ||
|
||||
(GridwiseGemm::BK1Number >= 32 && GridwiseGemm::BPackedSize == 2))
|
||||
{
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -696,12 +672,23 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
|
||||
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
|
||||
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
|
||||
// clang-format off
|
||||
std::cout << "problem {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", "
|
||||
<< "KP:" << KPadded << ", "
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << "}" << std::endl;
|
||||
// clang-format off
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -831,6 +818,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
@@ -888,7 +879,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / MPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
@@ -969,6 +960,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
@@ -1022,7 +1016,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / NPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
@@ -1169,12 +1163,99 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__device__ static bool constexpr IsValidCompilationParameter()
|
||||
{
|
||||
enum struct Arch : bool
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
is_gfx950_build = true,
|
||||
#else
|
||||
is_gfx950_build = false,
|
||||
#endif
|
||||
};
|
||||
|
||||
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
|
||||
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
|
||||
(AK1Number < 32 && BK1Number < 32) ||
|
||||
(AK1Number >= 32 && APackedSize == 2) ||
|
||||
(BK1Number >= 32 && BPackedSize == 2))
|
||||
{
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check tile size
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(MPerXdl != 16 || NPerXdl != 16)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
// Check atomic caps
|
||||
#if defined(__gfx11__)
|
||||
constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set;
|
||||
#else
|
||||
constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation ==
|
||||
InMemoryDataOperationEnum::Set);
|
||||
#endif
|
||||
if constexpr(SupportMemOp == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check tile size
|
||||
if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
if constexpr(MWaves > 0 && NWaves > 0)
|
||||
{
|
||||
constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
|
||||
if constexpr(WaveSize == get_warp_size())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) ||
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl) != 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(BlockwiseGemmPipe::WaveSize != get_warp_size())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
|
||||
Reference in New Issue
Block a user