Update to gemm_reduce and batched_gemm_reduce (#213)

* [Experimental] Change to gemm+reduce and batched-gemm+reduce

* Use threadwise-reduce function to improve the gridwise_gemm_reduce_xdl_cshuffle kernel

* Tiny fix in device_batched_gemm_xdl.hpp

* clang-format library/src/utility/conv_fwd_util.cpp

[ROCm/composable_kernel commit: c77ae65d40]
This commit is contained in:
Qianfeng
2022-04-30 00:35:25 +08:00
committed by GitHub
parent 8ba53a00c8
commit 0b220df1a6
18 changed files with 350 additions and 375 deletions

View File

@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -44,8 +43,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op,
const D1ReduceOperation d1_reduce_op,
const D1ElementwiseOperation d1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -82,8 +80,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
@@ -99,8 +96,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = d1_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -125,6 +121,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -161,8 +158,7 @@ template <typename ALayout,
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>
D1ElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
@@ -564,6 +560,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
@@ -624,8 +621,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
@@ -648,8 +644,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op},
d1_reduce_op_{d1_reduce_op}
d1_element_op_{d1_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
@@ -684,8 +679,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_;
D1ReduceOperation d1_reduce_op_;
D1ElementwiseOperation d1_element_op_;
};
// Invoker
@@ -740,8 +734,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -763,8 +756,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -782,8 +774,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -805,8 +796,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -865,8 +855,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount)
{
return Argument{p_a,
@@ -883,8 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount};
}
@@ -905,8 +893,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
@@ -923,8 +910,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
BatchCount);
}

View File

@@ -107,7 +107,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_base_ptr_of_batch_;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}

View File

@@ -9,8 +9,7 @@ namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation>
typename D1ElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation>
typename D1ElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>>;
D1ElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -29,6 +29,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -65,8 +66,7 @@ template <typename ALayout,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation>
D1ElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
@@ -382,6 +382,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
@@ -440,8 +441,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op)
D1ElementwiseOperation d1_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
@@ -457,8 +457,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d0_reduce_op_{d0_reduce_op},
d1_reduce_op_{d1_reduce_op}
d1_element_op_{d1_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
@@ -491,8 +490,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ReduceOperation d0_reduce_op_;
D1ReduceOperation d1_reduce_op_;
D1ElementwiseOperation d1_element_op_;
};
// Invoker
@@ -544,8 +542,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -565,8 +562,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -583,8 +579,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -604,8 +599,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d0_reduce_op_,
arg.d1_reduce_op_,
arg.d1_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -655,8 +649,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op)
D1ElementwiseOperation d1_element_op)
{
return Argument{p_a,
p_b,
@@ -672,8 +665,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op};
d1_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -693,8 +685,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ReduceOperation d0_reduce_op,
D1ReduceOperation d1_reduce_op,
D1ElementwiseOperation d1_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
@@ -711,8 +702,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op);
d1_element_op);
}
// polymorphic

View File

@@ -5,20 +5,6 @@ namespace ck {
namespace tensor_operation {
namespace element_wise {
struct ReduceSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v; }
};
struct ReduceSquareSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; }
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck

View File

@@ -8,6 +8,7 @@
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
namespace ck {
@@ -18,8 +19,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -39,8 +39,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D0ReduceOperation d0_reduce_op,
const D1ReduceOperation d1_reduce_op,
const D1ElementwiseOperation d1_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -60,8 +59,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
d0_reduce_op,
d1_reduce_op,
d1_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
@@ -76,8 +74,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d0_reduce_op;
ignore = d1_reduce_op;
ignore = d1_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -97,6 +94,7 @@ template <typename FloatAB,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
@@ -372,8 +370,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const D0ReduceOperation& d0_reduce_op,
const D1ReduceOperation& d1_reduce_op,
const D1ElementwiseOperation& d1_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
@@ -741,13 +738,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
@@ -763,7 +760,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatCShuffle,
FloatReduceAcc,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
@@ -775,7 +772,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce: copy from VGPR to global
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatCShuffle,
FloatReduceAcc,
FloatD,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
@@ -840,6 +837,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
using ThreadwiseReduce_D0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D0ReduceOperation,
false>;
using ThreadwiseReduce_D1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D1ReduceOperation,
false>;
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
// reduce
{
// copy from LDS to VGPR
@@ -850,26 +869,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf);
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
FloatReduceAcc d0_acc = d0_reduce_op.GetReduceZeroValue();
FloatReduceAcc d1_acc = d1_reduce_op.GetReduceZeroValue();
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d0_reduce_op.Reduce(d0_acc, c_reduce_thread_buf[offset]);
d1_reduce_op.Reduce(d1_acc, c_reduce_thread_buf[offset]);
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
});
constexpr index_t out_offset =
d_reduce_thread_desc_mperblock.CalculateOffset(make_tuple(im));
d0_thread_buf(Number<out_offset>{}) = d0_acc;
d1_thread_buf(Number<out_offset>{}) = d1_acc;
});
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),