mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Gemm reduce max (#209)
* [What] Rename the example
[Why] Prepare to add unary reduction
* Add global oparation to the parameter
* Add atomicmax
* Fix compile error
* Support atomicMax (hip library)
* Rename the reduction example
* Fix target name
* use p_d1_grid as the indicator directly
* Prevent performance issue. Let passthrough handle it.
* Implement the function template the specialize the float2
* No need to separate into two lines
* Remove empty line
* add comment
* Fix compile error due to merge from develop
* make the implementation of atomic_max / atomic_add explicit for each datatype
* Refine typo
* For future CI test
* Fix compiler error in ckProfiler
* Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe'
* simply use remove_pointer
* Rename type and var
* Refine example
* Modify reducemax example
* Fix bug in reduction
* Change initialize range
* Implement F64 version of atomicMax
* Move reduction code together
* Add buffer atomic_max
* Fix coding style by clang-format
* Integrate new api of DeviceGemmReduce_Xdl_CShuffle
* Integrate Batch gemm reduction
* Fix example
* fix example
* clean up
* Fix batch gemm tensor operation
* Fix coding style
* Fix template augument
* Fix clang format
* Keep flexible of different stride for each D tensor
* Fix compile error for ckProfiler
* Fix typo
* [What] Fix naming
[Why] Prepare to add out elementop
* Add DoutElementOp
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: rocking <chunylai@amd.com>
[ROCm/composable_kernel commit: 0ffe956ab1]
This commit is contained in:
@@ -76,6 +76,12 @@
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if defined(__gfx90a__) // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
#endif
|
||||
|
||||
// inline asm
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
|
||||
@@ -91,10 +97,11 @@
|
||||
// experimental feature: static tensor descriptor
|
||||
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
|
||||
|
||||
// experimental feature: buffer load/store/atomic-add OOB trick
|
||||
// experimental feature: buffer load/store/atomic-add/ OOB trick
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
|
||||
|
||||
// experimental feature: in-regsiter sub-dword transpose
|
||||
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
|
||||
@@ -142,9 +149,23 @@ enum struct InMemoryDataOperationEnum
|
||||
{
|
||||
Set,
|
||||
AtomicAdd,
|
||||
AtomicMax,
|
||||
Add
|
||||
};
|
||||
|
||||
template <InMemoryDataOperationEnum... Is>
|
||||
struct InMemoryDataOperationEnumSequence
|
||||
{
|
||||
static constexpr int mSize = sizeof...(Is);
|
||||
|
||||
__host__ __device__ static constexpr InMemoryDataOperationEnum At(int I)
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set};
|
||||
return mData[I];
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: no longer needed, remove this
|
||||
enum struct ActivTypeEnum
|
||||
{
|
||||
|
||||
@@ -17,11 +17,12 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -37,13 +38,13 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsOutElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -64,23 +65,24 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
|
||||
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD1BasePtr(g_idx)));
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
|
||||
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
|
||||
});
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_d0_grid + d0_batch_offset,
|
||||
p_d1_grid + d1_batch_offset,
|
||||
p_ds_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -90,13 +92,13 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0_grid;
|
||||
ignore = p_d1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -118,13 +120,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -159,10 +162,12 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -508,13 +513,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t BatchStrideD0,
|
||||
index_t BatchStrideD1)
|
||||
index_t BatchStrideD)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideC_(BatchStrideC),
|
||||
BatchStrideD0_(BatchStrideD0),
|
||||
BatchStrideD1_(BatchStrideD1)
|
||||
BatchStrideD_(BatchStrideD)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -533,22 +536,20 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx,
|
||||
Number<I> reduction_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1_);
|
||||
// TODO - Support sequence of StrideD in MakeArgument()
|
||||
(void)reduction_idx;
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
index_t BatchStrideD0_;
|
||||
index_t BatchStrideD1_;
|
||||
index_t BatchStrideD_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
@@ -558,15 +559,15 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
@@ -615,8 +616,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DDataType* p_d0_grid,
|
||||
DDataType* p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -626,13 +626,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0_grid_{p_d0_grid},
|
||||
p_d1_grid_{p_d1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
BatchCount_(BatchCount),
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
@@ -644,13 +644,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d1_element_op_{d1_element_op}
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
@@ -670,8 +670,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DDataType* p_d0_grid_;
|
||||
DDataType* p_d1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
index_t BatchCount_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
@@ -685,7 +684,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsOutElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -736,11 +736,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -758,13 +759,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -778,11 +779,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -800,13 +802,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -855,8 +857,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DDataType* p_d0,
|
||||
DDataType* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -866,14 +867,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d0,
|
||||
p_d1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -883,7 +884,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount};
|
||||
}
|
||||
|
||||
@@ -893,8 +895,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -904,14 +905,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<DDataType*>(p_d0),
|
||||
static_cast<DDataType*>(p_d1),
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -921,7 +922,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,40 +6,47 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
template <typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation>
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation>
|
||||
struct DeviceGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
template <typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>>;
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -26,13 +26,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -67,10 +68,12 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -380,15 +383,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
@@ -435,8 +438,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DDataType* p_d0_grid,
|
||||
DDataType* p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -446,12 +448,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op)
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0_grid_{p_d0_grid},
|
||||
p_d1_grid_{p_d1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
@@ -462,7 +464,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d1_element_op_{d1_element_op}
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
@@ -482,8 +485,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DDataType* p_d0_grid_;
|
||||
DDataType* p_d1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
@@ -495,7 +497,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsOutElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -543,11 +546,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -564,12 +568,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -582,11 +586,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsOutElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -603,12 +608,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
@@ -648,8 +653,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DDataType* p_d0,
|
||||
DDataType* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -659,13 +663,13 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op)
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d0,
|
||||
p_d1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -675,7 +679,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op};
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -684,8 +689,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -695,14 +699,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsOutElementwiseOperation dxs_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<DDataType*>(p_d0),
|
||||
static_cast<DDataType*>(p_d1),
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -712,7 +716,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op);
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -15,11 +15,12 @@ namespace ck {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -34,12 +35,12 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsOutElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -53,13 +54,13 @@ __global__ void
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_d0_grid,
|
||||
p_d1_grid,
|
||||
p_ds_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -69,12 +70,12 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0_grid;
|
||||
ignore = p_d1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -88,15 +89,15 @@ template <typename FloatAB,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename FloatReduceAcc,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsOutElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum DGlobalMemoryDataOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
@@ -357,13 +358,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const D1ElementwiseOperation& d1_element_op,
|
||||
const DxsInElementwiseOperation& dxs_in_element_op,
|
||||
const DxsOutElementwiseOperation& dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
@@ -377,10 +378,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
@@ -527,7 +524,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
// shuffle C + reduction + write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
@@ -666,6 +663,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
// LDS c_reduce_block_desc_mperblock_nperblock
|
||||
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -716,16 +736,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr auto d_reduce_thread_desc_mblock_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
// reduce: threadwise copy from LDS to VGPR
|
||||
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
|
||||
@@ -749,47 +762,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
1,
|
||||
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
|
||||
|
||||
// reduce: copy from VGPR to global
|
||||
auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
FloatD,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation,
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto p_d_grid = p_ds_grid[I];
|
||||
auto d_out_element_op = dxs_out_element_op[I];
|
||||
|
||||
auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global;
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
remove_pointer_t<decltype(p_d_grid)>,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
decltype(d_out_element_op),
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation::At(I),
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
d_out_element_op};
|
||||
},
|
||||
Number<p_ds_grid.Size()>{});
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -816,64 +811,73 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
using ThreadwiseReduce_D0 =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
D0ReduceOperation,
|
||||
false>;
|
||||
|
||||
using ThreadwiseReduce_D1 =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
D1ReduceOperation,
|
||||
false>;
|
||||
|
||||
const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal();
|
||||
const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d0_thread_buf(I) = d0_zeroVal; });
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d1_thread_buf(I) = d1_zeroVal; });
|
||||
|
||||
// reduce
|
||||
// TODO - extract following into reduction_blockwise
|
||||
{
|
||||
// copy from LDS to VGPR
|
||||
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_reduce_thread_desc_mperblock_nperblock,
|
||||
make_tuple(I0, I0),
|
||||
c_reduce_thread_buf);
|
||||
|
||||
// reduce in VGPR
|
||||
ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf);
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_d_grid = p_ds_grid[In];
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset));
|
||||
auto d_thread_buf =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto& d_in_element_op = dxs_in_element_op[In];
|
||||
|
||||
auto& d_reduce_thread_copy_vgpr_to_global =
|
||||
dxs_reduce_thread_copy_vgpr_to_global(In);
|
||||
|
||||
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
|
||||
using ThreadwiseReduce =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
DReduceOperation,
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
d_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d_reduce_thread_copy_vgpr_to_global.Run(
|
||||
d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d_grid_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
|
||||
ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d0_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d0_grid_buf);
|
||||
|
||||
d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d1_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d1_grid_buf);
|
||||
}
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
@@ -883,18 +887,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
// move on D0
|
||||
d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
|
||||
// move on D1
|
||||
d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
|
||||
// Reduction
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
__device__ double
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
if constexpr(is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<double, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(double),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<double, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(double),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(double),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(double),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_atomic_max requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
#include "debug.hpp"
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "generic_memory_space_atomic_add.hpp"
|
||||
#include "generic_memory_space_atomic.hpp"
|
||||
#include "get_id.hpp"
|
||||
#include "synchronization.hpp"
|
||||
#include "amd_address_space.hpp"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "enable_if.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "generic_memory_space_atomic_add.hpp"
|
||||
#include "generic_memory_space_atomic.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -125,6 +125,10 @@ struct DynamicBuffer
|
||||
{
|
||||
this->template AtomicAdd<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
|
||||
{
|
||||
this->template AtomicMax<X>(i, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == InMemoryDataOperationEnum::Add)
|
||||
{
|
||||
auto tmp = this->template Get<X>(i, is_valid_element);
|
||||
@@ -326,6 +330,42 @@ struct DynamicBuffer
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
|
||||
typename scalar_type<remove_cvref_t<T>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
|
||||
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
|
||||
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
|
||||
97
include/ck/utility/generic_memory_space_atomic.hpp
Normal file
97
include/ck/utility/generic_memory_space_atomic.hpp
Normal file
@@ -0,0 +1,97 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
|
||||
// each datatype.
|
||||
template <typename X>
|
||||
__device__ X atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
__device__ int32_t atomic_add<int32_t>(int32_t* p_dst, const int32_t& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ uint32_t atomic_add<uint32_t>(uint32_t* p_dst, const uint32_t& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float atomic_add<float>(float* p_dst, const float& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
const vector_type<float, 2> vx{x};
|
||||
vector_type<float, 2> vy{0};
|
||||
|
||||
vy.template AsType<float>()(I0) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
||||
vy.template AsType<float>()(I1) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
||||
|
||||
return vy.template AsType<float2_t>()[I0];
|
||||
}
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
|
||||
// each datatype.
|
||||
|
||||
template <typename X>
|
||||
__device__ X atomic_max(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
__device__ int32_t atomic_max<int32_t>(int32_t* p_dst, const int32_t& x)
|
||||
{
|
||||
return atomicMax(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ uint32_t atomic_max<uint32_t>(uint32_t* p_dst, const uint32_t& x)
|
||||
{
|
||||
return atomicMax(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float atomic_max<float>(float* p_dst, const float& x)
|
||||
{
|
||||
return atomicMax(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ double atomic_max<double>(double* p_dst, const double& x)
|
||||
{
|
||||
return atomicMax(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
const vector_type<float, 2> vx{x};
|
||||
vector_type<float, 2> vy{0};
|
||||
|
||||
vy.template AsType<float>()(I0) =
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
||||
vy.template AsType<float>()(I1) =
|
||||
atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
||||
|
||||
return vy.template AsType<float2_t>()[I0];
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -1,44 +0,0 @@
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename X>
|
||||
__device__ X atomic_add(X* p_dst, const X& x);
|
||||
|
||||
template <>
|
||||
__device__ int32_t atomic_add<int32_t>(int32_t* p_dst, const int32_t& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ uint32_t atomic_add<uint32_t>(uint32_t* p_dst, const uint32_t& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float atomic_add<float>(float* p_dst, const float& x)
|
||||
{
|
||||
return atomicAdd(p_dst, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
const vector_type<float, 2> vx{x};
|
||||
vector_type<float, 2> vy{0};
|
||||
|
||||
vy.template AsType<float>()(I0) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst), vx.template AsType<float>()[I0]);
|
||||
vy.template AsType<float>()(I1) =
|
||||
atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, vx.template AsType<float>()[I1]);
|
||||
|
||||
return vy.template AsType<float2_t>()[I0];
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
template <typename T>
|
||||
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <typename T>
|
||||
using remove_pointer_t = typename std::remove_pointer<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user