Padding support for wave transfer (#3537)

* Add padding support with transpose

Also move check before writing storing is_src_valid during reading

* Add/modify instances to use wave transfer for gemm universal

Condition is changed so now the vectorsize of vmem reading and lds
writing must be equal to 8 in order to use the wave transfer

* Fix clang format

* Modify example

* Fix bwd data

* Add restriction for wave transfer with padding and transpose

Add test case which shows this limitation

* Fix validity checks 8 bit types

* Add validity check gemm_bias_add_reduce

* Add validity check grouped gemm tile loop

* Fix validity checks new flavours

* Minor fixes

* Fix clang format
This commit is contained in:
Enrico Degregori
2026-01-26 21:57:09 +01:00
committed by GitHub
parent bd5fec81af
commit 2e49b6b2f7
23 changed files with 385 additions and 50 deletions

View File

@@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
// check vector access
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),

View File

@@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}

View File

@@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},

View File

@@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}

View File

@@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},

View File

@@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemmWelford::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},

View File

@@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},

View File

@@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
}
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
};

View File

@@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
}

View File

@@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
BlkGemmPipelineVer,
AComputeType,
BComputeType,
false,
false>;
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
#define GridwiseGemmCTransposeTemplateParameters \
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
@@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
AComputeType, false, false
AComputeType, false, false, false, true
using GridwiseGemmCTranspose =
std::conditional_t<CTranspose,

View File

@@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
bool supported = true;
for(index_t i = 0; i < arg.group_count_; ++i)
{
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(
arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(
arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
std::array<index_t, NumDTensor> stride_Ds;
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());

View File

@@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)