Padding support for wave transfer (#3537)

* Add padding support with transpose

Also move check before writing storing is_src_valid during reading

* Add/modify instances to use wave transfer for gemm universal

Condition is changed so now the vectorsize of vmem reading and lds
writing must be equal to 8 in order to use the wave transfer

* Fix clang format

* Modify example

* Fix bwd data

* Add restriction for wave transfer with padding and transpose

Add test case which shows this limitation

* Fix validity checks 8 bit types

* Add validity check gemm_bias_add_reduce

* Add validity check grouped gemm tile loop

* Fix validity checks new flavours

* Minor fixes

* Fix clang format
This commit is contained in:
Enrico Degregori
2026-01-26 21:57:09 +01:00
committed by GitHub
parent bd5fec81af
commit 2e49b6b2f7
23 changed files with 385 additions and 50 deletions

View File

@@ -132,10 +132,6 @@ struct ABTransferWaveTiles
index_t,
index_t)
{
// Notes: padding is currently not supported with transpose
static_assert(!((PadMN || PadK) && ABDoTranspose),
"padding is currently not supported with transpose");
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;

View File

@@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
__host__ __device__ static constexpr bool AWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
!IsBPreShuffled;
}
__host__ __device__ static constexpr bool BWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
}
// Limitations of the current implementation:
// - no multiAB
// - GemmSpecialization Default with transpose
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable =
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
static constexpr bool IsBWaveTransferApplicable =
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable();
static constexpr bool IsWaveTileInterleavedFitting =
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
@@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Conditions for Wave Transfer with transpose:
// - 16 bit type: K % 8 == 0 (4 subtiles of 8x8)
// - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16)
__host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K)
{
if constexpr(AWaveTransferApplicable() &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0))
{
return false;
}
bool pass = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
pass &= !(sizeof(ADataType_) == 1 &&
!(M % (2 * ABlockTransferSrcScalarPerVector) == 0));
});
return pass;
}
else
{
return true;
}
}
__host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K)
{
if constexpr(BWaveTransferApplicable() &&
!(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value))
{
if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0))
{
return false;
}
bool pass = true;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
pass &= !(sizeof(BDataType_) == 1 &&
!(N % (2 * BBlockTransferSrcScalarPerVector) == 0));
});
return pass;
}
else
{
return true;
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg,