mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Wmma support for gemm_bias_add_reduce (#3316)
* Add tests for gemm_bias_add_reduce * Initial working implementation * Generalize implementation of reduce epilogue * Add tests for all layouts * Add instances * Fix test archs * Fix xdl bug * Remove library/profiler duplications * Fix num_byted error profiler * Fix typos * Fix copyright
This commit is contained in:
@@ -10,6 +10,7 @@ namespace ck {
|
||||
|
||||
template <typename ReduceAccDataType,
|
||||
typename ReducePtrsGlobal,
|
||||
typename D0ElementwiseOperation,
|
||||
typename ReduceOperations,
|
||||
typename ReduceInElementwiseOperations,
|
||||
typename ReduceAccElementwiseOperations,
|
||||
@@ -21,6 +22,7 @@ struct ReduceTrait_
|
||||
{
|
||||
using ReduceAccDataType_ = ReduceAccDataType;
|
||||
using ReducePtrsGlobal_ = ReducePtrsGlobal;
|
||||
using D0ElementwiseOperation_ = D0ElementwiseOperation;
|
||||
using ReduceOperations_ = ReduceOperations;
|
||||
using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
|
||||
using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
|
||||
@@ -148,11 +150,13 @@ struct EpilogueReduceCShuffle
|
||||
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
|
||||
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
|
||||
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
|
||||
const index_t MRaw_)
|
||||
const index_t MRaw_,
|
||||
const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
|
||||
: p_reduces_grid(p_reduces_grid_),
|
||||
reduce_in_element_ops(reduce_in_element_ops_),
|
||||
reduce_out_element_ops(reduce_out_element_ops_),
|
||||
MRaw(MRaw_),
|
||||
d0_element_op{d0_element_op_},
|
||||
reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)}
|
||||
{
|
||||
}
|
||||
@@ -174,6 +178,13 @@ struct EpilogueReduceCShuffle
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id)
|
||||
{
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
|
||||
|
||||
auto reduce_grid_desc_mblock_mperblock =
|
||||
MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m);
|
||||
|
||||
@@ -216,29 +227,6 @@ struct EpilogueReduceCShuffle
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
GetCShuffleLDSDescriptor();
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// Thread transfer LDS to Vmem
|
||||
auto cde_shuffle_block_copy_lds_and_global =
|
||||
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
|
||||
c_ds_desc_refs,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id);
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// LDS c_reduce_block_desc_mperblock_nperblock
|
||||
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
@@ -346,6 +334,68 @@ struct EpilogueReduceCShuffle
|
||||
},
|
||||
Number<NumReduce>{});
|
||||
|
||||
// multiple Ds
|
||||
constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
|
||||
|
||||
constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple(
|
||||
[&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; },
|
||||
Number<NumDTensor>{});
|
||||
|
||||
constexpr auto ds_thread_buf_size =
|
||||
d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
auto c01_thread_buf =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
|
||||
Number<ds_thread_buf_size>{});
|
||||
|
||||
auto ds_thread_copy_global_to_vgpr = generate_tuple(
|
||||
[&](auto I) {
|
||||
return ThreadwiseTensorSliceTransfer_v2<
|
||||
remove_cvref_t<tuple_element_t<I.value, DsDataType>>,
|
||||
typename ReduceTrait::ReduceAccDataType_,
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
|
||||
remove_cvref_t<
|
||||
decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>,
|
||||
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
|
||||
1,
|
||||
true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
|
||||
make_multi_index(
|
||||
I0,
|
||||
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
|
||||
I0,
|
||||
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
|
||||
|
||||
// Write E from Vgpr to Vmem
|
||||
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
typename ReduceTrait::ReduceAccDataType_,
|
||||
EDataType,
|
||||
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
3, // DstVectorDim
|
||||
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
|
||||
EGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(I0,
|
||||
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
|
||||
I0,
|
||||
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
|
||||
NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
@@ -365,15 +415,6 @@ struct EpilogueReduceCShuffle
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
{
|
||||
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
@@ -381,6 +422,53 @@ struct EpilogueReduceCShuffle
|
||||
make_tuple(I0, I0),
|
||||
c_reduce_thread_buf);
|
||||
|
||||
// Note: currently multiple Ds supports only Bias + Add.
|
||||
// It needs to be generalized for other operations (currently not needed)
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0);
|
||||
// d0 / d1 operations
|
||||
d0_thread_copy_global_to_vgpr.Run(
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
|
||||
ds_grid_buf[I0],
|
||||
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0],
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c01_thread_buf);
|
||||
|
||||
// c = activation(c + bias)
|
||||
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
|
||||
[&](auto i) {
|
||||
typename ReduceTrait::ReduceAccDataType_ out;
|
||||
cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
|
||||
c_reduce_thread_buf(i) = out;
|
||||
});
|
||||
|
||||
auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1);
|
||||
|
||||
d1_thread_copy_global_to_vgpr.Run(
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
|
||||
ds_grid_buf[I1],
|
||||
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1],
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c01_thread_buf);
|
||||
|
||||
// c = c + c1_function(c1)
|
||||
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
|
||||
[&](auto i) {
|
||||
d0_element_op(c01_thread_buf(i), c01_thread_buf(i));
|
||||
c_reduce_thread_buf(i) += c01_thread_buf(i);
|
||||
});
|
||||
}
|
||||
|
||||
// Write E
|
||||
c_reduce_thread_copy_vgpr_to_global.Run(
|
||||
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_reduce_thread_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
// Reduction
|
||||
static_for<0, NumReduce, 1>{}([&](auto In) {
|
||||
auto& p_reduce_grid = p_reduces_grid[In];
|
||||
|
||||
@@ -448,14 +536,15 @@ struct EpilogueReduceCShuffle
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto I) {
|
||||
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I);
|
||||
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -464,6 +553,7 @@ struct EpilogueReduceCShuffle
|
||||
typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
|
||||
typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
|
||||
index_t MRaw;
|
||||
typename ReduceTrait::D0ElementwiseOperation_ d0_element_op;
|
||||
ReduceGridDesc_M reduce_grid_desc_m;
|
||||
};
|
||||
|
||||
|
||||
@@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
|
||||
Reference in New Issue
Block a user