mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Padding support for wave transfer (#3537)
* Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format
This commit is contained in:
@@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
|
||||
// check if src element is valid
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);
|
||||
|
||||
// Vector length of elementwise operation
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
|
||||
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;
|
||||
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
// Load data from memory in src_vector first
|
||||
src_vector_container src_vector =
|
||||
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
|
||||
src_coord_.GetOffset(), true)};
|
||||
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
|
||||
src_vector_container src_vector = src_vector_container{
|
||||
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};
|
||||
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
@@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
|
||||
// store result in dvgpr_ (static array holding loaded data).
|
||||
// At this point data is already converted to DstData type and
|
||||
// the elementwise operation has been applied
|
||||
dvgpr_.template SetAsType<dst_vector_t>(
|
||||
vgpr_data_idx_seq,
|
||||
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
|
||||
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// For each dimension move fwd, bwd or don't move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
constexpr auto ordered_fwd_step = StepsPerIteration{};
|
||||
|
||||
// OOB check
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// calculate src data index and make sequence
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
|
||||
}();
|
||||
|
||||
// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
|
||||
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value < src_data_idx.Size())
|
||||
{
|
||||
return Number<src_data_idx[i]>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<src_data_idx.Size() + 1>{});
|
||||
|
||||
auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
|
||||
const bool is_src_valid =
|
||||
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
|
||||
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
|
||||
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
|
||||
});
|
||||
|
||||
// make forward steps
|
||||
// forward step for each iteration just add 1
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
@@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
true,
|
||||
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
|
||||
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
|
||||
|
||||
// For each dimension move fwd, bwd or don't move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal
|
||||
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto access_lengths_as_tuple =
|
||||
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});
|
||||
|
||||
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
|
||||
}
|
||||
|
||||
static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
|
||||
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
@@ -396,7 +435,17 @@ struct ThreadGroupTransferGlobal
|
||||
decltype(thread_data_scratch_desc_),
|
||||
true>;
|
||||
|
||||
ThreadScratchData dvgpr_;
|
||||
static constexpr auto src_oob_thread_scratch_desc_ =
|
||||
decltype(GetSrcThreadScratchDescriptor()){};
|
||||
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
bool,
|
||||
1,
|
||||
decltype(src_oob_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
ThreadScratchData src_dvgpr_;
|
||||
ThreadScratchData dst_dvgpr_;
|
||||
OOBThreadScratch oob_thread_scratch_;
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
|
||||
@@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access
|
||||
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
|
||||
@@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
std::array<const void*, 0>{},
|
||||
|
||||
@@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
|
||||
@@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
|
||||
@@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
std::array<const void*, 0>{},
|
||||
|
||||
@@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
}
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
|
||||
}
|
||||
|
||||
@@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
false,
|
||||
false>;
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
|
||||
@@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
|
||||
AComputeType, false, false
|
||||
AComputeType, false, false, false, true
|
||||
|
||||
using GridwiseGemmCTranspose =
|
||||
std::conditional_t<CTranspose,
|
||||
|
||||
@@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
bool supported = true;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(
|
||||
arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(
|
||||
arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
|
||||
@@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
|
||||
if(not group_arg_valid)
|
||||
|
||||
@@ -132,10 +132,6 @@ struct ABTransferWaveTiles
|
||||
index_t,
|
||||
index_t)
|
||||
{
|
||||
// Notes: padding is currently not supported with transpose
|
||||
static_assert(!((PadMN || PadK) && ABDoTranspose),
|
||||
"padding is currently not supported with transpose");
|
||||
|
||||
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
|
||||
const index_t K_grid = !PadK ? sizeK : KPad;
|
||||
|
||||
|
||||
@@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.wave_size;
|
||||
|
||||
__host__ __device__ static constexpr bool AWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
|
||||
!IsBPreShuffled;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool BWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
}
|
||||
|
||||
// Limitations of the current implementation:
|
||||
// - no multiAB
|
||||
// - GemmSpecialization Default with transpose
|
||||
#ifdef __gfx12__
|
||||
static constexpr bool IsAWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
|
||||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
|
||||
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
|
||||
|
||||
static constexpr bool IsBWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable();
|
||||
|
||||
static constexpr bool IsWaveTileInterleavedFitting =
|
||||
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
|
||||
@@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// Conditions for Wave Transfer with transpose:
|
||||
// - 16 bit type: K % 8 == 0 (4 subtiles of 8x8)
|
||||
// - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16)
|
||||
__host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K)
|
||||
{
|
||||
if constexpr(AWaveTransferApplicable() &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool pass = true;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
pass &= !(sizeof(ADataType_) == 1 &&
|
||||
!(M % (2 * ABlockTransferSrcScalarPerVector) == 0));
|
||||
});
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K)
|
||||
{
|
||||
if constexpr(BWaveTransferApplicable() &&
|
||||
!(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value))
|
||||
{
|
||||
if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool pass = true;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
pass &= !(sizeof(BDataType_) == 1 &&
|
||||
!(N % (2 * BBlockTransferSrcScalarPerVector) == 0));
|
||||
});
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Argument>
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg,
|
||||
|
||||
Reference in New Issue
Block a user