mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Update to gemm_reduce and batched_gemm_reduce (#213)
* [Experimental] Change to gemm+reduce and batched-gemm+reduce * Use threadwise-reduce function to improve the gridwise_gemm_reduce_xdl_cshuffle kernel * Tiny fix in device_batched_gemm_xdl.hpp * clang-format library/src/utility/conv_fwd_util.cpp
This commit is contained in:
@@ -21,8 +21,7 @@ template <typename GridwiseGemm,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -44,8 +43,7 @@ __global__ void
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D0ReduceOperation d0_reduce_op,
|
||||
const D1ReduceOperation d1_reduce_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -82,8 +80,7 @@ __global__ void
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op,
|
||||
d1_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -99,8 +96,7 @@ __global__ void
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d0_reduce_op;
|
||||
ignore = d1_reduce_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -125,6 +121,7 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -161,8 +158,7 @@ template <typename ALayout,
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation>
|
||||
D1ElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -564,6 +560,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
@@ -624,8 +621,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
index_t BatchCount)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
@@ -648,8 +644,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d0_reduce_op_{d0_reduce_op},
|
||||
d1_reduce_op_{d1_reduce_op}
|
||||
d1_element_op_{d1_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
@@ -684,8 +679,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D0ReduceOperation d0_reduce_op_;
|
||||
D1ReduceOperation d1_reduce_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -740,8 +734,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -763,8 +756,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d0_reduce_op_,
|
||||
arg.d1_reduce_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -782,8 +774,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -805,8 +796,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d0_reduce_op_,
|
||||
arg.d1_reduce_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -865,8 +855,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
index_t BatchCount)
|
||||
{
|
||||
return Argument{p_a,
|
||||
@@ -883,8 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op,
|
||||
d1_element_op,
|
||||
BatchCount};
|
||||
}
|
||||
|
||||
@@ -905,8 +893,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
index_t BatchCount) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
@@ -923,8 +910,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op,
|
||||
d1_element_op,
|
||||
BatchCount);
|
||||
}
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ __global__ void
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = compute_base_ptr_of_batch_;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace device {
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation>
|
||||
typename D1ElementwiseOperation>
|
||||
struct DeviceGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
@@ -27,8 +26,7 @@ struct DeviceGemmReduce : public BaseOperator
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
@@ -37,13 +35,11 @@ struct DeviceGemmReduce : public BaseOperator
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation>
|
||||
typename D1ElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation>>;
|
||||
D1ElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -29,6 +29,7 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -65,8 +66,7 @@ template <typename ALayout,
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation>
|
||||
D1ElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -382,6 +382,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
@@ -440,8 +441,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op)
|
||||
D1ElementwiseOperation d1_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
@@ -457,8 +457,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d0_reduce_op_{d0_reduce_op},
|
||||
d1_reduce_op_{d1_reduce_op}
|
||||
d1_element_op_{d1_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
@@ -491,8 +490,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D0ReduceOperation d0_reduce_op_;
|
||||
D1ReduceOperation d1_reduce_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -544,8 +542,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -565,8 +562,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d0_reduce_op_,
|
||||
arg.d1_reduce_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -583,8 +579,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -604,8 +599,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d0_reduce_op_,
|
||||
arg.d1_reduce_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -655,8 +649,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op)
|
||||
D1ElementwiseOperation d1_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -672,8 +665,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op};
|
||||
d1_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -693,8 +685,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D0ReduceOperation d0_reduce_op,
|
||||
D1ReduceOperation d1_reduce_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
@@ -711,8 +702,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d0_reduce_op,
|
||||
d1_reduce_op);
|
||||
d1_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
Reference in New Issue
Block a user