From 4a7d8df85a71cdb62fe26430f3bf417347b9982c Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Tue, 22 Mar 2022 04:59:51 +0800 Subject: [PATCH] Gemm_c_shuffle (4 layouts) X (fp32 bf16 int8) (#131) * [What] Separate fixpoint gemm from gemm example [Why] let example of gemm_int8 be pure gemm. [What] 1. Add gemm_requant_relu_requant, 2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant * Fix path * Revise cmakelist due to merge develop * Add gemm fp16 test * Extract PrepareGemmTensor * Extract TestGemm * Add test for different layout * Add 4 layouts of shuffle version of fp32 * Add 4 layouts of shuffle version of int8 * Add 4 layouts of shuffle version of bf16 * replace all DeviceGemmPtr_ with DeviceGemmNoOpPtr to fit naming convension * Add test for non-shuffle verstion of gemm * Fix typo * Print kernel information * Add rest of the fp32 kernel to the test * 1. Add rest of the fp16 device iop. 2. Mark the invalid device operation Co-authored-by: rocking [ROCm/composable_kernel commit: 485ea46a40f6ed9310443a33541b494d042c57a8] --- .../gemm_xdl_alpha_beta.cpp | 10 +- .../03_gemm_bias_relu/gemm_xdl_bias_relu.cpp | 4 +- .../gemm_xdl_bias_relu_add.cpp | 6 +- .../gpu/gemm/CMakeLists.txt | 12 +- ...uffle_bf16_bf16_bf16_km_kn_mn_instance.cpp | 59 +++++ ...uffle_bf16_bf16_bf16_km_nk_mn_instance.cpp | 59 +++++ ...uffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 59 +++++ ..._shuffle_f32_f32_f32_km_kn_mn_instance.cpp | 58 +++++ ..._shuffle_f32_f32_f32_km_nk_mn_instance.cpp | 58 +++++ ..._shuffle_f32_f32_f32_mk_kn_mn_instance.cpp | 58 +++++ ..._shuffle_f32_f32_f32_mk_nk_mn_instance.cpp | 55 ++++ ...uffle_int8_int8_int8_km_kn_mn_instance.cpp | 58 +++++ ...uffle_int8_int8_int8_km_nk_mn_instance.cpp | 58 +++++ ...uffle_int8_int8_int8_mk_kn_mn_instance.cpp | 58 +++++ ...l_splitk_f16_f16_f16_mk_nk_mn_instance.cpp | 23 +- profiler/include/profile_gemm_impl.hpp | 80 +++++- profiler/src/profile_gemm.cpp | 120 +++++++++ test/gemm/CMakeLists.txt | 4 + test/gemm/gemm_bf16.cpp | 179 +++++-------- test/gemm/gemm_fp16.cpp | 154 +++++++++++ test/gemm/gemm_fp32.cpp | 196 +++++++------- test/gemm/gemm_int8.cpp | 167 ++++++------ test/gemm/gemm_util.hpp | 241 ++++++++++++++++++ test/include/test_util.hpp | 43 ++++ 24 files changed, 1497 insertions(+), 322 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp create mode 100644 test/gemm/gemm_fp16.cpp diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp index 51a31bcfb7..bd937cdc07 100644 --- a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp +++ b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -157,9 +157,9 @@ int main(int argc, char* argv[]) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -172,12 +172,12 @@ int main(int argc, char* argv[]) case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp index 4dc8d0b788..b4739ed47a 100644 --- a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp +++ b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp @@ -139,8 +139,8 @@ int main(int argc, char* argv[]) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // c0_n[n] Tensor c0_n(HostTensorDescriptor( diff --git a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp index 3ce7e9848b..671cfd014f 100644 --- a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp +++ b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp @@ -141,15 +141,15 @@ int main(int argc, char* argv[]) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // c0_n[n] Tensor c0_n(HostTensorDescriptor( std::vector({static_cast(N)}), std::vector({1}))); // c1_m_n[m ,n] - Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 642df74a3d..5f057adcc5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -8,12 +8,22 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; @@ -25,7 +35,7 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; ) -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..dceb797302 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,59 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..33e33b4988 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,59 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..319db8ea7f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,59 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..d0b9fad3ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..b6d2b5c285 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 2, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 2, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 2, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 2, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 2, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 2, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 2, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..551a9afb03 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 4, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 4, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 4, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 4, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 4, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 4, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 4, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 4, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..08b6e53c14 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,55 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 32, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 32, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4>, + DeviceGemmXdl_C_Shuffle< F32, F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 32, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..01a2b4c164 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..a8be534fa1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..c3752e2603 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 7b79639b4e..4b3524c30e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -45,15 +45,15 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format on >; -using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, 1, 9, S<1, 2, 1, 72>, 2> - // clang-format on - >; +// using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< +// // clang-format off +// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, 1, 9, S<1, 2, 1, 72>, 2> +// // clang-format on +// >; void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector>& instances) @@ -61,8 +61,9 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); + // FIXME - IsSupportedArgument() is false, need to check validity + // add_device_operation_instances( + // instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); } } // namespace device_gemm_instance diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 30778351fa..409293a22a 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -26,16 +26,28 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( + std::vector&); void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( + std::vector&); void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector&); @@ -45,6 +57,11 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); + void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); @@ -127,11 +144,6 @@ void profile_gemm_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; - // if(do_verification) - // { - - // } - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); @@ -159,6 +171,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && @@ -174,6 +189,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && @@ -189,6 +207,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } } else if constexpr(is_same::value && @@ -204,6 +225,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } } } @@ -291,23 +315,65 @@ void profile_gemm_impl(int do_verification, is_same::value) { if constexpr(is_same::value && - is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs); } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs); + } } else if constexpr(is_same::value && is_same::value && is_same::value) { if constexpr(is_same::value && - is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + } } if(gemm_ptrs.size() <= 0) diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index d85eec1465..1cae0ded9e 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -223,6 +223,26 @@ int profile_gemm(int argc, char* argv[]) (StrideC < 0) ? N : StrideC, KBatch); } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) { ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + KBatch); + } else { throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 65f56bbd5a..83b3c1e2e3 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -2,6 +2,10 @@ add_test_executable(test_gemm_fp32 gemm_fp32.cpp) target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_fp16 gemm_fp16.cpp) +target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) + add_test_executable(test_gemm_bf16 gemm_bf16.cpp) target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index b6d54fcae8..b60a496218 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceGemmPtr_ = +using DeviceGemmNoOpPtr = ck::tensor_operation::device::DeviceGemmPtr; @@ -32,131 +32,80 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(std::vector&); -} +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(std::vector&); +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck -namespace { - -using BF16 = ck::bhalf_t; - -using ADataType = BF16; -using BDataType = BF16; -using CDataType = BF16; -using AccDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) -{ - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - // use fp32 host kernel to verify bf16 device kernel - Tensor a_m_k_bf16( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n_bf16( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_device_bf16( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - Tensor a_m_k_fp32( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n_fp32( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_host_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - - bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); - bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); - - return std::make_tuple(a_m_k_bf16, - b_k_n_bf16, - c_m_n_device_bf16, - a_m_k_fp32, - b_k_n_fp32, - c_m_n_host_fp32, - c_m_n_device_fp32); -} - -bool TestGemm(DeviceGemmPtr_& gemmPtr) -{ - // Arrange - ck::gemm_util::GemmParams params; - params.M = 1024; - params.N = 1024; - params.K = 1024; - params.StrideA = 1024; - params.StrideB = 1024; - params.StrideC = 1024; - - auto host_tensors = PrepareGemmTensor(params); - const Tensor& a_bf16 = std::get<0>(host_tensors); - const Tensor& b_bf16 = std::get<1>(host_tensors); - Tensor& c_device_bf16 = std::get<2>(host_tensors); - Tensor& a_fp32 = std::get<3>(host_tensors); - Tensor& b_fp32 = std::get<4>(host_tensors); - Tensor& c_host_fp32 = std::get<5>(host_tensors); - Tensor& c_device_fp32 = std::get<6>(host_tensors); - - auto a_element_op = PassThrough{}; - auto b_element_op = PassThrough{}; - auto c_element_op = PassThrough{}; - - // use fp32 host kernel to verify bf16 device kernel - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - ck::gemm_util::RunHostGEMM( - a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); - - // Act - ck::gemm_util::RunDeviceGEMM( - gemmPtr, params, a_bf16, b_bf16, c_device_bf16, a_element_op, b_element_op, c_element_op); - - bf16_to_f32_(c_device_bf16, c_device_fp32); - - // Assert - bool res = test_util::check_err( - c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); - - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res; -} - -} // anonymous namespace - int main() { - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemmPtrs); + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; bool res = true; + std::vector gemmPtrs; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { - res &= TestGemm(gemmPtr); + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp new file mode 100644 index 0000000000..4ed85d170d --- /dev/null +++ b/test/gemm/gemm_fp16.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; +} diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index a4cae6db2b..7f73296545 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceGemmPtr_ = +using DeviceGemmNoOpPtr = ck::tensor_operation::device::DeviceGemmPtr; @@ -32,106 +32,122 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -} +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck -namespace { - -using ADataType = float; -using BDataType = float; -using CDataType = float; -using AccDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) -{ - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_result( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - - return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); -} - -bool TestGemm(DeviceGemmPtr_& gemmPtr) -{ - // Arrange - ck::gemm_util::GemmParams params; - params.M = 1024; - params.N = 1024; - params.K = 1024; - params.StrideA = 1024; - params.StrideB = 1024; - params.StrideC = 1024; - - auto host_tensors = PrepareGemmTensor(params); - const Tensor& a = std::get<0>(host_tensors); - const Tensor& b = std::get<1>(host_tensors); - Tensor& c_host = std::get<2>(host_tensors); - Tensor& c_device = std::get<3>(host_tensors); - - auto a_element_op = PassThrough{}; - auto b_element_op = PassThrough{}; - auto c_element_op = PassThrough{}; - - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - ck::gemm_util::RunHostGEMM( - a, b, c_host, a_element_op, b_element_op, c_element_op); - - // Act - ck::gemm_util::RunDeviceGEMM( - gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); - - // Assert - bool res = test_util::check_err( - c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); - - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res; -} - -} // anonymous namespace - int main() { - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + using ADataType = float; + using BDataType = float; + using CDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs); for(auto& gemmPtr : gemmPtrs) { - res &= TestGemm(gemmPtr); + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 464689bf16..0f4f1cbf01 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceGemmPtr_ = +using DeviceGemmNoOpPtr = ck::tensor_operation::device::DeviceGemmPtr; @@ -32,105 +32,96 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(std::vector&); } } // namespace device } // namespace tensor_operation } // namespace ck -namespace { - -using ADataType = int8_t; -using BDataType = int8_t; -using CDataType = int8_t; -using AccDataType = int32_t; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) -{ - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_result( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - - return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); -} - -bool TestGemm(DeviceGemmPtr_& gemmPtr) -{ - // Arrange - ck::gemm_util::GemmParams params; - params.M = 1024; - params.N = 1024; - params.K = 1024; - params.StrideA = 1024; - params.StrideB = 1024; - params.StrideC = 1024; - - auto host_tensors = PrepareGemmTensor(params); - const Tensor& a = std::get<0>(host_tensors); - const Tensor& b = std::get<1>(host_tensors); - Tensor& c_host = std::get<2>(host_tensors); - Tensor& c_device = std::get<3>(host_tensors); - - auto a_element_op = PassThrough{}; - auto b_element_op = PassThrough{}; - auto c_element_op = PassThrough{}; - - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - ck::gemm_util::RunHostGEMM( - a, b, c_host, a_element_op, b_element_op, c_element_op); - - // Act - ck::gemm_util::RunDeviceGEMM( - gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); - - // Assert - bool res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); - - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res; -} - -} // anonymous namespace - int main() { - std::vector gemmPtrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + std::vector gemmPtrs; bool res = true; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs); + for(auto& gemmPtr : gemmPtrs) { - res &= TestGemm(gemmPtr); + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index b7177545af..14d532defc 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -4,6 +4,10 @@ #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_gemm.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" namespace ck { namespace gemm_util { @@ -98,6 +102,243 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, c_m_n_device_buf.FromDevice(C.mData.data()); } +template +struct TestGemm +{ + auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto desc, auto type) { + using dataType = decltype(type); + + if(std::is_same::value) + { + desc.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + else + { + desc.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_k_n, BDataType{}); + + return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(DeviceGemmPtr_& gemmPtr) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + std::cout << gemmPtr->GetTypeString() << std::endl; + + // Arrange + ck::gemm_util::GemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::gemm_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + ck::gemm_util::RunDeviceGEMM( + gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); + + // Assert + bool res = false; + if(std::is_same::value) + { + res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + + return res; + } +}; + +template +struct TestGemmBF16 +{ + using BF16 = ck::bhalf_t; + + auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + // use fp32 host kernel to verify bf16 device kernel + Tensor a_m_k_bf16( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_bf16( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_device_bf16( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + Tensor a_m_k_fp32( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_k_n_fp32( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_fp32( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); + bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); + + return std::make_tuple(a_m_k_bf16, + b_k_n_bf16, + c_m_n_device_bf16, + a_m_k_fp32, + b_k_n_fp32, + c_m_n_host_fp32, + c_m_n_device_fp32); + } + + auto operator()(DeviceGemmPtr_& gemmPtr) + { + // Arrange + ck::gemm_util::GemmParams params; + params.M = 1024; + params.N = 1024; + params.K = 1024; + params.StrideA = 1024; + params.StrideB = 1024; + params.StrideC = 1024; + + auto host_tensors = PrepareGemmTensorBF16(params); + const Tensor& a_bf16 = std::get<0>(host_tensors); + const Tensor& b_bf16 = std::get<1>(host_tensors); + Tensor& c_device_bf16 = std::get<2>(host_tensors); + Tensor& a_fp32 = std::get<3>(host_tensors); + Tensor& b_fp32 = std::get<4>(host_tensors); + Tensor& c_host_fp32 = std::get<5>(host_tensors); + Tensor& c_device_fp32 = std::get<6>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + // use fp32 host kernel to verify bf16 device kernel + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::gemm_util::RunHostGEMM( + a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); + + // Act + ck::gemm_util::RunDeviceGEMM(gemmPtr, + params, + a_bf16, + b_bf16, + c_device_bf16, + a_element_op, + b_element_op, + c_element_op); + + bf16_to_f32_(c_device_bf16, c_device_fp32); + + // Assert + bool res = test_util::check_err( + c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); + + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return res; + }; +}; + } // namespace gemm_util } // namespace ck #endif diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index f779c3dd1d..f18055879c 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -54,6 +54,49 @@ check_err(const std::vector& out, return res; } +bool check_err(const std::vector<_Float16>& out, + const std::vector<_Float16>& ref, + const std::string& msg, + _Float16 rtol = static_cast<_Float16>(1e-3f), + _Float16 atol = static_cast<_Float16>(1e-3f)) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits<_Float16>::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + double out_ = double(out[i]); + double ref_ = double(ref[i]); + err = std::abs(out_ - ref_); + if(err > atol + rtol * std::abs(ref_) || !std::isfinite(out_) || !std::isfinite(ref_)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out_ << "!=" << ref_ << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + template typename std::enable_if::value, bool>::type check_err( const std::vector& out, const std::vector& ref, const std::string& msg, T = 0, T = 0)