mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Update to gemm_reduce and batched_gemm_reduce (#213)
* [Experimental] Change to gemm+reduce and batched-gemm+reduce * Use threadwise-reduce function to improve the gridwise_gemm_reduce_xdl_cshuffle kernel * Tiny fix in device_batched_gemm_xdl.hpp * clang-format library/src/utility/conv_fwd_util.cpp
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -18,8 +19,7 @@ template <typename GridwiseGemm,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -39,8 +39,7 @@ __global__ void
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D0ReduceOperation d0_reduce_op,
|
||||
const D1ReduceOperation d1_reduce_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -60,8 +59,7 @@ __global__ void
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op,
|
||||
d1_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -76,8 +74,7 @@ __global__ void
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d0_reduce_op;
|
||||
ignore = d1_reduce_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -97,6 +94,7 @@ template <typename FloatAB,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
@@ -372,8 +370,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const D0ReduceOperation& d0_reduce_op,
|
||||
const D1ReduceOperation& d1_reduce_op,
|
||||
const D1ElementwiseOperation& d1_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
@@ -741,13 +738,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
|
||||
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
|
||||
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
// reduce: threadwise copy from LDS to VGPR
|
||||
@@ -763,7 +760,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
|
||||
FloatCShuffle,
|
||||
FloatCShuffle,
|
||||
FloatReduceAcc,
|
||||
decltype(c_reduce_block_desc_mperblock_nperblock),
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(c_reduce_thread_lengths_mperblock_nperblock),
|
||||
@@ -775,7 +772,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
// reduce: copy from VGPR to global
|
||||
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatCShuffle,
|
||||
FloatReduceAcc,
|
||||
FloatD,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
@@ -840,6 +837,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
using ThreadwiseReduce_D0 =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
D0ReduceOperation,
|
||||
false>;
|
||||
|
||||
using ThreadwiseReduce_D1 =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
D1ReduceOperation,
|
||||
false>;
|
||||
|
||||
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
|
||||
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
|
||||
|
||||
// reduce
|
||||
{
|
||||
// copy from LDS to VGPR
|
||||
@@ -850,26 +869,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_reduce_thread_buf);
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
FloatReduceAcc d0_acc = d0_reduce_op.GetReduceZeroValue();
|
||||
FloatReduceAcc d1_acc = d1_reduce_op.GetReduceZeroValue();
|
||||
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
d0_reduce_op.Reduce(d0_acc, c_reduce_thread_buf[offset]);
|
||||
d1_reduce_op.Reduce(d1_acc, c_reduce_thread_buf[offset]);
|
||||
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
|
||||
});
|
||||
|
||||
constexpr index_t out_offset =
|
||||
d_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im));
|
||||
|
||||
d0_thread_buf(Number<out_offset>{}) = d0_acc;
|
||||
d1_thread_buf(Number<out_offset>{}) = d1_acc;
|
||||
});
|
||||
|
||||
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
|
||||
Reference in New Issue
Block a user