mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Padding support for wave transfer (#3537)
* Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format
This commit is contained in:
@@ -132,10 +132,6 @@ struct ABTransferWaveTiles
|
||||
index_t,
|
||||
index_t)
|
||||
{
|
||||
// Notes: padding is currently not supported with transpose
|
||||
static_assert(!((PadMN || PadK) && ABDoTranspose),
|
||||
"padding is currently not supported with transpose");
|
||||
|
||||
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
|
||||
const index_t K_grid = !PadK ? sizeK : KPad;
|
||||
|
||||
|
||||
@@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.wave_size;
|
||||
|
||||
__host__ __device__ static constexpr bool AWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
|
||||
!IsBPreShuffled;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool BWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
}
|
||||
|
||||
// Limitations of the current implementation:
|
||||
// - no multiAB
|
||||
// - GemmSpecialization Default with transpose
|
||||
#ifdef __gfx12__
|
||||
static constexpr bool IsAWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
|
||||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
|
||||
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
|
||||
|
||||
static constexpr bool IsBWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable();
|
||||
|
||||
static constexpr bool IsWaveTileInterleavedFitting =
|
||||
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
|
||||
@@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// Conditions for Wave Transfer with transpose:
|
||||
// - 16 bit type: K % 8 == 0 (4 subtiles of 8x8)
|
||||
// - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16)
|
||||
__host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K)
|
||||
{
|
||||
if constexpr(AWaveTransferApplicable() &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool pass = true;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
pass &= !(sizeof(ADataType_) == 1 &&
|
||||
!(M % (2 * ABlockTransferSrcScalarPerVector) == 0));
|
||||
});
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K)
|
||||
{
|
||||
if constexpr(BWaveTransferApplicable() &&
|
||||
!(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value))
|
||||
{
|
||||
if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool pass = true;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
pass &= !(sizeof(BDataType_) == 1 &&
|
||||
!(N % (2 * BBlockTransferSrcScalarPerVector) == 0));
|
||||
});
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Argument>
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg,
|
||||
|
||||
Reference in New Issue
Block a user