mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Padding support for wave transfer (#3537)
* Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format
This commit is contained in:
@@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
|
||||
// check if src element is valid
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);
|
||||
|
||||
// Vector length of elementwise operation
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
|
||||
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;
|
||||
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
// Load data from memory in src_vector first
|
||||
src_vector_container src_vector =
|
||||
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
|
||||
src_coord_.GetOffset(), true)};
|
||||
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
|
||||
src_vector_container src_vector = src_vector_container{
|
||||
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};
|
||||
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
@@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
|
||||
// store result in dvgpr_ (static array holding loaded data).
|
||||
// At this point data is already converted to DstData type and
|
||||
// the elementwise operation has been applied
|
||||
dvgpr_.template SetAsType<dst_vector_t>(
|
||||
vgpr_data_idx_seq,
|
||||
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
|
||||
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// For each dimension move fwd, bwd or don't move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
constexpr auto ordered_fwd_step = StepsPerIteration{};
|
||||
|
||||
// OOB check
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// calculate src data index and make sequence
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
|
||||
}();
|
||||
|
||||
// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
|
||||
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value < src_data_idx.Size())
|
||||
{
|
||||
return Number<src_data_idx[i]>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<src_data_idx.Size() + 1>{});
|
||||
|
||||
auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
|
||||
const bool is_src_valid =
|
||||
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
|
||||
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
|
||||
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
|
||||
});
|
||||
|
||||
// make forward steps
|
||||
// forward step for each iteration just add 1
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
@@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
true,
|
||||
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
|
||||
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
|
||||
|
||||
// For each dimension move fwd, bwd or don't move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal
|
||||
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto access_lengths_as_tuple =
|
||||
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});
|
||||
|
||||
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
|
||||
}
|
||||
|
||||
static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
|
||||
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
@@ -396,7 +435,17 @@ struct ThreadGroupTransferGlobal
|
||||
decltype(thread_data_scratch_desc_),
|
||||
true>;
|
||||
|
||||
ThreadScratchData dvgpr_;
|
||||
static constexpr auto src_oob_thread_scratch_desc_ =
|
||||
decltype(GetSrcThreadScratchDescriptor()){};
|
||||
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
bool,
|
||||
1,
|
||||
decltype(src_oob_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
ThreadScratchData src_dvgpr_;
|
||||
ThreadScratchData dst_dvgpr_;
|
||||
OOBThreadScratch oob_thread_scratch_;
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
|
||||
Reference in New Issue
Block a user