Padding support for wave transfer (#3537)

* Add padding support with transpose

Also move check before writing storing is_src_valid during reading

* Add/modify instances to use wave transfer for gemm universal

Condition is changed so now the vectorsize of vmem reading and lds
writing must be equal to 8 in order to use the wave transfer

* Fix clang format

* Modify example

* Fix bwd data

* Add restriction for wave transfer with padding and transpose

Add test case which shows this limitation

* Fix validity checks 8 bit types

* Add validity check gemm_bias_add_reduce

* Add validity check grouped gemm tile loop

* Fix validity checks new flavours

* Minor fixes

* Fix clang format
This commit is contained in:
Enrico Degregori
2026-01-26 21:57:09 +01:00
committed by GitHub
parent bd5fec81af
commit 2e49b6b2f7
23 changed files with 385 additions and 50 deletions

View File

@@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
// check if src element is valid
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);
// Vector length of elementwise operation
constexpr auto get_elem_op_vec_len = []() {
@@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
using dst_vector_t = typename dst_vector_type::type;
using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;
dst_vector_type op_r_v;
// Load data from memory in src_vector first
src_vector_container src_vector =
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
src_coord_.GetOffset(), true)};
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
src_vector_container src_vector = src_vector_container{
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};
// apply the src elementwise op and convert to DstData under the hood if needed
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
@@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
// store result in dvgpr_ (static array holding loaded data).
// At this point data is already converted to DstData type and
// the elementwise operation has been applied
dvgpr_.template SetAsType<dst_vector_t>(
vgpr_data_idx_seq,
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
@@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
constexpr auto ordered_fwd_step = StepsPerIteration{};
// OOB check
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// calculate src data index and make sequence
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}(
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
}();
// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
[&](auto i) {
if constexpr(i.value < src_data_idx.Size())
{
return Number<src_data_idx[i]>{};
}
else
{
return Number<0>{};
}
},
Number<src_data_idx.Size() + 1>{});
auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
const bool is_src_valid =
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
});
// make forward steps
// forward step for each iteration just add 1
const auto dst_forward_steps = generate_tuple(
@@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
true,
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
@@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto access_lengths_as_tuple =
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}
static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
@@ -396,7 +435,17 @@ struct ThreadGroupTransferGlobal
decltype(thread_data_scratch_desc_),
true>;
ThreadScratchData dvgpr_;
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool,
1,
decltype(src_oob_thread_scratch_desc_),
true>;
ThreadScratchData src_dvgpr_;
ThreadScratchData dst_dvgpr_;
OOBThreadScratch oob_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;

View File

@@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
// check vector access
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),

View File

@@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}

View File

@@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},

View File

@@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}

View File

@@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},

View File

@@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemmWelford::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},

View File

@@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},

View File

@@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
}
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
};

View File

@@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
}

View File

@@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
BlkGemmPipelineVer,
AComputeType,
BComputeType,
false,
false>;
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
#define GridwiseGemmCTransposeTemplateParameters \
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
@@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
AComputeType, false, false
AComputeType, false, false, false, true
using GridwiseGemmCTranspose =
std::conditional_t<CTranspose,

View File

@@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
bool supported = true;
for(index_t i = 0; i < arg.group_count_; ++i)
{
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(
arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(
arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
std::array<index_t, NumDTensor> stride_Ds;
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());

View File

@@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)

View File

@@ -132,10 +132,6 @@ struct ABTransferWaveTiles
index_t,
index_t)
{
// Notes: padding is currently not supported with transpose
static_assert(!((PadMN || PadK) && ABDoTranspose),
"padding is currently not supported with transpose");
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;

View File

@@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
__host__ __device__ static constexpr bool AWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
!IsBPreShuffled;
}
__host__ __device__ static constexpr bool BWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
}
// Limitations of the current implementation:
// - no multiAB
// - GemmSpecialization Default with transpose
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable =
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
static constexpr bool IsBWaveTransferApplicable =
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable();
static constexpr bool IsWaveTileInterleavedFitting =
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
@@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Conditions for Wave Transfer with transpose:
// - 16 bit type: K % 8 == 0 (4 subtiles of 8x8)
// - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16)
__host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K)
{
if constexpr(AWaveTransferApplicable() &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0))
{
return false;
}
bool pass = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
pass &= !(sizeof(ADataType_) == 1 &&
!(M % (2 * ABlockTransferSrcScalarPerVector) == 0));
});
return pass;
}
else
{
return true;
}
}
__host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K)
{
if constexpr(BWaveTransferApplicable() &&
!(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value))
{
if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0))
{
return false;
}
bool pass = true;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
pass &= !(sizeof(BDataType_) == 1 &&
!(N % (2 * BBlockTransferSrcScalarPerVector) == 0));
});
return pass;
}
else
{
return true;
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg,