Padding support for wave transfer (#3537)

* Add padding support with transpose

Also move check before writing storing is_src_valid during reading

* Add/modify instances to use wave transfer for gemm universal

Condition is changed so now the vectorsize of vmem reading and lds
writing must be equal to 8 in order to use the wave transfer

* Fix clang format

* Modify example

* Fix bwd data

* Add restriction for wave transfer with padding and transpose

Add test case which shows this limitation

* Fix validity checks 8 bit types

* Add validity check gemm_bias_add_reduce

* Add validity check grouped gemm tile loop

* Fix validity checks new flavours

* Minor fixes

* Fix clang format
This commit is contained in:
Enrico Degregori
2026-01-26 21:57:09 +01:00
committed by GitHub
parent bd5fec81af
commit 2e49b6b2f7
23 changed files with 385 additions and 50 deletions

View File

@@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
// check if src element is valid
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);
// Vector length of elementwise operation
constexpr auto get_elem_op_vec_len = []() {
@@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
using dst_vector_t = typename dst_vector_type::type;
using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;
dst_vector_type op_r_v;
// Load data from memory in src_vector first
src_vector_container src_vector =
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
src_coord_.GetOffset(), true)};
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
src_vector_container src_vector = src_vector_container{
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};
// apply the src elementwise op and convert to DstData under the hood if needed
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
@@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
// store result in dvgpr_ (static array holding loaded data).
// At this point data is already converted to DstData type and
// the elementwise operation has been applied
dvgpr_.template SetAsType<dst_vector_t>(
vgpr_data_idx_seq,
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
@@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
constexpr auto ordered_fwd_step = StepsPerIteration{};
// OOB check
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// calculate src data index and make sequence
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}(
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
}();
// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
[&](auto i) {
if constexpr(i.value < src_data_idx.Size())
{
return Number<src_data_idx[i]>{};
}
else
{
return Number<0>{};
}
},
Number<src_data_idx.Size() + 1>{});
auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
const bool is_src_valid =
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
});
// make forward steps
// forward step for each iteration just add 1
const auto dst_forward_steps = generate_tuple(
@@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
true,
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
@@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto access_lengths_as_tuple =
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}
static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
@@ -396,7 +435,17 @@ struct ThreadGroupTransferGlobal
decltype(thread_data_scratch_desc_),
true>;
ThreadScratchData dvgpr_;
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool,
1,
decltype(src_oob_thread_scratch_desc_),
true>;
ThreadScratchData src_dvgpr_;
ThreadScratchData dst_dvgpr_;
OOBThreadScratch oob_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;