From 276dfdd4574fb047af4c1ca5dddfbc4770e8712f Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Fri, 28 Oct 2022 04:58:20 +0800 Subject: [PATCH] Input/output permutation for fused attention (#460) * reopen masking att instance due to CI is upgraded * re-enable instances previously failed on 9110 * enable ksize-kpadding pair validity test * add non-masked attention+permute test; expose masking boolean to attention kernel handles * disable bench * fix test * move files * bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute * format * amend rename * disable bench in test * add mask/no-mask test for non-permute attention kernels * disable broken kernel instance * example working add non-permuted problem statement evaluating whether overhead comes from permutation or the extra kernel arg * interface for bias addition without implementing it * test and profiler running * tidy * mask type determined by enum class * unify example code * move masking specialization to its own header * align formats * extract helper functions * experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute * add tensor specialization to template args since tensor spec packed shows perf parity when permutation isn't needed remove redundant template args comment on 'packed' tensor specialization * grouped attention with input/output permute example * format * clean up * refactor acc0 tile visitor Co-authored-by: shaojiewang Co-authored-by: Chao Liu [ROCm/composable_kernel commit: de37550f728ea27c683be3f367547db80cba68a8] --- .../CMakeLists.txt | 2 + ...le_scale_softmax_gemm_permute_xdl_fp16.cpp | 306 +------- ...mm_scale_softmax_gemm_permute_xdl_fp16.cpp | 296 +------- ...le_scale_softmax_gemm_permute_xdl_fp16.cpp | 159 ++++ ...mm_scale_softmax_gemm_permute_xdl_fp16.cpp | 341 +-------- ...atched_gemm_scale_softmax_gemm_permute.inc | 262 +++++++ ...rouped_gemm_scale_softmax_gemm_permute.inc | 319 ++++++++ include/ck/ck.hpp | 5 + .../tensor_space_filling_curve.hpp | 10 +- .../gpu/block/blockwise_gemm_xdlops.hpp | 36 + .../device_batched_gemm_softmax_gemm.hpp | 3 +- ...vice_batched_gemm_softmax_gemm_permute.hpp | 67 +- ...vice_grouped_gemm_softmax_gemm_permute.hpp | 44 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 548 ++++++-------- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 9 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 686 ++++++++---------- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 31 +- .../gpu/device/masking_specialization.hpp | 82 +++ ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 159 ++-- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 13 +- .../transform_contraction_to_gemm.hpp | 288 ++++++++ ...emm_masking_scale_softmax_gemm_permute.hpp | 100 --- .../gpu/batched_gemm_softmax_gemm.hpp | 40 +- .../gpu/batched_gemm_softmax_gemm_permute.hpp | 129 ++++ .../gpu/CMakeLists.txt | 1 + ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 8 +- ...6_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp | 6 +- .../CMakeLists.txt | 4 - ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 85 --- ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 80 +- .../CMakeLists.txt | 4 + ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 133 ++++ ...profile_batched_gemm_softmax_gemm_impl.hpp | 24 +- ...atched_gemm_softmax_gemm_permute_impl.hpp} | 263 ++++--- test/CMakeLists.txt | 2 +- .../CMakeLists.txt | 5 - .../test_batched_gemm_softmax_gemm_fp16.cpp | 16 +- .../test_batched_gemm_softmax_gemm_util.hpp | 20 +- .../CMakeLists.txt | 5 + ...atched_gemm_softmax_gemm_permute_fp16.cpp} | 63 +- ...atched_gemm_softmax_gemm_permute_util.hpp} | 119 +-- .../space_filling_curve.cpp | 77 +- 42 files changed, 2654 insertions(+), 2196 deletions(-) create mode 100644 example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc create mode 100644 include/ck/tensor_operation/gpu/device/masking_specialization.hpp create mode 100644 include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt delete mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp rename profiler/include/{profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp => profile_batched_gemm_softmax_gemm_permute_impl.hpp} (55%) delete mode 100644 test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt create mode 100644 test/batched_gemm_softmax_gemm_permute/CMakeLists.txt rename test/{batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp => batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp} (55%) rename test/{batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp => batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp} (52%) diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index b43a810458..37187676b5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -2,9 +2,11 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_custom_target(example_gemm_scale_softmax_gemm) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 20294bccf1..644adf300e 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -33,9 +33,6 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; @@ -44,13 +41,14 @@ using B1DataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; -using ALayout = Row; -using B0Layout = Col; -using B1Layout = Row; - -using CPermuteNumDims_G_M_O = - S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; using AElementOp = PassThrough; using B0ElementOp = PassThrough; @@ -59,17 +57,27 @@ using B1ElementOp = PassThrough; using CElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< - ALayout, - B0Layout, - B1Layout, - CPermuteNumDims_G_M_O, + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, ADataType, B0DataType, B1DataType, CDataType, + Acc0BiasDataType, + Acc1BiasDataType, AccDataType, CShuffleDataType, AElementOp, @@ -78,6 +86,10 @@ using DeviceGemmInstance = B1ElementOp, CElementOp, GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, 1, 256, 128, // MPerBlock @@ -118,7 +130,7 @@ using DeviceGemmInstance = 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - true>; // MaskOutUpperTriangle + MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; +#include "run_batched_gemm_scale_softmax_gemm_permute.inc" - // GEMM shape for A/B0/B1/C - // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o - ck::index_t M = 512; - ck::index_t N = 512; - ck::index_t K = 64; - ck::index_t O = 128; - ck::index_t StrideA = -1; - ck::index_t StrideB0 = -1; - ck::index_t StrideB1 = -1; - ck::index_t BatchStrideA = -1; - ck::index_t BatchStrideB0 = -1; - ck::index_t BatchStrideB1 = -1; - float alpha = 1; - - // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape - // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) - // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) - ck::index_t G0 = 7; - ck::index_t G1 = 13; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 11) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - O = std::stoi(argv[7]); - G0 = std::stoi(argv[8]); - G1 = std::stoi(argv[9]); - - alpha = std::stof(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 11: M, N, K, O, G0, G1\n"); - printf("arg10: scale (alpha)\n"); - exit(0); - } - - std::vector c_gs_ms_os_lengths{G0, G1, M, O}; - std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; - - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB0 = ck::is_same_v ? N : K; - const int DefaultStrideB1 = ck::is_same_v ? O : N; - - StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; - StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; - StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; - - const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; - const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; - const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; - - BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; - BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; - BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; - - const int BatchCount = G0 * G1; - - auto f_host_tensor_descriptor = [](std::size_t batch_count, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); - } - }; - - // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); - Tensor b0_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); - Tensor b1_g_n_o( - f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); - Tensor c_gs_ms_os_host_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - Tensor c_gs_ms_os_device_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - - std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; - std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; - std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; - std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 2: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - case 3: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - break; - default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - } - - DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); - DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) * - c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); - - a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); - b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); - b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); - - auto a_element_op = AElementOp{}; - auto b0_element_op = B0ElementOp{}; - auto acc0_element_op = Acc0ElementOp{alpha}; - auto b1_element_op = B1ElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), - static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), - static_cast(c_gs_ms_os_device_buf.GetDeviceBuffer()), - M, - N, - K, - O, - BatchCount, - c_gs_ms_os_lengths, - c_gs_ms_os_strides, - StrideA, - StrideB0, - StrideB1, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - a_element_op, - b0_element_op, - acc0_element_op, - b1_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; - std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + - sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * - BatchCount; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - if(do_verification) - { - c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); - - // Output of Gemm0 is input A of Gemm1 - Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - - Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - - Tensor c_g_m_o_host_result(std::vector{BatchCount, M, O}, - std::vector{M * O, O, 1}); - - auto ref_gemm0 = ReferenceGemm0Instance{}; - auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); - auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); - - // gemm 0 - ref_gemm0_invoker.Run(ref_gemm0_argument); - - // mask out upper triangle - acc0_g_m_n.ForEach([&](auto& self, auto idx) { - if(idx[1] < idx[2]) - self(idx) = -ck::NumericLimits::Infinity(); - }); - - auto ref_softmax = ReferenceSoftmaxInstance{}; - auto ref_softmax_invoker = ref_softmax.MakeInvoker(); - auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); - - // softmax - ref_softmax_invoker.Run(ref_softmax_argument); - - auto ref_gemm1 = ReferenceGemm1Instance{}; - auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); - auto ref_gemm1_argument = ref_gemm1.MakeArgument( - a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); - - // gemm1 - ref_gemm1_invoker.Run(ref_gemm1_argument); - - // permute - c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { - const size_t& g0 = idx[0]; - const size_t& g1 = idx[1]; - - const size_t g = g0 * G1 + g1; - - self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); - }); - - return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData) - ? 0 - : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 8b2daec654..3727be02d4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -33,9 +33,6 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; @@ -44,13 +41,14 @@ using B1DataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; -using ALayout = Row; -using B0Layout = Col; -using B1Layout = Row; - -using CPermuteNumDims_G_M_O = - S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; using AElementOp = PassThrough; using B0ElementOp = PassThrough; @@ -59,17 +57,27 @@ using B1ElementOp = PassThrough; using CElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< - ALayout, - B0Layout, - B1Layout, - CPermuteNumDims_G_M_O, + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, ADataType, B0DataType, B1DataType, CDataType, + Acc0BiasDataType, + Acc1BiasDataType, AccDataType, CShuffleDataType, AElementOp, @@ -78,6 +86,10 @@ using DeviceGemmInstance = B1ElementOp, CElementOp, GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, 1, 256, 128, // MPerBlock @@ -118,7 +130,7 @@ using DeviceGemmInstance = 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - false>; // MaskOutUpperTriangle + MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; +#include "run_batched_gemm_scale_softmax_gemm_permute.inc" - // GEMM shape for A/B0/B1/C - // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o - ck::index_t M = 120; - ck::index_t N = 1000; - ck::index_t K = 64; - ck::index_t O = 128; - ck::index_t StrideA = -1; - ck::index_t StrideB0 = -1; - ck::index_t StrideB1 = -1; - ck::index_t BatchStrideA = -1; - ck::index_t BatchStrideB0 = -1; - ck::index_t BatchStrideB1 = -1; - float alpha = 1; - - // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape - // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) - // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) - ck::index_t G0 = 7; - ck::index_t G1 = 13; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 11) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - O = std::stoi(argv[7]); - G0 = std::stoi(argv[8]); - G1 = std::stoi(argv[9]); - - alpha = std::stof(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 11: M, N, K, O, G0, G1\n"); - printf("arg10: scale (alpha)\n"); - exit(0); - } - - std::vector c_gs_ms_os_lengths{G0, G1, M, O}; - std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; - - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB0 = ck::is_same_v ? N : K; - const int DefaultStrideB1 = ck::is_same_v ? O : N; - - StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; - StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; - StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; - - const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; - const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; - const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; - - BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; - BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; - BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; - - const int BatchCount = G0 * G1; - - auto f_host_tensor_descriptor = [](std::size_t batch_count, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); - } - }; - - // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); - Tensor b0_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); - Tensor b1_g_n_o( - f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); - Tensor c_gs_ms_os_host_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - Tensor c_gs_ms_os_device_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - - std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; - std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; - std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; - std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 2: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - case 3: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - break; - default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - } - - DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); - DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) * - c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); - - a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); - b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); - b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); - - auto a_element_op = AElementOp{}; - auto b0_element_op = B0ElementOp{}; - auto acc0_element_op = Acc0ElementOp{alpha}; - auto b1_element_op = B1ElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), - static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), - static_cast(c_gs_ms_os_device_buf.GetDeviceBuffer()), - M, - N, - K, - O, - BatchCount, - c_gs_ms_os_lengths, - c_gs_ms_os_strides, - StrideA, - StrideB0, - StrideB1, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - a_element_op, - b0_element_op, - acc0_element_op, - b1_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; - std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + - sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * - BatchCount; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - if(do_verification) - { - c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); - - // Output of Gemm0 is input A of Gemm1 - Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - - Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - - Tensor c_g_m_o_host_result(std::vector{BatchCount, M, O}, - std::vector{M * O, O, 1}); - - auto ref_gemm0 = ReferenceGemm0Instance{}; - auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); - auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); - - ref_gemm0_invoker.Run(ref_gemm0_argument); - - auto ref_softmax = ReferenceSoftmaxInstance{}; - auto ref_softmax_invoker = ref_softmax.MakeInvoker(); - auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); - - ref_softmax_invoker.Run(ref_softmax_argument); - - auto ref_gemm1 = ReferenceGemm1Instance{}; - auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); - auto ref_gemm1_argument = ref_gemm1.MakeArgument( - a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); - - ref_gemm1_invoker.Run(ref_gemm1_argument); - - c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { - const size_t& g0 = idx[0]; - const size_t& g1 = idx[1]; - - const size_t g = g0 * G1 + g1; - - self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); - }); - - return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData) - ? 0 - : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp new file mode 100644 index 0000000000..e4a71b0431 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_grouped_gemm_scale_softmax_gemm_permute.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 4f11a307c5..11d9927f70 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -33,9 +33,6 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; @@ -44,13 +41,14 @@ using B1DataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; -using ALayout = Row; -using B0Layout = Col; -using B1Layout = Row; - -using CPermuteNumDims_G_M_O = - S<1, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_M_O +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; using AElementOp = PassThrough; using B0ElementOp = PassThrough; @@ -59,17 +57,27 @@ using B1ElementOp = PassThrough; using CElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle< - ALayout, - B0Layout, - B1Layout, - CPermuteNumDims_G_M_O, + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, ADataType, B0DataType, B1DataType, CDataType, + Acc0BiasDataType, + Acc1BiasDataType, AccDataType, CShuffleDataType, AElementOp, @@ -78,6 +86,10 @@ using DeviceGemmInstance = B1ElementOp, CElementOp, GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, 1, 256, 128, // MPerBlock @@ -118,7 +130,7 @@ using DeviceGemmInstance = 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - false>; + MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; +#include "run_grouped_gemm_scale_softmax_gemm_permute.inc" - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - float alpha = 1; // scaling after 1st gemm - - std::size_t group_count = 13; - - // Problem descs - std::vector problem_descs; - std::vector p_a; - std::vector p_b0; - std::vector p_b1; - std::vector p_c; - - for(std::size_t i = 0; i < group_count; i++) - { - int M = 128 * (rand() % 8 + 1); - int N = 128 * (rand() % 8 + 1); - int K = 40; - int O = 40 * (rand() % 2 + 1); - int Batch = rand() % 8 + 1; - - const int StrideA = ck::is_same_v ? K : M; - const int StrideB0 = ck::is_same_v ? N : K; - const int StrideB1 = ck::is_same_v ? O : N; - - const int BatchStrideA = (ck::is_same_v ? K : M) * StrideA; - const int BatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; - const int BatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; - - std::vector c_gs_ms_os_lengths{Batch, M, O}; - std::vector c_gs_ms_os_strides{O, Batch * O, 1}; - - problem_descs.push_back({M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideB1, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - c_gs_ms_os_lengths, - c_gs_ms_os_strides}); - } - - auto f_host_tensor_descriptor = [](std::size_t batch_count, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); - } - }; - - std::vector> a_tensors; - std::vector> b0_tensors; - std::vector> b1_tensors; - std::vector> c_tensors; - - using DeviceMemPtr = std::unique_ptr; - - std::vector a_tensors_device; - std::vector b0_tensors_device; - std::vector b1_tensors_device; - std::vector c_tensors_device; - - std::size_t flop = 0, num_byte = 0; - - std::cout << "group count " << group_count << ". printing first 4 groups\n"; - for(std::size_t i = 0; i < group_count; i++) - { - const auto& M = problem_descs[i].M; - const auto& N = problem_descs[i].N; - const auto& K = problem_descs[i].K; - const auto& O = problem_descs[i].O; - const auto& Batch = problem_descs[i].Batch; - const auto& StrideA = problem_descs[i].StrideA; - const auto& StrideB0 = problem_descs[i].StrideB0; - const auto& StrideB1 = problem_descs[i].StrideB1; - const auto& BatchStrideA = problem_descs[i].BatchStrideA; - const auto& BatchStrideB0 = problem_descs[i].BatchStrideB0; - const auto& BatchStrideB1 = problem_descs[i].BatchStrideB1; - const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; - const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; - - // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_g_m_k( - f_host_tensor_descriptor(Batch, M, K, StrideA, BatchStrideA, ALayout{})); - Tensor b0_g_k_n( - f_host_tensor_descriptor(Batch, K, N, StrideB0, BatchStrideB0, B0Layout{})); - Tensor b1_g_n_o( - f_host_tensor_descriptor(Batch, N, O, StrideB1, BatchStrideB1, B1Layout{})); - Tensor c_gs_ms_os_device_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - - flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; - num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + - sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * - Batch; - - if(i < 4) - { - std::cout << "a_g_m_k[" << i << "]: " << a_g_m_k.mDesc << ", " - << "b0_g_k_n[" << i << "]: " << b0_g_k_n.mDesc << ", " - << "b1_g_n_o[" << i << "]: " << b1_g_n_o.mDesc << ", " - << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; - } - - switch(init_method) - { - case 0: break; - case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - break; - case 2: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - case 3: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - break; - default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - } - - a_tensors.push_back(a_g_m_k); - b0_tensors.push_back(b0_g_k_n); - b1_tensors.push_back(b1_g_n_o); - c_tensors.push_back(c_gs_ms_os_device_result); - - a_tensors_device.emplace_back( - std::make_unique(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize())); - b0_tensors_device.emplace_back( - std::make_unique(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize())); - b1_tensors_device.emplace_back( - std::make_unique(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize())); - c_tensors_device.emplace_back(std::make_unique( - sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); - - a_tensors_device[i]->ToDevice(a_g_m_k.mData.data()); - b0_tensors_device[i]->ToDevice(b0_g_k_n.mData.data()); - b1_tensors_device[i]->ToDevice(b1_g_n_o.mData.data()); - - p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); - p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); - p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); - p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); - } - - auto a_element_op = AElementOp{}; - auto b0_element_op = B0ElementOp{}; - auto acc0_element_op = Acc0ElementOp{alpha}; - auto b1_element_op = B1ElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(p_a, - p_b0, - p_b1, - p_c, - problem_descs, - a_element_op, - b0_element_op, - acc0_element_op, - b1_element_op, - c_element_op); - - // specify workspace for problem_desc - DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); - - gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - bool pass = true; - if(do_verification) - { - for(std::size_t i = 0; i < group_count; i++) - { - const auto& M = problem_descs[i].M; - const auto& N = problem_descs[i].N; - const auto& O = problem_descs[i].O; - const auto& Batch = problem_descs[i].Batch; - const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; - const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; - - const auto& a_g_m_k = a_tensors[i]; - const auto& b0_g_k_n = b0_tensors[i]; - const auto& b1_g_n_o = b1_tensors[i]; - auto& c_gs_ms_os_device_result = c_tensors[i]; - auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; - - Tensor c_gs_ms_os_host_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - - c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); - - // Output of Gemm0 is input A of Gemm1 - Tensor acc0_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{})); - - Tensor a1_g_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{})); - - Tensor c_g_m_o_host_result(std::vector{Batch, M, O}, - std::vector{M * O, O, 1}); - - auto ref_gemm0 = ReferenceGemm0Instance{}; - auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); - auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a_g_m_k, b0_g_k_n, acc0_m_n, a_element_op, b0_element_op, acc0_element_op); - - ref_gemm0_invoker.Run(ref_gemm0_argument); - - auto ref_softmax = ReferenceSoftmaxInstance{}; - auto ref_softmax_invoker = ref_softmax.MakeInvoker(); - auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_m_n, a1_g_m_n, 1, 0, {2}); - - ref_softmax_invoker.Run(ref_softmax_argument); - - auto ref_gemm1 = ReferenceGemm1Instance{}; - auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); - auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, - b1_g_n_o, - c_g_m_o_host_result, - PassThrough{}, - b1_element_op, - c_element_op); - - ref_gemm1_invoker.Run(ref_gemm1_argument); - - // Note: in this example, we merely permute the dimensions by changing underlying - // strides so we simply access data as-is - c_gs_ms_os_host_result.ForEach( - [&](auto& self, auto idx) { self(idx) = c_g_m_o_host_result(idx); }); - - bool pass_ = - ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); - pass &= pass_; - } - } - - return pass ? 0 : 1; -} +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc new file mode 100644 index 0000000000..5a373d7a27 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + // TODO ANT: replace array with vector? + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + {}, // std::array p_acc0_biases; + {}, // std::array p_acc1_biases; + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_lengths}, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_strides}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_lengths}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_strides}, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + ck::index_t BatchCount = G0 * G1; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = DeviceGemmInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc new file mode 100644 index 0000000000..57782208de --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + input_permute = std::stoi(argv[4]); + output_permute = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 5: input / output permute\n"); + exit(0); + } + + float alpha = 1; // scaling after 1st gemm + + std::size_t group_count = 7; + + // Problem descs + std::vector problem_descs; + std::vector p_a; + std::vector p_b0; + std::vector p_b1; + std::vector p_c; + std::vector> g0_g1_m_n_k_o; + + std::vector> a_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> c_tensors; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_tensors_device; + std::vector b0_tensors_device; + std::vector b1_tensors_device; + std::vector c_tensors_device; + + std::size_t flop = 0, num_byte = 0; + + std::cout << "group count " << group_count << ". printing first 4 groups\n"; + for(std::size_t i = 0; i < group_count; i++) + { + int M = 128 * (rand() % 8 + 1); + int N = 128 * (rand() % 8 + 1); + int K = 40; + int O = 40 * (rand() % 2 + 1); + int G0 = rand() % 3 + 1; + int G1 = rand() % 5 + 1; + + g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + problem_descs.push_back({a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}}); // acc1_biases_gs_ms_os_strides + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + int Batch = G0 * G1; + flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + Batch; + + if(i < 4) + { + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " + << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " + << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " + << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + } + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + a_tensors.push_back(a_gs_ms_ks); + b0_tensors.push_back(b0_gs_ns_ks); + b1_tensors.push_back(b1_gs_os_ns); + c_tensors.push_back(c_gs_ms_os_device_result); + + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize())); + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize())); + b1_tensors_device.emplace_back(std::make_unique( + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); + + a_tensors_device[i]->ToDevice(a_gs_ms_ks.mData.data()); + b0_tensors_device[i]->ToDevice(b0_gs_ns_ks.mData.data()); + b1_tensors_device[i]->ToDevice(b1_gs_os_ns.mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); + p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(p_a, + p_b0, + p_b1, + p_c, + {}, // p_acc0_biases + {}, // p_acc1_biases + problem_descs, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + // specify workspace for problem_desc + DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + for(std::size_t i = 0; i < group_count; i++) + { + const int& G0 = g0_g1_m_n_k_o[i][0]; + const int& G1 = g0_g1_m_n_k_o[i][1]; + const int& M = g0_g1_m_n_k_o[i][2]; + const int& N = g0_g1_m_n_k_o[i][3]; + const int& K = g0_g1_m_n_k_o[i][4]; + const int& O = g0_g1_m_n_k_o[i][5]; + + const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; + const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; + + const auto& a_gs_ms_ks = a_tensors[i]; + const auto& b0_gs_ns_ks = b0_tensors[i]; + const auto& b1_gs_os_ns = b1_tensors[i]; + auto& c_gs_ms_os_device_result = c_tensors[i]; + auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; + + c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({G0 * G1, M, K}); + Tensor b0_g_k_n({G0 * G1, K, N}); + Tensor b1_g_n_o({G0 * G1, N, O}); + Tensor acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = DeviceGemmInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm 1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + bool pass_ = + ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); + pass &= pass_; + } + } + + return pass ? 0 : 1; +} diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index ad85e23382..92018aacba 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -159,6 +159,11 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 0 +// workaround: disable broken fused attention kernel instance that does not pass validation +// issue found on mi100/#10738 combo when irregular KPerBlock attention kernel has acc0 scaling +// enabled +#define CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE 1 + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index e9a990d857..17c9100b9f 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -14,7 +14,8 @@ namespace ck { template // # of scalars per access in each dimension + typename ScalarsPerAccess, + bool SnakeCurved = true> // # of scalars per access in each dimension struct SpaceFillingCurve { static constexpr index_t nDim = TensorLengths::Size(); @@ -136,9 +137,10 @@ struct SpaceFillingCurve Index ordered_idx; static_for<0, nDim, 1>{}([&](auto idim) { - ordered_idx(idim) = forward_sweep[idim] ? ordered_access_idx[idim] - : ordered_access_lengths[idim] - 1 - - ordered_access_idx[idim]; + ordered_idx(idim) = + !SnakeCurved || forward_sweep[idim] + ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim]; }); return container_reorder_given_old2new(ordered_idx, dim_access_order) * diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 025be9e961..ac484a0866 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 return make_tuple(c_thread_m, c_thread_n); } + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() { static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && @@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2 return make_tuple(c_thread_m, c_thread_n); } + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp index 7d04f85749..c1f85e575c 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -24,7 +24,8 @@ template + typename CElementwiseOperation, + bool MaskOutUpperTriangle> // TODO: enum for mask type struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp index 3d29ae4520..ff55519980 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -7,49 +7,60 @@ #include #include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" namespace ck { namespace tensor_operation { namespace device { -template +template + typename CElementwiseOperation, + MaskingSpecialization MaskingSpec> struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator { - virtual std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b0, - const void* p_b1, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t O, - ck::index_t Batch, - std::vector c_gs_ms_os_lengths, - std::vector c_gs_ms_os_strides, - ck::index_t StrideA, - ck::index_t StrideB0, - ck::index_t StrideB1, - ck::index_t BatchStrideA, - ck::index_t BatchStrideB0, - ck::index_t BatchStrideB1, - AElementwiseOperation a_element_op, - B0ElementwiseOperation b0_element_op, - Acc0ElementwiseOperation acc0_element_op, - B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) = 0; + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp index 611e8bb1d4..b066a44585 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp @@ -7,46 +7,50 @@ #include #include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" namespace ck { namespace tensor_operation { namespace device { -template +template + typename CElementwiseOperation, + MaskingSpecialization MaskingSpec> struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator { struct ProblemDesc { - // Overall problem shape - index_t M; - index_t N; - index_t K; - index_t O; - index_t Batch; + std::vector a_gs_ms_ks_lengths; + std::vector a_gs_ms_ks_strides; - // Stride for A/B0/B1; layout determined by template args - index_t StrideA; - index_t StrideB0; - index_t StrideB1; - index_t BatchStrideA; - index_t BatchStrideB0; - index_t BatchStrideB1; + std::vector b0_gs_ns_ks_lengths; + std::vector b0_gs_ns_ks_strides; + + std::vector b1_gs_os_ns_lengths; + std::vector b1_gs_os_ns_strides; - // Lengths and strides for output C std::vector c_gs_ms_os_lengths; std::vector c_gs_ms_os_strides; + + std::vector> acc0_biases_gs_ms_ns_lengths; + std::vector> acc0_biases_gs_ms_ns_strides; + + std::vector> acc1_biases_gs_ms_os_lengths; + std::vector> acc1_biases_gs_ms_os_strides; }; virtual std::unique_ptr @@ -54,6 +58,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator std::vector p_b0_vec, std::vector p_b1_vec, std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, std::vector problem_desc_vec, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 9719735612..946a757cee 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -54,9 +55,8 @@ __global__ void index_t right = group_count; index_t group_id = index_t((left + right) / 2); - while((!(block_id >= arg_ptr[group_id].block_start_ && - block_id < arg_ptr[group_id].block_end_)) && - left <= right) + while( + (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_))) { if(block_id < arg_ptr[group_id].block_start_) { @@ -114,14 +114,17 @@ __global__ void // Computes C = A * B0 * B1 // ^^^^^^ (Acc0) // ^^^^^^^^^^^ (Acc1) -template +template struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle - : public DeviceGroupedGemmSoftmaxGemmPermute + CElementwiseOperation, + MaskingSpec> { - using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; - using ProblemDesc = - typename DeviceGroupedGemmSoftmaxGemmPermute::ProblemDesc; + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + +#if 0 + // TODO ANT: use alias + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimN; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimO; + static constexpr index_t NumDimGemm1K = NumDimN; +#endif + + using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; + using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute::ProblemDesc; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr auto matrix_padder = - GemmGemmPadder{ - MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< + Sequence, + Sequence, + GemmSpec, + ASpec, + BSpec, + B1Spec, + CSpec>; - static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); } - static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), + Number{}); } - // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 - static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto + MakeB1GridDescriptor_BK0_N_BK1(const std::vector& b1_gs_gemm1ns_gemm1ks_lengths_vec, + const std::vector& b1_gs_gemm1ns_gemm1ks_strides_vec) { - const auto b1_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); - - const auto N = b1_grid_desc_n_k.GetLength(I0); - const auto K = b1_grid_desc_n_k.GetLength(I1); - - const auto B1K0 = K / B1K1; - - return transform_tensor_descriptor( - b1_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, + b1_gs_gemm1ns_gemm1ks_strides_vec), + Number{}); } - // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] - static auto MakeCGridDescriptor_M_N(const std::vector& c_gs_ms_ns_lengths_vec, - const std::vector& c_gs_ms_ns_strides_vec) + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + constexpr static auto make_MaskOutPredicate() { - constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); - constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); - constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N - - assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && - c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); - - const auto to_tuple = [&](auto& vec, auto start, auto end) { - return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); - }; - - const auto c_ms_ns_lengths = to_tuple( - c_gs_ms_ns_lengths_vec, Number{}, Number{}); - const auto c_ms_ns_strides = to_tuple( - c_gs_ms_ns_strides_vec, Number{}, Number{}); - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds); - - // naive tensor C[M0, M1, M2, ..., N0, N1, N2...] - const auto c_grid_desc_ms_ns = - make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); - - // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] - const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor( - c_grid_desc_ms_ns, - make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), - make_tuple(mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); - } - - // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] - static auto MakeCGridDescriptor_G_M_N(const std::vector& c_gs_ms_ns_lengths_vec, - const std::vector& c_gs_ms_ns_strides_vec) - { - constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); - constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); - constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N - - assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && - c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); - - const auto to_tuple = [&](auto& vec, auto start, auto end) { - return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); - }; - - const auto c_gs_ms_ns_lengths = - to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); - const auto c_gs_ms_ns_strides = - to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number{}); - - // dimension Ids for G0, G1, ... - constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = - typename arithmetic_sequence_gen::type{}; - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; - - // lengths for G0, G1, ... - const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds); - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds); - - // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] - const auto c_grid_desc_gs_ms_ns = - make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides); - - // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * - // N2 * ...] - const auto c_grid_desc_g_mraw_nraw = - transform_tensor_descriptor(c_grid_desc_gs_ms_ns, - make_tuple(make_merge_transform(gLengths), - make_merge_transform(mLengths), - make_merge_transform(nLengths)), - make_tuple(gDimIds, mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // this desc is only for calculating batch offset so no padding needed - return c_grid_desc_g_mraw_nraw; - } - - using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); - using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); - - // to track the points which need to be set to -inf on C0 - // Note: no need to reset M padding value, because they will not be stored out. - struct C0MatrixMask - { - C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} - - __host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; } - - __host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) { - return n >= NRaw_; + return MaskDisabledPredicate{}; } - - __host__ __device__ bool IsMaskedElement(index_t m, index_t n) const + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) { - return IsUpperTriangle(m, n) || IsNOutOfBound(n); + return MaskOutUpperTrianglePredicate{}; } - - private: - // index_t MRaw_; - index_t NRaw_; - }; + } + using C0MatrixMask = C0MatrixMask_impl; struct ComputeBasePtrOfStridedBatch { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideB1, - CGridDesc_G_M_N c_grid_desc_g_m_n) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideB1_(BatchStrideB1), + ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const BGridDesc_G_N_K& b_grid_desc_g_n_k, + const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b_grid_desc_g_n_k_(b_grid_desc_g_n_k), + b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), c_grid_desc_g_m_n_(c_grid_desc_g_m_n) { } __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const { - return g_idx * static_cast(BatchStrideA_); + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); } __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB_); + return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); } __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB1_); + return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); } __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const @@ -469,9 +331,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle } private: - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideB1_; + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; CGridDesc_G_M_N c_grid_desc_g_m_n_; }; @@ -535,8 +397,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, - matrix_padder.PadN, - MaskOutUpperTriangle>; + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; using Block2CTileMap = OffsettedBlockToCTileMap; @@ -570,16 +432,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle struct GroupDeviceArg { - // problem definiton - index_t M; - index_t N; - index_t K; - index_t O; + // lengths for the last dimensions of overall problem for sanity check of vector load/store + std::vector raw_lengths_mz_nz_kz_gemm1nz_; - // Strides for the last dimensions of C for sanity check of vector load/store - index_t c_extent_lowest_; - index_t c_stride_lowest_; + // strides for the last dimensions of each tensor for sanity check of vector load/store + std::vector a_mz_kz_strides_; + std::vector b_nz_kz_strides_; + std::vector b1_nz_kz_strides_; + std::vector c_mz_gemm1nz_strides_; + // for gridwise gemm check CGridDesc_M_N c_grid_desc_m_n_; }; @@ -591,6 +453,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle std::vector p_b_vec, std::vector p_b1_vec, std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, std::vector problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -603,6 +467,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle b1_element_op_{b1_element_op}, c_element_op_{c_element_op} { + // TODO ANT: implement bias addition group_count_ = problem_desc_vec.size(); if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && @@ -611,6 +476,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); } + if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size())) + { + throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size"); + } + grid_size_ = 0; for(std::size_t i = 0; i < group_count_; i++) @@ -620,14 +490,25 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle const auto p_b1_grid = static_cast(p_b1_vec[i]); const auto p_c_grid = static_cast(p_c_vec[i]); - const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1( - problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA); - const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( - problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0); - const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( - problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1); - const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N( - problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); + const auto& problem_desc = problem_desc_vec[i]; + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); + const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( + problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( + problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); + + const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K( + problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); + const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( + problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( + problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( + problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -635,25 +516,32 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle const index_t BlockStart = grid_size_; const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); - const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * - problem_desc_vec[i].Batch; + const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); + const index_t grid_size_grp = + block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count; const index_t BlockEnd = grid_size_ + grid_size_grp; // batch stride - // TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure - const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N( - problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); - const auto compute_base_ptr_of_batch = - ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA, - problem_desc_vec[i].BatchStrideB0, - problem_desc_vec[i].BatchStrideB1, - c_grid_desc_g_m_n); + const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( + a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n); // C0 mask - const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N); + const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); grid_size_ += grid_size_grp; + // for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and + // so on + if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias && + problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias && + problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias && + problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias)) + { + throw std::runtime_error( + "wrong! number of biases in function argument does not " + "match that in template argument"); + } + group_kernel_args_.push_back({p_a_grid, p_b_grid, p_b1_grid, @@ -669,13 +557,20 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle BlockStart, BlockEnd}); - group_device_args_.push_back({problem_desc_vec[i].M, - problem_desc_vec[i].N, - problem_desc_vec[i].K, - problem_desc_vec[i].O, - problem_desc_vec[i].c_gs_ms_os_lengths.back(), - problem_desc_vec[i].c_gs_ms_os_strides.back(), - c_grid_desc_m_n}); + group_device_args_.push_back( + {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], + problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], + problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]}, + {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + {problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1], + problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, + {problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1], + problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]}, + {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], + problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, + c_grid_desc_m_n}); } } @@ -788,6 +683,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } + // TODO ANT: Check if tensor specialization & strides mismatch + bool all_has_main_k_block_loop = true; bool some_has_main_k_block_loop = false; @@ -815,19 +712,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle // Note: we need raw lengths since threadwise copy can not handle vector load when // part of vector is out of bounds - const auto MRaw = device_arg.M; - const auto NRaw = device_arg.N; - const auto KRaw = device_arg.K; - const auto Gemm1NRaw = device_arg.O; + const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; + const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; + const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; + const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = device_arg.c_extent_lowest_; + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; + const auto c_extent_lowest = Gemm1NzRaw; if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && @@ -837,8 +731,22 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } - // Check vector store requirement; assumes last dimension in N to be contiguous - if(device_arg.c_stride_lowest_ != 1) + // Check vector load/store requirement + const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 + ? device_arg.a_mz_kz_strides_[1] + : device_arg.a_mz_kz_strides_[0]; + const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2 + ? device_arg.b_nz_kz_strides_[1] + : device_arg.b_nz_kz_strides_[0]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? device_arg.b1_nz_kz_strides_[1] + : device_arg.b1_nz_kz_strides_[0]; + const auto c_stride_lowest = + device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be + // contiguous + + if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) { return false; } @@ -873,6 +781,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle std::vector p_b_vec, std::vector p_b1_vec, std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, std::vector problem_desc_vec, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -884,6 +794,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle p_b_vec, p_b1_vec, p_c_vec, + p_acc0_biases_vec, + p_acc1_biases_vec, problem_desc_vec, a_element_op, b_element_op, @@ -895,21 +807,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle static auto MakeInvoker() { return Invoker{}; } // polymorphic - std::unique_ptr MakeArgumentPointer(std::vector p_a_vec, - std::vector p_b_vec, - std::vector p_b1_vec, - std::vector p_c_vec, - std::vector problem_desc_vec, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, - B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) override + std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override { return std::make_unique(p_a_vec, p_b_vec, p_b1_vec, p_c_vec, + p_acc0_biases_vec, + p_acc1_biases_vec, problem_desc_vec, a_element_op, b_element_op, @@ -942,7 +859,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle << Gemm1NPerBlock << ", " << Gemm1KPerBlock << ", " << B1K1 << ", " - << getGemmSpecializationString(GemmSpec) << ">"; + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(BSpec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) << ">"; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index bb3c09b427..2237ad944c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -130,8 +130,11 @@ namespace device { // D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] // E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] -// FIXME: TensorSpecialization::Packed specialization does not cover all packed tensor cases, it -// merely degenerates into TensorSpecialization::Default with NumDimG/M/N/K = 1 +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 // // Detail- Packed tensor satisfies // stride_0 = 1 @@ -147,7 +150,7 @@ namespace device { // essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. // // Might need to expose dimension order to the interface to fully support -// TensorSpecialization::Packed. +// TensorSpecialization::Packed in a traditional sense of "packed" tensor template +template struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle - : public DeviceBatchedGemmSoftmaxGemmPermute + CElementwiseOperation, + MaskingSpec> { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + +#if 0 + // TODO ANT: use alias + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimN; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimO; + static constexpr index_t NumDimGemm1K = NumDimN; +#endif + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr auto matrix_padder = - GemmGemmPadder{ - MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< + Sequence, + Sequence, + GemmSpec, + ASpec, + BSpec, + B1Spec, + CSpec>; - static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); } - static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), + Number{}); } - // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 - static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto + MakeB1GridDescriptor_BK0_N_BK1(const std::vector& b1_gs_gemm1ns_gemm1ks_lengths_vec, + const std::vector& b1_gs_gemm1ns_gemm1ks_strides_vec) { - const auto b1_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); - - const auto N = b1_grid_desc_n_k.GetLength(I0); - const auto K = b1_grid_desc_n_k.GetLength(I1); - - const auto B1K0 = K / B1K1; - - return transform_tensor_descriptor( - b1_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, + b1_gs_gemm1ns_gemm1ks_strides_vec), + Number{}); } - // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] - static auto MakeCGridDescriptor_M_N(const std::vector& c_gs_ms_ns_lengths_vec, - const std::vector& c_gs_ms_ns_strides_vec) + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + constexpr static auto make_MaskOutPredicate() { - constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); - constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); - constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N - - assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && - c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); - - const auto to_tuple = [&](auto& vec, auto start, auto end) { - return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); - }; - - const auto c_ms_ns_lengths = to_tuple( - c_gs_ms_ns_lengths_vec, Number{}, Number{}); - const auto c_ms_ns_strides = to_tuple( - c_gs_ms_ns_strides_vec, Number{}, Number{}); - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds); - - // naive tensor C[M0, M1, M2, ..., N0, N1, N2...] - const auto c_grid_desc_ms_ns = - make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); - - // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] - const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor( - c_grid_desc_ms_ns, - make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), - make_tuple(mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); - } - - // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] - static auto MakeCGridDescriptor_G_M_N(const std::vector& c_gs_ms_ns_lengths_vec, - const std::vector& c_gs_ms_ns_strides_vec) - { - constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); - constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); - constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N - - assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && - c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); - - const auto to_tuple = [&](auto& vec, auto start, auto end) { - return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); - }; - - const auto c_gs_ms_ns_lengths = - to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); - const auto c_gs_ms_ns_strides = - to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number{}); - - // dimension Ids for G0, G1, ... - constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = - typename arithmetic_sequence_gen::type{}; - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; - - // lengths for G0, G1, ... - const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds); - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds); - - // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] - const auto c_grid_desc_gs_ms_ns = - make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides); - - // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * - // N2 * ...] - const auto c_grid_desc_g_mraw_nraw = - transform_tensor_descriptor(c_grid_desc_gs_ms_ns, - make_tuple(make_merge_transform(gLengths), - make_merge_transform(mLengths), - make_merge_transform(nLengths)), - make_tuple(gDimIds, mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // this desc is only for calculating batch offset so no padding needed - return c_grid_desc_g_mraw_nraw; - } - - using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); - using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); - - // to track the points which need to be set to -inf on C0 - // Note: no need to reset M padding value, because they will not be stored out. - struct C0MatrixMask - { - C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} - - __host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; } - - __host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) { - return n >= NRaw_; + return MaskDisabledPredicate{}; } - - __host__ __device__ bool IsMaskedElement(index_t m, index_t n) const + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) { - return IsUpperTriangle(m, n) || IsNOutOfBound(n); + return MaskOutUpperTrianglePredicate{}; } - - private: - // index_t MRaw_; - index_t NRaw_; - }; + } + using C0MatrixMask = C0MatrixMask_impl; struct ComputeBasePtrOfStridedBatch { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideB1, - CGridDesc_G_M_N c_grid_desc_g_m_n) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideB1_(BatchStrideB1), + ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const BGridDesc_G_N_K& b_grid_desc_g_n_k, + const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b_grid_desc_g_n_k_(b_grid_desc_g_n_k), + b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), c_grid_desc_g_m_n_(c_grid_desc_g_m_n) { } __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const { - return g_idx * static_cast(BatchStrideA_); + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); } __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB_); + return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); } __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB1_); + return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); } __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const @@ -457,9 +317,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle } private: - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideB1_; + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; CGridDesc_G_M_N c_grid_desc_g_m_n_; }; @@ -523,47 +383,59 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, - matrix_padder.PadN, - MaskOutUpperTriangle>; + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; // Argument // FIXME: constness struct Argument : public BaseArgument { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - const B1DataType* p_b1_grid, - CDataType* p_c_grid, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t Gemm1NRaw, // = ORaw - index_t Batch, - std::vector c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths - std::vector c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides - index_t StrideA, - index_t StrideB, - index_t StrideB1, - index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideB1, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, - B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) + Argument( + const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_b1_grid_{p_b1_grid}, p_c_grid_{p_c_grid}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, - b1_grid_desc_bk0_n_bk1_{ - DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, - c_gs_ms_gemm1ns_strides)}, - c_grid_desc_g_m_n_{DeviceOp::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, - c_gs_ms_gemm1ns_strides)}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( + b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, + c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_g_n_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( + b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, + c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, @@ -571,14 +443,31 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle acc_element_op_{acc_element_op}, b1_element_op_{b1_element_op}, c_element_op_{c_element_op}, - batch_count_(Batch), + c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, + raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], + b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], + b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1], + b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, + b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1], + b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, + c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, compute_base_ptr_of_batch_{ - BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_}, - c0_matrix_mask_{NRaw}, - raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}, - c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()}, - c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()} + a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_} { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ns_lengths; + ignore = acc0_biases_gs_ms_ns_strides; + ignore = acc1_biases_gs_ms_gemm1ns_lengths; + ignore = acc1_biases_gs_ms_gemm1ns_strides; + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_, @@ -591,34 +480,66 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle } } - // private: + void Print() const + { + std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", " + << a_grid_desc_g_m_k_.GetLength(I1) << ", " + << a_grid_desc_g_m_k_.GetLength(I2) << '\n'; + // a_grid_desc_g_m_k_.Print(); + std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", " + << b_grid_desc_g_n_k_.GetLength(I1) << ", " + << b_grid_desc_g_n_k_.GetLength(I2) << '\n'; + // b_grid_desc_g_n_k_.Print(); + std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " + << b1_grid_desc_g_n_k_.GetLength(I1) << ", " + << b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; + // b1_grid_desc_g_n_k_.Print(); + std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", " + << c_grid_desc_g_m_n_.GetLength(I1) << ", " + << c_grid_desc_g_m_n_.GetLength(I2) << '\n'; + // c_grid_desc_g_m_n_.Print(); + } + + // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; const B1DataType* p_b1_grid_; CDataType* p_c_grid_; + + // tensor descriptor AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; CGridDesc_G_M_N c_grid_desc_g_m_n_; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-c-tile map typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + + // element-wise op AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; AccElementwiseOperation acc_element_op_; B1ElementwiseOperation b1_element_op_; CElementwiseOperation c_element_op_; - index_t batch_count_; - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; // check C0 masking and padding C0MatrixMask c0_matrix_mask_; // For robust IsSupportedArgument() check - std::vector raw_lengths_m_n_k_o_; - index_t c_extent_lowest_; - index_t c_stride_lowest_; + std::vector raw_lengths_mz_nz_kz_gemm1nz_; + std::vector a_mz_kz_strides_; + std::vector b_nz_kz_strides_; + std::vector b1_nz_kz_strides_; + std::vector c_mz_gemm1nz_strides_; + + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; }; // Invoker @@ -628,13 +549,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!DeviceOp::IsSupportedArgument(arg)) { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + throw std::runtime_error("wrong! unsupported argument"); } const index_t grid_size = @@ -719,17 +636,24 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if 0 + arg.Print(); +#endif + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) { return false; } + // TODO ANT: Check if tensor specialization & strides mismatch + // Check if C permute dimension matches GEMM + GEMM shape const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); + if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) { return false; @@ -737,19 +661,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle // Note: we need raw lengths since threadwise copy can not handle vector load when part of // vector is out of bounds - const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; - const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; - const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; - const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; + const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; + const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = arg.c_extent_lowest_; + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; + const auto c_extent_lowest = Gemm1NzRaw; if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && @@ -759,8 +681,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return false; } - // Check vector store requirement; assumes last dimension in N to be contiguous - if(arg.c_stride_lowest_ != 1) + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b_stride_lowest = + BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0]; + const auto c_stride_lowest = + arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous + + if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) { return false; } @@ -778,46 +710,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - const B1DataType* p_b1, - CDataType* p_c, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t Gemm1NRaw, - index_t Batch, - std::vector c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths - std::vector c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides - index_t StrideA, - index_t StrideB, - index_t StrideB1, - index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideB1, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, - B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) + static auto MakeArgument( + const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) { return Argument{p_a, p_b, p_b1, p_c, - MRaw, - NRaw, - KRaw, - Gemm1NRaw, - Batch, - c_gs_ms_gemm1ns_lengths, - c_gs_ms_gemm1ns_strides, - StrideA, - StrideB, - StrideB1, - BatchStrideA, - BatchStrideB, - BatchStrideB1, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides, + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides a_element_op, b_element_op, acc_element_op, @@ -829,47 +766,51 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle // polymorphic // FIXME: constness - std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - const void* p_b1, - void* p_c, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t Gemm1NRaw, - index_t Batch, - std::vector c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths - std::vector c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides - index_t StrideA, - index_t StrideB, - index_t StrideB1, - index_t BatchStrideA, - index_t BatchStrideB, - index_t BatchStrideB1, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, - B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) override + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_b1), static_cast(p_c), - MRaw, - NRaw, - KRaw, - Gemm1NRaw, - Batch, - c_gs_ms_gemm1ns_lengths, - c_gs_ms_gemm1ns_strides, - StrideA, - StrideB, - StrideB1, - BatchStrideA, - BatchStrideB, - BatchStrideB1, + p_acc0_biases, // cast in struct Argument + p_acc1_biases, // cast in struct Argument + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides, + acc1_biases_gs_ms_gemm1ns_lengths, + acc1_biases_gs_ms_gemm1ns_strides, a_element_op, b_element_op, acc_element_op, @@ -901,7 +842,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle << Gemm1NPerBlock << ", " << Gemm1KPerBlock << ", " << B1K1 << ", " - << getGemmSpecializationString(GemmSpec) << ">"; + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(BSpec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) << ">"; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index cf4bd01f09..1f21f2d712 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/host_utility/device_prop.hpp" @@ -196,7 +197,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, - CElementwiseOperation> + CElementwiseOperation, + MaskOutUpperTriangle> { using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; @@ -315,29 +317,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); } - // to track the points which need to be set to -inf on C0 - // Note: no need to reset M padding value, because they will not be stored out. - struct C0MatrixMask - { - C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {} - - __host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; } - - __host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const - { - return n >= NRaw_; - } - - __host__ __device__ bool IsMaskedElement(index_t m, index_t n) const - { - return IsUpperTriangle(m, n) || IsNOutOfBound(n); - } - - private: - // index_t MRaw_; - index_t NRaw_; - }; - struct ComputeBasePtrOfStridedBatch { ComputeBasePtrOfStridedBatch(index_t BatchStrideA, @@ -383,6 +362,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0MatrixMask = conditional_t, + C0MatrixMask_impl>; + // GridwiseGemm using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp new file mode 100644 index 0000000000..ea0f5897a7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct MaskingSpecialization +{ + MaskDisabled, + MaskOutUpperTriangle +}; + +inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) +{ + switch(s) + { + case MaskingSpecialization::MaskDisabled: return "MaskDisabled"; + case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle"; + default: return "Unrecognized specialization!"; + } +} + +struct MaskDisabledPredicate +{ + __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const + { + return false; + }; + + __host__ __device__ constexpr bool + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + { + return false; + } +}; + +struct MaskOutUpperTrianglePredicate +{ + __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const + { + return operator()(m + m_tile - 1, n); + } +}; + +// to track the points which need to be set to -inf on C0 +// Note: no need to reset M padding value, because they will not be stored out. +template +struct C0MatrixMask_impl +{ + C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} + + __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const + { + return n >= NRaw_; + } + + __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const + { + return predicate_(m, n) || IsNOutOfBound(n); + } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const + { + return predicate_.IsTileSkippable(m, n, m_tile, n_tile); + } + + private: + // index_t MRaw_; + index_t NRaw_; + MaskOutPredicate predicate_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index d356d23132..ef12f29fc4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -336,36 +336,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }; - template - struct ElementOpPredicatedResetNaNToMinusInf; - - template <> - struct ElementOpPredicatedResetNaNToMinusInf - { - template - __host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x) - { - if(ck::math::isnan(x)) - { - y = -ck::NumericLimits::Infinity(); - } - else - { - op(y, x); - } - } - }; - - template <> - struct ElementOpPredicatedResetNaNToMinusInf - { - template - __host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x) - { - op(y, x); - } - }; - template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -406,11 +376,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle return; } - // HACK: this force m/n_block_data_idx_on_grid into SGPR + // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - const index_t n_block_data_idx_on_grid = + const index_t gemm1_n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // A matrix in LDS memory, dst of blockwise copy @@ -627,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle true, // DstResetCoord NumGemmKPrefetchStage>( b1_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), b1_element_op, b1_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -745,29 +715,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle running_max = NumericLimits::Lowest(); running_max_new = NumericLimits::Lowest(); - // decoder lower triangular mask - const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_n_cluster_id = thread_cluster_idx[I1]; - const index_t MPerRepeat = MPerBlock / MXdlPerWave; - const index_t NPerRepeat = NPerBlock / NXdlPerWave; - const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id; - // gemm1 K loop index_t gemm1_k_block_outer_index = 0; do { - if constexpr(MaskOutUpperTriangle) + auto n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) { - auto gemm0_n_block_idx = - __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); - if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) && - c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1, - gemm0_n_block_idx)) - { - continue; - } + continue; } // gemm0 gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, @@ -789,60 +746,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // do MNK padding or upper triangular masking if constexpr(MaskOutUpperTriangle || PadN) { - const index_t nstart = gemm1_k_block_outer_index * NPerBlock; + // 8d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); - static_for<0, m0, 1>{}([&](auto m0_i) { - const index_t m_global = mstart + m0_i * MPerRepeat; - const index_t acc_idx_m0 = m0_i * n0 * n2 * n4; - static_for<0, n0, 1>{}([&](auto n0_i) { - // constexpr auto nrepeat_i = n0_i * NPerRepeat; - // const index_t nstartxdl = nstart + nrepeat_i; - const index_t nstartxdl = nstart + n0_i * NPerRepeat; - const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4; - static_for<0, n2, 1>{}([&](auto n2_i) { - const index_t nstartgroup = - nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4; - const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4; - static_for<0, n4, 1>{}([&](auto n4_i) { - const index_t n_global = nstartgroup + n4_i; - const auto acc_offset = Number{}; - if constexpr(MaskOutUpperTriangle) - { - if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) - { - acc_thread_buf(acc_offset) = - -ck::NumericLimits::Infinity(); - } - else - { - acc_element_op(acc_thread_buf(acc_offset), - acc_thread_buf[acc_offset]); - } - } - else - { - // ignore m_global; - if(c0_matrix_mask.IsNOutOfBound(n_global)) - { - acc_thread_buf(acc_offset) = - -ck::NumericLimits::Infinity(); - } - else - { - acc_element_op(acc_thread_buf(acc_offset), - acc_thread_buf[acc_offset]); - } - } - }); - }); - }); + // 8d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + constexpr auto M0 = c_block_lengths[I0]; + constexpr auto N0 = c_block_lengths[I1]; + constexpr auto M1 = c_block_lengths[I2]; + constexpr auto N1 = c_block_lengths[I3]; + constexpr auto M2 = c_block_lengths[I4]; + constexpr auto N2 = c_block_lengths[I5]; + constexpr auto N3 = c_block_lengths[I6]; + constexpr auto N4 = c_block_lengths[I7]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( + Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)), + make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto n_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto n_global = n_local + n_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) + { + acc_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); + } }); } - else - { - static_for<0, acc_thread_buf.Size(), 1>{}( - [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); - } block_sync_lds(); // wait for lds read in gemm0 blockwise gemm diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 0748ffbce5..4d53f0d816 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -593,7 +593,8 @@ struct XdlopsGemm static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; - using CIndex = MultiIndex<2>; + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } @@ -822,6 +823,16 @@ struct XdlopsGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; + } + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp new file mode 100644 index 0000000000..5fc11d9158 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +static auto MakeGridDescriptorPair(const std::vector& gs_ms_ns_lengths_vec, + const std::vector& gs_ms_ns_strides_vec) +{ + if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + { + throw std::runtime_error("wrong! dimension must match input lengths"); + } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + static auto MakeAGridDescriptorPair(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + static auto MakeAGridDescriptor_G_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // B (alias of B0) + // + static auto MakeB0GridDescriptorPair(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeB0GridDescriptor_G_N_K(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + static auto MakeB0GridDescriptor_N_K(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // B1 + // + static auto MakeB1GridDescriptorPair(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + static auto MakeB1GridDescriptor_G_N_K(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + static auto MakeB1GridDescriptor_N_K(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // C + // + static auto MakeCGridDescriptorPair(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeCGridDescriptor_G_M_N(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + static auto MakeCGridDescriptor_M_N(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp deleted file mode 100644 index 61625ffb8b..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp +++ /dev/null @@ -1,100 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = ck::Sequence; - -using CPermuteNumDims_G_M_O = - S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O - -void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( - std::vector>>& instances); - -template -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute> -{ - using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute; - - static auto GetInstances() - { - std::vector> op_ptrs; - - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v) - { - add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( - op_ptrs); - } - } - return op_ptrs; - } -}; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp index d553f981d1..8a0b1b1fa7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp @@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g F16, PassThrough, PassThrough, + Scale, PassThrough, PassThrough, - PassThrough>>>& instances); + false>>>& instances); + +void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); template + typename CDataType, + bool MaskOutUpperTriangle> struct DeviceOperationInstanceFactory< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm> + MaskOutUpperTriangle>> { using DeviceOp = DeviceBatchedGemmSoftmaxGemm; + MaskOutUpperTriangle>; static auto GetInstances() { @@ -79,8 +99,16 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( - op_ptrs); + if constexpr(MaskOutUpperTriangle) + { + add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + else + { + add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } } } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp new file mode 100644 index 0000000000..9002fc382a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances); + +void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpec>> +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpec>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + else if(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index d660f28493..c206c4dc04 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -6,6 +6,7 @@ function(add_instance_library INSTANCE_NAME) clang_tidy_check(${INSTANCE_NAME}) endfunction(add_instance_library INSTANCE_NAME) + file(GLOB dir_list LIST_DIRECTORIES true *) set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 724961d357..9b96194c87 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -36,10 +36,10 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100 - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100 - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100 - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100 + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp index 6f65c3d378..0713dfcd99 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -36,10 +36,10 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_inst //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can trigger compiler crash in mainline #9110 but not in #10738 + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100 - // DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100 + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt deleted file mode 100644 index 7851fa36b6..0000000000 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_instance_library(device_batched_gemm_masking_scale_softmax_gemm_permute_instance - device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp -) - diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp deleted file mode 100644 index 006531a530..0000000000 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using CPermuteNumDims_G_M_O = - S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Scale = ck::tensor_operation::element_wise::Scale; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - -// c[g, m, n] = a[g, m, k] * b[g, n, k] -using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = - std::tuple< - // clang-format off - // 2 of them are commented out because they trigger the clang-13 issue. - //##############################################| ALayout| B0Layout| B1Layout| CPermuteNumDims_G_M_O| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| - //##############################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| - //##############################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| - //##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - //DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - //DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, true>, - // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< Row, Col, Row, CPermuteNumDims_G_M_O, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, true> - // clang-format on - >; - -void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 524a521383..a77872a315 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -24,11 +24,13 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; // c[g, m, n] = a[g, m, k] * b[g, n, k] +template using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< // clang-format off @@ -36,24 +38,25 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_ //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, false>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, Masking>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false> + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking> // clang-format on >; +template using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances = std::tuple< // clang-format off @@ -61,12 +64,14 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_ //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false>, - DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false> + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, +#if CK_WORKAROUND_DISABLE_BROKEN_ATTN_KERNEL_INSTANCE == 0 + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, +#endif + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking> // clang-format on >; @@ -81,16 +86,45 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g F16, PassThrough, PassThrough, + Scale, PassThrough, PassThrough, - PassThrough>>>& instances) + false>>>& instances) { add_device_operation_instances( instances, - device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + false>{}); add_device_operation_instances( instances, - device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances{}); + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances< + false>{}); +} + +void add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + true>{}); + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances< + true>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000..b5525b7386 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_batched_gemm_softmax_gemm_permute_instance + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +) + diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..21da6895e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +template +using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskOutUpperTriangle>{}); +} + +void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + Scale, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskDisabled>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp index 249fd1a885..6b0a25aca2 100644 --- a/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp @@ -29,7 +29,8 @@ template + typename CLayout, + bool MaskOutUpperTriangle> bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, int init_method, bool do_log, @@ -46,16 +47,18 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, int BatchStrideA = -1, int BatchStrideB0 = -1, int BatchStrideB1 = -1, - int BatchStrideC = -1) + int BatchStrideC = -1, + float alpha = 1.f) { using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; using PassThrough = tensor_operation::element_wise::PassThrough; + using Scale = tensor_operation::element_wise::Scale; using AElementOp = PassThrough; using B0ElementOp = PassThrough; - using Acc0ElementOp = PassThrough; + using Acc0ElementOp = Scale; using B1ElementOp = PassThrough; using CElementOp = PassThrough; using AccDataType = float; @@ -67,7 +70,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, AccDataType, AElementOp, B0ElementOp, - CElementOp>; + Acc0ElementOp>; // Ref Softmax: fp32 in, various type out using ReferenceSoftmaxInstance = @@ -185,7 +188,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, auto a_element_op = AElementOp{}; auto b0_element_op = B0ElementOp{}; - auto acc0_element_op = Acc0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; auto b1_element_op = B1ElementOp{}; auto c_element_op = CElementOp{}; @@ -201,7 +204,8 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, B0ElementOp, Acc0ElementOp, B1ElementOp, - CElementOp>; + CElementOp, + MaskOutUpperTriangle>; // get device op instances const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -214,10 +218,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, Scale{alpha}); ref_gemm0_invoker.Run(ref_gemm0_argument); + // mask out upper triangle + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(MaskOutUpperTriangle && idx[1] < idx[2]) + self(idx) = -ck::NumericLimits::Infinity(); + }); + auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); diff --git a/profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp b/profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp similarity index 55% rename from profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp rename to profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp index 5cf1035620..5533a88d54 100644 --- a/profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -7,10 +7,10 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -22,36 +22,32 @@ namespace ck { namespace profiler { -template -bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int O, - int G0, - int G1, - int StrideA = -1, - int StrideB0 = -1, - int StrideB1 = -1, - int BatchStrideA = -1, - int BatchStrideB0 = -1, - int BatchStrideB1 = -1, - float alpha = 1.f) + typename Acc0BiasesDataType, + typename Acc1BiasesDataType, + tensor_operation::device::MaskingSpecialization MaskingSpec> +bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int G0, + int G1, + float alpha = 1.f) { - using Row = tensor_layout::gemm::RowMajor; - using Col = tensor_layout::gemm::ColumnMajor; using PassThrough = tensor_operation::element_wise::PassThrough; using Scale = tensor_operation::element_wise::Scale; using AElementOp = PassThrough; @@ -60,6 +56,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi using B1ElementOp = PassThrough; using CElementOp = PassThrough; using AccDataType = float; + using tensor_operation::device::MaskingSpecialization; // Ref Gemm0: various type in, fp32 out using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] std::vector c_gs_ms_os_lengths{G0, G1, M, O}; std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB0 = ck::is_same_v ? N : K; - const int DefaultStrideB1 = ck::is_same_v ? O : N; - - StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; - StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; - StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; - - const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; - const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; - const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; - - BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; - BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; - BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; - const int BatchCount = G0 * G1; - auto f_host_tensor_descriptor = [](std::size_t batch_count, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); - } - }; + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); - Tensor b0_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); - Tensor b1_g_n_o( - f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); - Tensor c_gs_ms_os_host_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - Tensor c_gs_ms_os_device_result( - std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), - std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); - // Host verification: Output of Gemm0 is input A of Gemm1 - Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - Tensor c_g_m_o_host_result(std::vector{BatchCount, M, O}, - std::vector{M * O, O, 1}); - - std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; - std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; - std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; std::srand(1); // work around test flakiness @@ -157,38 +120,38 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi // or not. May want to try exact same approach as the GPU kernel in the host reference // GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then, // shrink the input value range as it is less likely to produce errors of around ~1e-3. - // a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - // b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - // b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + // a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 2: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; case 3: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); break; default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } - DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); - DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); - DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); - DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) * - c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); - a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); - b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); - b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); auto a_element_op = AElementOp{}; auto b0_element_op = B0ElementOp{}; @@ -196,20 +159,23 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi auto b1_element_op = B1ElementOp{}; auto c_element_op = CElementOp{}; - using DeviceOp = - tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; + using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple<>, + ck::Tuple<>, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + MaskingSpec>; // get device op instances const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -219,6 +185,26 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi if(do_verification) { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_argument = ref_gemm0.MakeArgument( @@ -228,7 +214,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi // mask out upper triangle acc0_g_m_n.ForEach([&](auto& self, auto idx) { - if(idx[1] < idx[2]) + if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2]) self(idx) = -ck::NumericLimits::Infinity(); }); @@ -265,23 +251,24 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi for(auto& op_ptr : op_ptrs) { auto argument_ptr = op_ptr->MakeArgumentPointer( - static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), - static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), - static_cast(c_gs_ms_os_device_buf.GetDeviceBuffer()), - M, - N, - K, - O, - BatchCount, + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + {}, // std::array p_acc0_biases; + {}, // std::array p_acc1_biases; + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, c_gs_ms_os_lengths, c_gs_ms_os_strides, - StrideA, - StrideB0, - StrideB1, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_lengths}, + {}, // std::array, 1>{acc0_biases_gs_ms_ns_strides}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_lengths}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_strides}, a_element_op, b0_element_op, acc0_element_op, @@ -319,18 +306,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi if(do_verification) { - c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); if(do_log) { - LogRangeAsType(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") + LogRangeAsType(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") + LogRangeAsType(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",") + LogRangeAsType(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",") << std::endl; LogRangeAsType( std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e1b0b9c6e6..edf17bcb69 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -41,7 +41,7 @@ add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) add_subdirectory(batched_gemm_softmax_gemm) -add_subdirectory(batched_gemm_masking_scale_softmax_gemm_permute) +add_subdirectory(batched_gemm_softmax_gemm_permute) add_subdirectory(grouped_gemm) add_subdirectory(reduce) add_subdirectory(convnd_fwd) diff --git a/test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt deleted file mode 100644 index 9596858e74..0000000000 --- a/test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute) - -add_gtest_executable(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp) -target_link_libraries(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance) -add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16) \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp index 8d54711b51..5df7769d5f 100644 --- a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp @@ -9,9 +9,13 @@ class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm { }; +using Masked = std::true_type; +using NoMask = std::false_type; + // clang-format off using KernelTypes = ::testing::Types< - std::tuple + std::tuple, + std::tuple >; // clang-format on @@ -120,7 +124,6 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK) using ck::tensor_operation::device::GemmSpecialization; -// TODO: enable KPadding tests when it is implemented TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch) { int P = 120; // requires padding @@ -152,12 +155,12 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch) // IsSupported(M, N, K, O) // clang-format off EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); // clang-format on } @@ -169,6 +172,5 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest) {1020, 1020, 64, 128, 24}, {576, 576, 64, 64, 24}, }; - this->bench_ = true; this->Run(); } diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp index eb7fb24b27..e9fd514cce 100644 --- a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp @@ -20,14 +20,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; template struct TestBatchedGemmSoftmaxGemm : public ::testing::Test { - using ADataType = std::tuple_element_t<0, Tuple>; - using B0DataType = std::tuple_element_t<1, Tuple>; - using B1DataType = std::tuple_element_t<2, Tuple>; - using CDataType = std::tuple_element_t<3, Tuple>; - using ALayout = std::tuple_element_t<4, Tuple>; - using B0Layout = std::tuple_element_t<5, Tuple>; - using B1Layout = std::tuple_element_t<6, Tuple>; - using CLayout = std::tuple_element_t<7, Tuple>; + using ADataType = std::tuple_element_t<0, Tuple>; + using B0DataType = std::tuple_element_t<1, Tuple>; + using B1DataType = std::tuple_element_t<2, Tuple>; + using CDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using B0Layout = std::tuple_element_t<5, Tuple>; + using B1Layout = std::tuple_element_t<6, Tuple>; + using CLayout = std::tuple_element_t<7, Tuple>; + using MaskingType = std::tuple_element_t<8, Tuple>; std::vector> lengths_ = {{256, 256, 64, 64, 4}, {256, 256, 128, 128, 4}, @@ -54,7 +55,8 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test ALayout, B0Layout, B1Layout, - CLayout>( + CLayout, + MaskingType::value>( verify_, 1, false, bench_, M, N, K, O, BatchCount); EXPECT_TRUE(pass); diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000..e1a74c7843 --- /dev/null +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -0,0 +1,5 @@ +add_custom_target(test_batched_gemm_softmax_gemm_permute) + +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) +target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) +add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) \ No newline at end of file diff --git a/test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp similarity index 55% rename from test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp index 43cd60bca5..293acd6015 100644 --- a/test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" -#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp" +#include "test_batched_gemm_softmax_gemm_permute_util.hpp" template class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 @@ -10,13 +10,18 @@ class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 { }; +using I1_t = ck::Number<1>; +using I2_t = ck::Number<2>; + +using MaskDisabled_t = + ck::integral_constant; +using MaskOutUpperTriangle_t = + ck::integral_constant; + // clang-format off -template -using S = ck::Sequence; -using CPermuteNumDims_G_M_O = - S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O using KernelTypes = ::testing::Types< - std::tuple + std::tuple, ck::Tuple<>, MaskDisabled_t>, + std::tuple, ck::Tuple<>, MaskOutUpperTriangle_t> >; // clang-format on @@ -91,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO) this->Run(); } -TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Bench_FP16_IrregularK) +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK) { this->lengths_ = std::vector>{{256, 256, 160, 160, 1, 16}, {256, 64, 160, 64, 1, 16}, @@ -125,7 +130,6 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP1 using ck::tensor_operation::device::GemmSpecialization; -// TODO: enable KPadding tests when it is implemented TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) { int P = 120; // requires padding @@ -133,22 +137,22 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS // IsSupported(M, N, K, O) // clang-format off - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); - EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); // clang-format on } @@ -156,13 +160,13 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS { // IsSupported(M, N, K, O) // clang-format off - EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 - // EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); // clang-format on } @@ -174,6 +178,5 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest) {1020, 1020, 64, 128, 4, 6}, {576, 576, 64, 64, 4, 6}, }; - this->bench_ = true; this->Run(); } diff --git a/test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp similarity index 52% rename from test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp index cd5d6389b0..990ef633c2 100644 --- a/test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp @@ -4,10 +4,14 @@ #include #include +#include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" -#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp" +#include "profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp" + using ck::tensor_operation::device::GemmSpecialization; +using ck::tensor_operation::device::MaskingSpecialization; +using ck::tensor_operation::device::TensorSpecialization; template using I = ck::Number; @@ -20,14 +24,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; template struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test { - using ADataType = std::tuple_element_t<0, Tuple>; - using B0DataType = std::tuple_element_t<1, Tuple>; - using B1DataType = std::tuple_element_t<2, Tuple>; - using CDataType = std::tuple_element_t<3, Tuple>; - using ALayout = std::tuple_element_t<4, Tuple>; - using B0Layout = std::tuple_element_t<5, Tuple>; - using B1Layout = std::tuple_element_t<6, Tuple>; - using CPermuteNumDims_G_M_O = std::tuple_element_t<7, Tuple>; + using NumDimGType = std::tuple_element_t<0, Tuple>; + using NumDimMType = std::tuple_element_t<1, Tuple>; + using NumDimNType = std::tuple_element_t<2, Tuple>; + using NumDimKType = std::tuple_element_t<3, Tuple>; + using NumDimOType = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using B0DataType = std::tuple_element_t<6, Tuple>; + using B1DataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using Acc0BiasDataType = std::tuple_element_t<9, Tuple>; + using Acc1BiasDataType = std::tuple_element_t<10, Tuple>; + using MaskingType = std::tuple_element_t<11, Tuple>; std::vector> lengths_ = { {256, 256, 64, 64, 6, 4}, @@ -42,15 +50,20 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test void RunSingle(int M, int N, int K, int O, int G0, int G1) { - bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl< - ADataType, - B0DataType, - B1DataType, - CDataType, - ALayout, - B0Layout, - B1Layout, - CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1); + bool pass = + ck::profiler::profile_batched_gemm_softmax_gemm_permute_impl, + ck::Tuple<>, + MaskingType::value>( + verify_, 1, false, bench_, M, N, K, O, G0, G1); EXPECT_TRUE(pass); } @@ -72,19 +85,13 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test }; template -struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 +struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128 { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale; - using ALayout = Row; - using B0Layout = Col; - using B1Layout = Row; - template using S = ck::Sequence; - using CPermuteNumDims_G_M_O = - S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O using ADataType = F16; using B0DataType = F16; @@ -103,14 +110,17 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< - ALayout, - B0Layout, - B1Layout, - CPermuteNumDims_G_M_O, + 2, + 1, + 1, + 1, + 1, ADataType, B0DataType, B1DataType, CDataType, + ck::Tuple<>, + ck::Tuple<>, AccDataType, CShuffleDataType, AElementOp, @@ -119,6 +129,10 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 B1ElementOp, CElementOp, GemmSpec, + TensorSpecialization::Default, // ATensorSpec + TensorSpecialization::Default, // B0TensorSpec + TensorSpecialization::Default, // B1TensorSpec + TensorSpecialization::Default, // CTensorSpec 1, 256, 128, // MPerBlock @@ -159,29 +173,48 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - true>; // Masking + MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle bool IsSupported(int M, int N, int K, int O) { + const int G0 = 1, G1 = 1; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + auto gemm = DeviceGemmGemmInstance{}; auto invoker = gemm.MakeInvoker(); auto argument = gemm.MakeArgument(static_cast(nullptr), static_cast(nullptr), static_cast(nullptr), static_cast(nullptr), - M, - N, - K, - O, - 0, // BatchCount - {0, 0, M, O}, // gs ms ns lengths - {0, O, 0, 1}, // gs ms ns strides - 0, // StrideA - 0, // StrideB0 - 0, // StrideB1 - 0, // BatchStrideA - 0, // BatchStrideB0 - 0, // BatchStrideB1 + {}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + {}, // acc0_biases_gs_ms_ns_lengths + {}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides PassThrough{}, // a_element_op PassThrough{}, // b0_element_op Scale{1.f}, // acc0_element_op diff --git a/test/space_filling_curve/space_filling_curve.cpp b/test/space_filling_curve/space_filling_curve.cpp index 500717dd2b..c7f6759e81 100644 --- a/test/space_filling_curve/space_filling_curve.cpp +++ b/test/space_filling_curve/space_filling_curve.cpp @@ -12,28 +12,91 @@ using namespace ck; -void traverse_using_space_filling_curve(); +void traverse_using_space_filling_curve_linear(); +void traverse_using_space_filling_curve_snakecurved(); int main(int argc, char** argv) { (void)argc; (void)argv; - traverse_using_space_filling_curve(); + traverse_using_space_filling_curve_linear(); + traverse_using_space_filling_curve_snakecurved(); return 0; } -void traverse_using_space_filling_curve() +void traverse_using_space_filling_curve_linear() { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; - using TensorLengths = Sequence<16, 10, 9>; - using DimAccessOrder = Sequence<2, 0, 1>; - using ScalarsPerAccess = Sequence<4, 2, 3>; - using SpaceFillingCurve = SpaceFillingCurve; + using TensorLengths = Sequence<3, 2, 2>; + using DimAccessOrder = Sequence<2, 0, 1>; + using ScalarsPerAccess = Sequence<1, 1, 1>; + using SpaceFillingCurve = + SpaceFillingCurve; + + constexpr auto expected = make_tuple(make_tuple(0, 0, 0), + make_tuple(0, 1, 0), + make_tuple(1, 0, 0), + make_tuple(1, 1, 0), + make_tuple(2, 0, 0), + make_tuple(2, 1, 0), + make_tuple(0, 0, 1), + make_tuple(0, 1, 1), + make_tuple(1, 0, 1), + make_tuple(1, 1, 1), + make_tuple(2, 0, 1), + make_tuple(2, 1, 1)); + + constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{}, + math::multiplies{}, + Number<1>{})); + + static_for<1, num_access, 1>{}([&](auto i) { + constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i); + + static_assert(idx_curr[I0] == expected[i][I0]); + static_assert(idx_curr[I1] == expected[i][I1]); + static_assert(idx_curr[I2] == expected[i][I2]); + + constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i); + constexpr auto expected_step = expected[i - I1] - expected[i]; + static_assert(backward_step[I0] == expected_step[I0]); + static_assert(backward_step[I1] == expected_step[I1]); + static_assert(backward_step[I2] == expected_step[I2]); + }); + + static_for<0, num_access - 1, 1>{}([&](auto i) { + constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i); + + static_assert(idx_curr[I0] == expected[i][I0]); + static_assert(idx_curr[I1] == expected[i][I1]); + static_assert(idx_curr[I2] == expected[i][I2]); + + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i); + constexpr auto expected_step = expected[i + I1] - expected[i]; + static_assert(forward_step[I0] == expected_step[I0]); + static_assert(forward_step[I1] == expected_step[I1]); + static_assert(forward_step[I2] == expected_step[I2]); + }); +} + +void traverse_using_space_filling_curve_snakecurved() +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + using TensorLengths = Sequence<16, 10, 9>; + using DimAccessOrder = Sequence<2, 0, 1>; + using ScalarsPerAccess = Sequence<4, 2, 3>; + using SpaceFillingCurve = + SpaceFillingCurve; constexpr auto expected = make_tuple(make_tuple(0, 0, 0), make_tuple(0, 2, 0),