From 19c5d6e651d00d15b3909bf1ba44bf59df7f29cf Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Tue, 22 Feb 2022 01:35:21 +0800 Subject: [PATCH] Gemm alpha beta profiler (fp32 & fp16) (#91) * [What] Refactor verification of gemm alpha_beta, move to reference operation [Why] Sync with other verification * Profile mk_nk for gemm bias 2d * Support bias 2d with mn * kn in profiler * Support bias 2d with km*kn and km*nk in profiler * Support fp32 bias 2d in profiler * format * format Co-authored-by: rocking Co-authored-by: Chao Liu --- device_operation/CMakeLists.txt | 73 ++-- .../device_gemm_xdl_c_shuffle_bias_2d.hpp | 2 +- ..._bias_2d_f16_f16_f16_km_kn_mn_instance.cpp | 52 +++ ..._bias_2d_f16_f16_f16_km_nk_mn_instance.cpp | 52 +++ ..._bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp | 52 +++ ..._bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp | 57 ++++ ..._bias_2d_f32_f32_f32_km_kn_mn_instance.cpp | 51 +++ ..._bias_2d_f32_f32_f32_km_nk_mn_instance.cpp | 51 +++ ..._bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp | 51 +++ ..._bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp | 56 ++++ .../gemm_xdl_alpha_beta.cpp | 65 ++-- profiler/CMakeLists.txt | 4 +- .../include/profile_gemm_bias_2d_impl.hpp | 311 ++++++++++++++++++ profiler/src/profile_gemm_bias_2d.cpp | 261 +++++++++++++++ profiler/src/profiler.cpp | 6 + .../include/reference_gemm_bias_2d.hpp | 133 ++++++++ 16 files changed, 1203 insertions(+), 74 deletions(-) create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profile_gemm_bias_2d_impl.hpp create mode 100644 profiler/src/profile_gemm_bias_2d.cpp create mode 100644 reference_operation/include/reference_gemm_bias_2d.hpp diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index 31fa455301..440e16c2fa 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -13,7 +13,7 @@ include_directories(BEFORE ) # device_gemm_instance -set(DEVICE_GEMM_INSTANCE_SOURCE +set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; @@ -30,23 +30,35 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; -) +) + +# device_gemm_bias_2d_instance +set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp; +) # device_gemm_bias_relu_instance -set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE +set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; -) +) # device_gemm_bias_relu_add_instance -set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE +set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; -) +) set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp; @@ -56,39 +68,41 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE ) # device_conv2d_fwd_instance -set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE +set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; -) +) # device_conv2d_fwd_bias_relu_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE +set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; -) +) # device_conv2d_fwd_bias_relu_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE +set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) +) # device_conv2d_fwd_bias_relu_atomic_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE +set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) +) -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) +add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE}) +add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) +add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) target_include_directories(device_gemm_instance SYSTEM PUBLIC $) +target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $) @@ -98,6 +112,7 @@ target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLI target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) target_compile_features(device_gemm_instance PUBLIC) +target_compile_features(device_gemm_bias_2d_instance PUBLIC) target_compile_features(device_gemm_bias_relu_instance PUBLIC) target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) target_compile_features(device_batched_gemm_instance PUBLIC) @@ -107,6 +122,7 @@ target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -115,11 +131,12 @@ set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_I set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) +install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib) +install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) +install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) diff --git a/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp b/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp index 6ee7967382..fcdc512477 100644 --- a/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp +++ b/device_operation/include/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -490,7 +490,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d auto str = std::stringstream(); // clang-format off - str << "DeviceGemmXdl" + str << "DeviceGemmXdl_C_Shuffle_Bias_2d" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..bd16850ee4 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..12740ce256 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..56db0475ef --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..b20ee8db69 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..11984c36db --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..bd0a988059 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..440ea1582e --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..fab885969f --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_2d.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AlphaBetaAdd = ck::tensor_operation::element_wise::AlphaBetaAdd; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#############################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle_Bias_2d< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, AlphaBetaAdd, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp index 2a7b6991e2..51a31bcfb7 100644 --- a/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp +++ b/example/8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -14,6 +14,7 @@ #include "device_base.hpp" #include "device_gemm_xdl_c_shuffle_bias_2d.hpp" #include "element_wise_operation.hpp" +#include "reference_gemm_bias_2d.hpp" template using S = ck::Sequence; @@ -72,43 +73,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -template -static void host_verify(const Tensor& a_m_k, - const Tensor& b_k_n, - const Tensor& c0_k_n, - Tensor& c_m_n, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op) -{ - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a_m_k.mDesc.GetLengths()[1]; - - AccDataType v = 0; - AccDataType a = 0; - AccDataType b = 0; - for(int k = 0; k < K; ++k) - { - a_element_op(a, a_m_k(m, k)); - b_element_op(b, b_k_n(k, n)); - v += a * b; - } - - CType y = static_cast(v); - - c_element_op(c_m_n(m, n), y, c0_k_n(m, n)); - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, - c_m_n.mDesc.GetLengths()[0], - c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); -} +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D; int main(int argc, char* argv[]) { @@ -259,13 +231,18 @@ int main(int argc, char* argv[]) if(do_verification) { - host_verify(a_m_k, - b_k_n, - c0_m_n, - c_m_n_host_result, - AElementOp{}, - BElementOp{}, - CElementOp{alpha, beta}); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c0_m_n, + c_m_n_host_result, + AElementOp{}, + BElementOp{}, + CElementOp{alpha, beta}); + + ref_invoker.Run(ref_argument); check_error(c_m_n_host_result, c_m_n_device_result); } diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index a25e64f5ba..18b7a89363 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -14,9 +14,10 @@ include_directories(BEFORE ) # ck_profiler -set(PROFILER_SOURCE +set(PROFILER_SOURCE src/profiler.cpp src/profile_gemm.cpp + src/profile_gemm_bias_2d.cpp src/profile_gemm_bias_relu.cpp src/profile_gemm_bias_relu_add.cpp src/profile_conv_fwd.cpp @@ -30,6 +31,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_2d_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) diff --git a/profiler/include/profile_gemm_bias_2d_impl.hpp b/profiler/include/profile_gemm_bias_2d_impl.hpp new file mode 100644 index 0000000000..94223c4f7a --- /dev/null +++ b/profiler/include/profile_gemm_bias_2d_impl.hpp @@ -0,0 +1,311 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm.hpp" +#include "reference_gemm_bias_2d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AlphaBetaAdd>; + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_2d_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + float alpha, + float beta) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{alpha, beta}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c0_m_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c0_device_buf(sizeof(C0DataType) * c0_m_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c0_device_buf.ToDevice(c0_m_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c0_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0 : ", c0_m_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_gemm_bias_2d.cpp b/profiler/src/profile_gemm_bias_2d.cpp new file mode 100644 index 0000000000..29fabb3579 --- /dev/null +++ b/profiler/src/profile_gemm_bias_2d.cpp @@ -0,0 +1,261 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_2d_impl.hpp" + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm_bias_2d(int argc, char* argv[]) +{ + if(!(argc == 16 || argc == 17)) + { + printf("arg1: tensor operation (gemm: GEMM+Bias)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: alpha\n"); + printf("arg15: beta\n"); + printf("arg16: split k into mulitiple batch\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const float alpha = std::stof(argv[14]); + const float beta = std::stof(argv[15]); + + int KBatch = 1; + + if(argc == 17) + KBatch = std::stoi(argv[16]); + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_2d_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + alpha, + beta); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 399ea8ee4d..c6a5a4cbc9 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -7,6 +7,7 @@ int profile_gemm(int, char*[]); int profile_batched_gemm(int, char*[]); +int profile_gemm_bias_2d(int, char*[]); int profile_gemm_bias_relu(int, char*[]); int profile_gemm_bias_relu_add(int, char*[]); int profile_conv_fwd(int, char*[]); @@ -20,6 +21,10 @@ int main(int argc, char* argv[]) { return profile_gemm(argc, argv); } + else if(strcmp(argv[1], "gemm_bias_2d") == 0) + { + return profile_gemm_bias_2d(argc, argv); + } else if(strcmp(argv[1], "gemm_bias_relu") == 0) { return profile_gemm_bias_relu(argc, argv); @@ -52,6 +57,7 @@ int main(int argc, char* argv[]) { // clang-format off printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_bias_2d: GEMM+Bias(2D)\n" " gemm_bias_relu: GEMM+Bias+ReLU\n" " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" " conv_fwd: ForwardConvolution\n" diff --git a/reference_operation/include/reference_gemm_bias_2d.hpp b/reference_operation/include/reference_gemm_bias_2d.hpp new file mode 100644 index 0000000000..7dd6fc9199 --- /dev/null +++ b/reference_operation/include/reference_gemm_bias_2d.hpp @@ -0,0 +1,133 @@ +#ifndef REFERENCE_GEMM_BIAS_BIAS_2D_HPP +#define REFERENCE_GEMM_BIAS_BIAS_2D_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBias2D : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c0_m_n_{c0_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& c0_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBias2D::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType a = 0; + AccDataType b = 0; + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + arg.a_element_op_(a, arg.a_m_k_(m, k)); + arg.b_element_op_(b, arg.b_k_n_(k, n)); + acc += a * b; + } + + CDataType cast_acc = static_cast(acc); + arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n)); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& c0_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBias2D" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif