mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
adding gemm pipeline
This commit is contained in:
@@ -11,9 +11,10 @@
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle.hpp"
|
||||
#include "device_gemm_xdl_cshuffle.hpp"
|
||||
//#include "device_gemm_xdl.hpp"
|
||||
//#include "device_gemm_xdl_c_shuffle.hpp"
|
||||
//#include "device_gemm_xdl_cshuffle.hpp"
|
||||
#include "device_gemm_xdl_cshuffle_v2.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
@@ -42,15 +43,39 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
#if 0
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// // 1-stage prefetch
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// // 2-stage prefetch
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
#elif 1
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
|
||||
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
|
||||
#elif 1
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
|
||||
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
|
||||
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#define CK_USE_LAUNCH_BOUNDS 1
|
||||
|
||||
#ifdef CK_USE_LAUNCH_BOUNDS
|
||||
#define CK_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_MAX_THREAD_PER_BLOCK 512
|
||||
#define CK_MIN_BLOCK_PER_CU 1
|
||||
#endif
|
||||
|
||||
|
||||
@@ -71,35 +71,35 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
|
||||
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
ABlockTransfer& a_block_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_block_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
index_t num_loop)
|
||||
{
|
||||
// global read 0
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// move to 1
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write 0
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global Read 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write 0
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global Read 1
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
// main body
|
||||
// FIXME: HasMainLoop = (num_loop) > 2
|
||||
@@ -116,18 +116,18 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// LDS write i + 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// global read i + 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
// LDS write i + 1
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
// global read i + 2
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 2));
|
||||
@@ -142,8 +142,8 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
a_block_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_block_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -153,7 +153,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
|
||||
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
const BlockwiseGemm& block_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
@@ -171,7 +171,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -192,7 +192,7 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -201,46 +201,45 @@ struct GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
ABlockTransfer& a_block_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_block_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
const BlockwiseGemm& block_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
if(ABBlockTransferThreadGroup::IsBelong())
|
||||
{
|
||||
gridwise_gemm_pipeline.RunABBlockTransferPipeline(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
num_loop);
|
||||
RunABBlockTransferPipeline(a_grid_desc,
|
||||
a_block_desc,
|
||||
a_block_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_copy_step,
|
||||
b_grid_desc,
|
||||
b_block_desc,
|
||||
b_block_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_copy_step,
|
||||
num_loop);
|
||||
}
|
||||
else if(BlockGemmThreadGroup::IsBelong())
|
||||
{
|
||||
gridwise_gemm_pipeline.RunBlockGemmPipeline(
|
||||
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_loop);
|
||||
RunBlockGemmPipeline(a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4,12 +4,11 @@
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "gridwise_gemm_pipeline_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -118,11 +117,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
using ThisThreadBlock =
|
||||
AnyThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
|
||||
|
||||
#if 1
|
||||
using ABBlockTransferThreadGroup = ThisThreadBlock;
|
||||
using BlockGemmThreadGroup = ThisThreadBlock;
|
||||
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
|
||||
#else
|
||||
struct ABBlockTransferThreadGroup
|
||||
{
|
||||
__device__ static constexpr index_t GetNumOfThread()
|
||||
@@ -157,7 +151,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
};
|
||||
|
||||
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -494,7 +487,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1<remove_cvref_t<decltype(a_grid_desc_ak0_m_ak1)>,
|
||||
@@ -667,9 +660,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
CShuffleBlockTransferThreadGroup, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
|
||||
@@ -24,38 +24,38 @@ include_directories(BEFORE
|
||||
set(PROFILER_SOURCE
|
||||
src/profiler.cpp
|
||||
src/profile_gemm.cpp
|
||||
src/profile_gemm_bias_2d.cpp
|
||||
src/profile_gemm_bias_relu.cpp
|
||||
src/profile_gemm_bias_relu_add.cpp
|
||||
src/profile_gemm_reduce.cpp
|
||||
src/profile_batched_gemm.cpp
|
||||
src/profile_conv_fwd.cpp
|
||||
src/profile_conv_fwd_bias_relu.cpp
|
||||
src/profile_conv_fwd_bias_relu_add.cpp
|
||||
src/profile_conv_fwd_bias_relu_atomic_add.cpp
|
||||
src/profile_convnd_bwd_data.cpp
|
||||
src/profile_reduce.cpp
|
||||
src/profile_grouped_gemm.cpp
|
||||
src/profile_conv_bwd_weight.cpp
|
||||
src/profile_batched_gemm_reduce.cpp
|
||||
# src/profile_gemm_bias_2d.cpp
|
||||
# src/profile_gemm_bias_relu.cpp
|
||||
# src/profile_gemm_bias_relu_add.cpp
|
||||
# src/profile_gemm_reduce.cpp
|
||||
# src/profile_batched_gemm.cpp
|
||||
# src/profile_conv_fwd.cpp
|
||||
# src/profile_conv_fwd_bias_relu.cpp
|
||||
# src/profile_conv_fwd_bias_relu_add.cpp
|
||||
# src/profile_conv_fwd_bias_relu_atomic_add.cpp
|
||||
# src/profile_convnd_bwd_data.cpp
|
||||
# src/profile_reduce.cpp
|
||||
# src/profile_grouped_gemm.cpp
|
||||
# src/profile_conv_bwd_weight.cpp
|
||||
# src/profile_batched_gemm_reduce.cpp
|
||||
)
|
||||
|
||||
add_executable(ckProfiler ${PROFILER_SOURCE})
|
||||
|
||||
target_link_libraries(ckProfiler PRIVATE host_tensor)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
|
||||
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
|
||||
|
||||
@@ -26,70 +26,70 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
|
||||
{
|
||||
return profile_gemm_bias_2d(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
|
||||
{
|
||||
return profile_gemm_bias_relu(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
|
||||
{
|
||||
return profile_gemm_bias_relu_add(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_reduce") == 0)
|
||||
{
|
||||
return profile_gemm_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm") == 0)
|
||||
{
|
||||
return profile_batched_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
|
||||
{
|
||||
return profile_batched_gemm_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "grouped_gemm") == 0)
|
||||
{
|
||||
profile_grouped_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd") == 0)
|
||||
{
|
||||
return profile_conv_fwd(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
|
||||
{
|
||||
return profile_conv_fwd_bias_relu(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
|
||||
{
|
||||
return profile_conv_fwd_bias_relu_add(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
|
||||
{
|
||||
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
|
||||
{
|
||||
return profile_convnd_bwd_data(argc, argv, 1);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
|
||||
{
|
||||
return profile_convnd_bwd_data(argc, argv, 2);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
|
||||
{
|
||||
return profile_convnd_bwd_data(argc, argv, 3);
|
||||
}
|
||||
else if(strcmp(argv[1], "reduce") == 0)
|
||||
{
|
||||
return profile_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
|
||||
{
|
||||
return profile_conv_bwd_weight(argc, argv);
|
||||
}
|
||||
// else if(strcmp(argv[1], "gemm_bias_2d") == 0)
|
||||
// {
|
||||
// return profile_gemm_bias_2d(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "gemm_bias_relu") == 0)
|
||||
// {
|
||||
// return profile_gemm_bias_relu(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
|
||||
// {
|
||||
// return profile_gemm_bias_relu_add(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "gemm_reduce") == 0)
|
||||
// {
|
||||
// return profile_gemm_reduce(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "batched_gemm") == 0)
|
||||
// {
|
||||
// return profile_batched_gemm(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
|
||||
// {
|
||||
// return profile_batched_gemm_reduce(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "grouped_gemm") == 0)
|
||||
// {
|
||||
// profile_grouped_gemm(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv_fwd") == 0)
|
||||
// {
|
||||
// return profile_conv_fwd(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
|
||||
// {
|
||||
// return profile_conv_fwd_bias_relu(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
|
||||
// {
|
||||
// return profile_conv_fwd_bias_relu_add(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
|
||||
// {
|
||||
// return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
|
||||
// {
|
||||
// return profile_convnd_bwd_data(argc, argv, 1);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
|
||||
// {
|
||||
// return profile_convnd_bwd_data(argc, argv, 2);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
|
||||
// {
|
||||
// return profile_convnd_bwd_data(argc, argv, 3);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "reduce") == 0)
|
||||
// {
|
||||
// return profile_reduce(argc, argv);
|
||||
// }
|
||||
// else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
|
||||
// {
|
||||
// return profile_conv_bwd_weight(argc, argv);
|
||||
// }
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
Reference in New Issue
Block a user