Gemm reduce max (#209)

* [What] Rename the example
[Why] Prepare to add unary reduction

* Add global oparation to the parameter

* Add atomicmax

* Fix compile error

* Support atomicMax (hip library)

* Rename the reduction example

* Fix target name

* use p_d1_grid as the indicator directly

* Prevent performance issue. Let passthrough handle it.

* Implement the function template the specialize the float2

* No need to separate into two lines

* Remove empty line

* add comment

* Fix compile error due to merge from develop

* make the implementation of atomic_max / atomic_add explicit for each datatype

* Refine typo

* For future CI test

* Fix compiler error in ckProfiler

* Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe'

* simply use remove_pointer

* Rename type and var

* Refine example

* Modify reducemax example

* Fix bug in reduction

* Change initialize range

* Implement F64 version of atomicMax

* Move reduction  code together

* Add buffer atomic_max

* Fix coding style by clang-format

* Integrate new api of DeviceGemmReduce_Xdl_CShuffle

* Integrate Batch gemm reduction

* Fix example

* fix example

* clean up

* Fix batch gemm tensor operation

* Fix coding style

* Fix template augument

* Fix clang format

* Keep flexible of different stride for each D tensor

* Fix compile error for ckProfiler

* Fix typo

* [What] Fix naming
[Why] Prepare to add out elementop

* Add DoutElementOp

Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: rocking <chunylai@amd.com>
This commit is contained in:
rocking5566
2022-05-20 10:56:56 +08:00
committed by GitHub
parent aafc3ac27a
commit 0ffe956ab1
28 changed files with 1298 additions and 626 deletions

View File

@@ -17,11 +17,12 @@ namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatD,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -37,13 +38,13 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
DPtrsGlobal p_ds_grid,
const index_t batch_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D1ElementwiseOperation d1_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsOutElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -64,23 +65,24 @@ __global__ void
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD1BasePtr(g_idx)));
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
});
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_d0_grid + d0_batch_offset,
p_d1_grid + d1_batch_offset,
p_ds_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
d1_element_op,
dxs_in_element_op,
dxs_out_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
@@ -90,13 +92,13 @@ __global__ void
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_d0_grid;
ignore = p_d1_grid;
ignore = p_ds_grid;
ignore = batch_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d1_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -118,13 +120,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DDataType,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -159,10 +162,12 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation>
DxsInElementwiseOperation,
DxsOutElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
@@ -508,13 +513,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t BatchStrideD0,
index_t BatchStrideD1)
index_t BatchStrideD)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideC_(BatchStrideC),
BatchStrideD0_(BatchStrideD0),
BatchStrideD1_(BatchStrideD1)
BatchStrideD_(BatchStrideD)
{
}
@@ -533,22 +536,20 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
template <index_t I>
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx,
Number<I> reduction_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0_);
}
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1_);
// TODO - Support sequence of StrideD in MakeArgument()
(void)reduction_idx;
return g_idx * static_cast<long_index_t>(BatchStrideD_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
index_t BatchStrideD0_;
index_t BatchStrideD1_;
index_t BatchStrideD_;
};
// GridwiseGemm
@@ -558,15 +559,15 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
CShuffleDataType,
CDataType,
ReduceAccDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
@@ -615,8 +616,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DDataType* p_d0_grid,
DDataType* p_d1_grid,
DPtrsGlobal p_ds_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -626,13 +626,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_d0_grid_{p_d0_grid},
p_d1_grid_{p_d1_grid},
p_ds_grid_{p_ds_grid},
BatchCount_(BatchCount),
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
@@ -644,13 +644,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d1_element_op_{d1_element_op}
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
@@ -670,8 +670,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DDataType* p_d0_grid_;
DDataType* p_d1_grid_;
DPtrsGlobal p_ds_grid_;
index_t BatchCount_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
@@ -685,7 +684,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D1ElementwiseOperation d1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsOutElementwiseOperation dxs_out_element_op_;
};
// Invoker
@@ -736,11 +736,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -758,13 +759,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.p_ds_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -778,11 +779,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -800,13 +802,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.p_ds_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -855,8 +857,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DDataType* p_d0,
DDataType* p_d1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -866,14 +867,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
index_t BatchCount)
{
return Argument{p_a,
p_b,
p_c,
p_d0,
p_d1,
p_dxs,
MRaw,
NRaw,
KRaw,
@@ -883,7 +884,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d1_element_op,
dxs_in_element_op,
dxs_out_element_op,
BatchCount};
}
@@ -893,8 +895,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_d0,
void* p_d1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -904,14 +905,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<DDataType*>(p_d0),
static_cast<DDataType*>(p_d1),
p_dxs,
MRaw,
NRaw,
KRaw,
@@ -921,7 +922,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
a_element_op,
b_element_op,
c_element_op,
d1_element_op,
dxs_in_element_op,
dxs_out_element_op,
BatchCount);
}

View File

@@ -6,40 +6,47 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation>
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_d0,
void* p_d1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
DPtrsGlobal p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation>>;
DxsInElementwiseOperation,
DxsOutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation

View File

@@ -26,13 +26,14 @@ template <typename ALayout,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename DDataType,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
@@ -67,10 +68,12 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation>
DxsInElementwiseOperation,
DxsOutElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
@@ -380,15 +383,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
CShuffleDataType,
CDataType,
ReduceAccDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ReduceOperation,
D1ReduceOperation,
D1ElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum::AtomicAdd,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
@@ -435,8 +438,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DDataType* p_d0_grid,
DDataType* p_d1_grid,
DPtrsGlobal p_ds_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -446,12 +448,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op)
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_d0_grid_{p_d0_grid},
p_d1_grid_{p_d1_grid},
p_ds_grid_{p_ds_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
@@ -462,7 +464,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d1_element_op_{d1_element_op}
dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op}
{
if(GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
@@ -482,8 +485,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
DDataType* p_d0_grid_;
DDataType* p_d1_grid_;
DPtrsGlobal p_ds_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
@@ -495,7 +497,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D1ElementwiseOperation d1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsOutElementwiseOperation dxs_out_element_op_;
};
// Invoker
@@ -543,11 +546,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -564,12 +568,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.p_ds_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -582,11 +586,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DDataType,
DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D1ElementwiseOperation,
DxsInElementwiseOperation,
DxsOutElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -603,12 +608,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_d0_grid_,
arg.p_d1_grid_,
arg.p_ds_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.d1_element_op_,
arg.dxs_in_element_op_,
arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -648,8 +653,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
DDataType* p_d0,
DDataType* p_d1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -659,13 +663,13 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op)
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_d0,
p_d1,
p_dxs,
MRaw,
NRaw,
KRaw,
@@ -675,7 +679,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d1_element_op};
dxs_in_element_op,
dxs_out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -684,8 +689,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_d0,
void* p_d1,
DPtrsGlobal p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
@@ -695,14 +699,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D1ElementwiseOperation d1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsOutElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<DDataType*>(p_d0),
static_cast<DDataType*>(p_d1),
p_dxs,
MRaw,
NRaw,
KRaw,
@@ -712,7 +716,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
a_element_op,
b_element_op,
c_element_op,
d1_element_op);
dxs_in_element_op,
dxs_out_element_op);
}
// polymorphic

View File

@@ -15,11 +15,12 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatD,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -34,12 +35,12 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
DPtrsGlobal p_ds_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const D1ElementwiseOperation d1_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsOutElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -53,13 +54,13 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_d0_grid,
p_d1_grid,
p_ds_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
d1_element_op,
dxs_in_element_op,
dxs_out_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
@@ -69,12 +70,12 @@ __global__ void
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_d0_grid;
ignore = p_d1_grid;
ignore = p_ds_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = d1_element_op;
ignore = dxs_in_element_op;
ignore = dxs_out_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -88,15 +89,15 @@ template <typename FloatAB,
typename FloatCShuffle,
typename FloatC,
typename FloatReduceAcc,
typename FloatD,
typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ReduceOperation,
typename D1ReduceOperation,
typename D1ElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsOutElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
@@ -357,13 +358,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatD* __restrict__ p_d0_grid,
FloatD* __restrict__ p_d1_grid,
DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const D1ElementwiseOperation& d1_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsOutElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
@@ -377,10 +378,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
@@ -527,7 +524,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
// shuffle C + reduction + write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
@@ -666,6 +663,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// TODO: this should be implemented as a blockwise reduction
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
@@ -716,16 +736,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto d_reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
@@ -749,47 +762,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
// reduce: copy from VGPR to global
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatD,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation,
1,
false>{d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
ck::tensor_operation::element_wise::PassThrough{}};
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_d_grid = p_ds_grid[I];
auto d_out_element_op = dxs_out_element_op[I];
auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global;
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
remove_pointer_t<decltype(p_d_grid)>,
decltype(d_reduce_thread_desc_mblock_mperblock),
decltype(d_grid_desc_mblock_mperblock),
decltype(d_out_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
DGlobalMemoryDataOperation::At(I),
1,
false>{d_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
d_out_element_op};
},
Number<p_ds_grid.Size()>{});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
@@ -816,64 +811,73 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
using ThreadwiseReduce_D0 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D0ReduceOperation,
false>;
using ThreadwiseReduce_D1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
D1ReduceOperation,
false>;
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
// reduce
// TODO - extract following into reduction_blockwise
{
// copy from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
// reduce in VGPR
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In];
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
auto d_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& d_in_element_op = dxs_in_element_op[In];
auto& d_reduce_thread_copy_vgpr_to_global =
dxs_reduce_thread_copy_vgpr_to_global(In);
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
DReduceOperation,
false>;
// Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
d_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global.Run(
d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d_thread_buf,
d_grid_desc_mblock_mperblock,
d_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d0_thread_buf,
d_grid_desc_mblock_mperblock,
d0_grid_buf);
d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
d1_thread_buf,
d_grid_desc_mblock_mperblock,
d1_grid_buf);
}
if constexpr(access_id < num_access - 1)
@@ -883,18 +887,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on D0
d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
// move on D1
d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
d_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
// Reduction
}
}
};