mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Gemm reduce max (#209)
* [What] Rename the example [Why] Prepare to add unary reduction * Add global oparation to the parameter * Add atomicmax * Fix compile error * Support atomicMax (hip library) * Rename the reduction example * Fix target name * use p_d1_grid as the indicator directly * Prevent performance issue. Let passthrough handle it. * Implement the function template the specialize the float2 * No need to separate into two lines * Remove empty line * add comment * Fix compile error due to merge from develop * make the implementation of atomic_max / atomic_add explicit for each datatype * Refine typo * For future CI test * Fix compiler error in ckProfiler * Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe' * simply use remove_pointer * Rename type and var * Refine example * Modify reducemax example * Fix bug in reduction * Change initialize range * Implement F64 version of atomicMax * Move reduction code together * Add buffer atomic_max * Fix coding style by clang-format * Integrate new api of DeviceGemmReduce_Xdl_CShuffle * Integrate Batch gemm reduction * Fix example * fix example * clean up * Fix batch gemm tensor operation * Fix coding style * Fix template augument * Fix clang format * Keep flexible of different stride for each D tensor * Fix compile error for ckProfiler * Fix typo * [What] Fix naming [Why] Prepare to add out elementop * Add DoutElementOp Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
@@ -17,11 +17,12 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -37,13 +38,13 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsOutElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -64,23 +65,24 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
|
||||
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD1BasePtr(g_idx)));
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
|
||||
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
|
||||
});
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_d0_grid + d0_batch_offset,
|
||||
p_d1_grid + d1_batch_offset,
|
||||
p_ds_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -90,13 +92,13 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0_grid;
|
||||
ignore = p_d1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -118,13 +120,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -159,10 +162,12 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -508,13 +513,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t BatchStrideD0,
|
||||
index_t BatchStrideD1)
|
||||
index_t BatchStrideD)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideC_(BatchStrideC),
|
||||
BatchStrideD0_(BatchStrideD0),
|
||||
BatchStrideD1_(BatchStrideD1)
|
||||
BatchStrideD_(BatchStrideD)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -533,22 +536,20 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx,
|
||||
Number<I> reduction_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1_);
|
||||
// TODO - Support sequence of StrideD in MakeArgument()
|
||||
(void)reduction_idx;
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
index_t BatchStrideD0_;
|
||||
index_t BatchStrideD1_;
|
||||
index_t BatchStrideD_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
@@ -558,15 +559,15 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
@@ -615,8 +616,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DDataType* p_d0_grid,
|
||||
DDataType* p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -626,13 +626,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0_grid_{p_d0_grid},
|
||||
p_d1_grid_{p_d1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
BatchCount_(BatchCount),
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
@@ -644,13 +644,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d1_element_op_{d1_element_op}
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
@@ -670,8 +670,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DDataType* p_d0_grid_;
|
||||
DDataType* p_d1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
index_t BatchCount_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
@@ -685,7 +684,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsOutElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -736,11 +736,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -758,13 +759,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -778,11 +779,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -800,13 +802,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -855,8 +857,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DDataType* p_d0,
|
||||
DDataType* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -866,14 +867,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d0,
|
||||
p_d1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -883,7 +884,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount};
|
||||
}
|
||||
|
||||
@@ -893,8 +895,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -904,14 +905,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<DDataType*>(p_d0),
|
||||
static_cast<DDataType*>(p_d1),
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -921,7 +922,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,40 +6,47 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
template <typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation>
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation>
|
||||
struct DeviceGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
template <typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>>;
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -26,13 +26,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -67,10 +68,12 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -380,15 +383,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
@@ -435,8 +438,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DDataType* p_d0_grid,
|
||||
DDataType* p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -446,12 +448,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op)
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0_grid_{p_d0_grid},
|
||||
p_d1_grid_{p_d1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
@@ -462,7 +464,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d1_element_op_{d1_element_op}
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
@@ -482,8 +485,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DDataType* p_d0_grid_;
|
||||
DDataType* p_d1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
@@ -495,7 +497,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsOutElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -543,11 +546,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -564,12 +568,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -582,11 +586,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -603,12 +608,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -648,8 +653,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DDataType* p_d0,
|
||||
DDataType* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -659,13 +663,13 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op)
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d0,
|
||||
p_d1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -675,7 +679,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op};
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -684,8 +689,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -695,14 +699,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<DDataType*>(p_d0),
|
||||
static_cast<DDataType*>(p_d1),
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -712,7 +716,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op);
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
Reference in New Issue
Block a user