mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Padding support for wave transfer (#3537)
* Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format
This commit is contained in:
@@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access
|
||||
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
|
||||
@@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
std::array<const void*, 0>{},
|
||||
|
||||
@@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
|
||||
@@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
|
||||
@@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
std::array<const void*, 0>{},
|
||||
|
||||
@@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
}
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
|
||||
}
|
||||
|
||||
@@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
false,
|
||||
false>;
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
|
||||
@@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
|
||||
AComputeType, false, false
|
||||
AComputeType, false, false, false, true
|
||||
|
||||
using GridwiseGemmCTranspose =
|
||||
std::conditional_t<CTranspose,
|
||||
|
||||
@@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
bool supported = true;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(
|
||||
arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(
|
||||
arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
|
||||
@@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
|
||||
if(not group_arg_valid)
|
||||
|
||||
Reference in New Issue
Block a user