mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
GEMM pipeline v2 (#317)
* format
* improving pipeline
* fix typo
* format
* adding thread group
* adding thread group
* adding thread group
* adding gemm pipeline
* tweak
* refactor
* refactor
* add missing type convert
* refactor
* refactor
* refactor
* clean
* fix build
* refactor
* format
* clean up
* use remove_cvref_t
* clean
* use pipeline_v2 for gemm kernel
* Remove inconsistent indent
* Fix compilation errors due to incomplete merge process
* Add missing include directives
* Fix compilation errors in currently unused files
* Add license in newly added files
* Re-format touched files by clang-format-10
* Fix wrong template argument count of DeviceGemm<>
* Use language construct to choose between types
* Use language construct to choose GEMM example instance
* Fix compilation error due to interface change
* Re-use type alias to avoid duplication
* Unify type alias usage in source file
* Only use v2 pipeline in one gridwise GEMM type
* Remove no-longer used include directives
* Add static_assert() to check pipeline type requirements
* Revert "Add static_assert() to check pipeline type requirements"
This reverts commit f0985f0a13.
* clean
* clean
* clean
* clean
Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: shaojiewang <wsjmessi@163.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -44,8 +45,17 @@ using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
|
||||
// clang-format off
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
@@ -53,6 +63,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceGemmInstance = DeviceGemmInstance0;
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user