Wmma support for gemm_bias_add_reduce (#3316)

* Add tests for gemm_bias_add_reduce

* Initial working implementation

* Generalize implementation of reduce epilogue

* Add tests for all layouts

* Add instances

* Fix test archs

* Fix xdl bug

* Remove library/profiler duplications

* Fix num_byted error profiler

* Fix typos

* Fix copyright
This commit is contained in:
Enrico Degregori
2026-01-07 19:27:16 +01:00
committed by GitHub
parent f9c6ba0403
commit aad4cf0985
15 changed files with 1424 additions and 141 deletions

View File

@@ -10,6 +10,7 @@ namespace ck {
template <typename ReduceAccDataType,
typename ReducePtrsGlobal,
typename D0ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
@@ -21,6 +22,7 @@ struct ReduceTrait_
{
using ReduceAccDataType_ = ReduceAccDataType;
using ReducePtrsGlobal_ = ReducePtrsGlobal;
using D0ElementwiseOperation_ = D0ElementwiseOperation;
using ReduceOperations_ = ReduceOperations;
using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
@@ -148,11 +150,13 @@ struct EpilogueReduceCShuffle
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
const index_t MRaw_)
const index_t MRaw_,
const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
: p_reduces_grid(p_reduces_grid_),
reduce_in_element_ops(reduce_in_element_ops_),
reduce_out_element_ops(reduce_out_element_ops_),
MRaw(MRaw_),
d0_element_op{d0_element_op_},
reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)}
{
}
@@ -174,6 +178,13 @@ struct EpilogueReduceCShuffle
const index_t& block_m_id,
const index_t& block_n_id)
{
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
auto reduce_grid_desc_mblock_mperblock =
MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m);
@@ -216,29 +227,6 @@ struct EpilogueReduceCShuffle
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
GetCShuffleLDSDescriptor();
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// Thread transfer LDS to Vmem
auto cde_shuffle_block_copy_lds_and_global =
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
c_ds_desc_refs,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
@@ -346,6 +334,68 @@ struct EpilogueReduceCShuffle
},
Number<NumReduce>{});
// multiple Ds
constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple(
[&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; },
Number<NumDTensor>{});
constexpr auto ds_thread_buf_size =
d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
auto c01_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
Number<ds_thread_buf_size>{});
auto ds_thread_copy_global_to_vgpr = generate_tuple(
[&](auto I) {
return ThreadwiseTensorSliceTransfer_v2<
remove_cvref_t<tuple_element_t<I.value, DsDataType>>,
typename ReduceTrait::ReduceAccDataType_,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
remove_cvref_t<
decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
1,
true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
// Write E from Vgpr to Vmem
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
typename ReduceTrait::ReduceAccDataType_,
EDataType,
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
EGlobalMemoryDataOperation,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
@@ -365,15 +415,6 @@ struct EpilogueReduceCShuffle
// make sure it's safe to read from LDS
block_sync_lds();
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
{
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
@@ -381,6 +422,53 @@ struct EpilogueReduceCShuffle
make_tuple(I0, I0),
c_reduce_thread_buf);
// Note: currently multiple Ds supports only Bias + Add.
// It needs to be generalized for other operations (currently not needed)
if constexpr(NumDTensor > 0)
{
auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0);
// d0 / d1 operations
d0_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0],
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
// c = activation(c + bias)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
typename ReduceTrait::ReduceAccDataType_ out;
cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
c_reduce_thread_buf(i) = out;
});
auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1);
d1_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1],
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
// c = c + c1_function(c1)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
d0_element_op(c01_thread_buf(i), c01_thread_buf(i));
c_reduce_thread_buf(i) += c01_thread_buf(i);
});
}
// Write E
c_reduce_thread_copy_vgpr_to_global.Run(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c_reduce_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
// Reduction
static_for<0, NumReduce, 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
@@ -448,14 +536,15 @@ struct EpilogueReduceCShuffle
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
static_for<0, NumDTensor, 1>{}([&](auto I) {
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I);
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step);
}
});
}
@@ -464,6 +553,7 @@ struct EpilogueReduceCShuffle
typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
index_t MRaw;
typename ReduceTrait::D0ElementwiseOperation_ d0_element_op;
ReduceGridDesc_M reduce_grid_desc_m;
};

View File

@@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),